some problem with overriding
This commit is contained in:
@@ -74,6 +74,9 @@ def create_parser():
|
||||
"One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps",
|
||||
)
|
||||
plot_parser.add_argument("--options", "-o", default=None)
|
||||
plot_parser.add_argument(
|
||||
"--show", action="store_true", help="show the plots instead of saving them"
|
||||
)
|
||||
plot_parser.set_defaults(func=plot_all)
|
||||
|
||||
dispersion_parser = subparsers.add_parser(
|
||||
@@ -186,7 +189,7 @@ def plot_all(args):
|
||||
if args.options is not None:
|
||||
opts |= dict([o.split("=")[:2] for o in re.split("[, ]", args.options)])
|
||||
root = Path(args.sim_dir).resolve()
|
||||
scripts.plot_all(root, args.spectrum_limits, **opts)
|
||||
scripts.plot_all(root, args.spectrum_limits, show=args.show, **opts)
|
||||
|
||||
|
||||
def plot_init_field_spec(args):
|
||||
|
||||
@@ -13,6 +13,8 @@ from scipy.interpolate import UnivariateSpline
|
||||
from scipy.interpolate.interpolate import interp1d
|
||||
from tqdm import utils
|
||||
|
||||
from scgenerator.const import PARAM_SEPARATOR
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
from . import io, math
|
||||
@@ -65,7 +67,7 @@ def plot_setup(
|
||||
|
||||
# ensure no overwrite
|
||||
ind = 0
|
||||
while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists():
|
||||
while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists():
|
||||
ind += 1
|
||||
|
||||
if mode == "default":
|
||||
@@ -864,10 +866,11 @@ def uniform_axis(
|
||||
else:
|
||||
raise TypeError(f"Don't know how to interpret {new_axis_spec}")
|
||||
tmp_axis, ind, ext = units.sort_axis(axis, plt_range)
|
||||
values = np.atleast_2d(values)
|
||||
if np.allclose((diff := np.diff(tmp_axis))[0], diff):
|
||||
new_axis = tmp_axis
|
||||
values = values[:, ind]
|
||||
else:
|
||||
values = np.atleast_2d(values)
|
||||
if plt_range.unit.type == "WL":
|
||||
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
|
||||
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))
|
||||
|
||||
@@ -9,11 +9,12 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils.parameter import BareParams
|
||||
from ..const import PARAM_SEPARATOR
|
||||
|
||||
from ..initialize import ParamSequence
|
||||
from ..physics import units, fiber
|
||||
from ..spectra import Pulse
|
||||
from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop
|
||||
from ..utils import pretty_format_value, pretty_format_from_sim_name, auto_crop
|
||||
from ..plotting import plot_setup
|
||||
from .. import env, math
|
||||
|
||||
@@ -24,7 +25,7 @@ def fingerprint(params: BareParams):
|
||||
return h1, h2
|
||||
|
||||
|
||||
def plot_all(sim_dir: Path, limits: list[str], **opts):
|
||||
def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
|
||||
for k, v in opts.items():
|
||||
if k in ["skip"]:
|
||||
opts[k] = int(v)
|
||||
@@ -39,7 +40,12 @@ def plot_all(sim_dir: Path, limits: list[str], **opts):
|
||||
pulse = Pulse(p)
|
||||
for left, right, unit in limits:
|
||||
path, fig, ax = plot_setup(
|
||||
pulse.path.parent / f"{pulse.path.name}_{left:.1f}_{right:.1f}_{unit}"
|
||||
pulse.path.parent
|
||||
/ (
|
||||
pretty_format_from_sim_name(pulse.path.name)
|
||||
+ PARAM_SEPARATOR
|
||||
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
|
||||
)
|
||||
)
|
||||
pulse.plot_2D(
|
||||
left,
|
||||
@@ -49,8 +55,11 @@ def plot_all(sim_dir: Path, limits: list[str], **opts):
|
||||
**opts,
|
||||
)
|
||||
bar.update()
|
||||
fig.savefig(path, bbox_inches="tight")
|
||||
plt.close("all")
|
||||
if show:
|
||||
plt.show()
|
||||
else:
|
||||
fig.savefig(path, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_init_field_spec(
|
||||
|
||||
@@ -208,15 +208,30 @@ def format_value(value) -> str:
|
||||
|
||||
|
||||
def pretty_format_value(name: str, value) -> str:
|
||||
return getattr(BareParams, name).display(value)
|
||||
try:
|
||||
return getattr(BareParams, name).display(value)
|
||||
except AttributeError:
|
||||
return name + PARAM_SEPARATOR + str(value)
|
||||
|
||||
|
||||
def pretty_format_from_file_name(name: str) -> str:
|
||||
def pretty_format_from_sim_name(name: str) -> str:
|
||||
"""formats a pretty version of a simulation directory
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
name of the simulation (directory name)
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
prettier name
|
||||
"""
|
||||
s = name.split(PARAM_SEPARATOR)
|
||||
out = []
|
||||
for key, value in zip(s[::2], s[1::2]):
|
||||
try:
|
||||
out.append(getattr(BareParams, key).display(float(value)))
|
||||
out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))]
|
||||
except (AttributeError, ValueError):
|
||||
out.append(key + PARAM_SEPARATOR + value)
|
||||
return PARAM_SEPARATOR.join(out)
|
||||
@@ -307,6 +322,9 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
|
||||
variable[k] = v
|
||||
for k in variable:
|
||||
new_dict[k] = None
|
||||
|
||||
new_dict["readjust_wavelength"] = False
|
||||
|
||||
return replace(old, variable=variable, **new_dict)
|
||||
|
||||
|
||||
|
||||
@@ -359,7 +359,7 @@ class BareParams:
|
||||
quantum_noise: bool = Parameter(boolean)
|
||||
shape: str = Parameter(literal("gaussian", "sech"))
|
||||
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
|
||||
intensity_noise: float = Parameter(in_range_incl(0, 1))
|
||||
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"))
|
||||
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
import scgenerator as sc
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
p = Path("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PPP")
|
||||
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/")
|
||||
|
||||
configs = [
|
||||
sc.io.load_config(p / c)
|
||||
for c in ("PM1550.toml", "PMHNLF_appended.toml", "PM2000_appended.toml")
|
||||
]
|
||||
root = Path("PM1550+PMHNLF+PM1550+PM2000")
|
||||
|
||||
for variable, params in sc.utils.required_simulations(*configs):
|
||||
print(variable)
|
||||
confs = sc.io.load_config_sequence(root / "4_PM2000.toml")
|
||||
final = sc.utils.final_config_from_sequence(*confs)
|
||||
|
||||
# sc.initialize.ContinuationParamSequence(configs[-1])
|
||||
print(final)
|
||||
|
||||
Reference in New Issue
Block a user