some problem with overriding

This commit is contained in:
Benoît Sierro
2021-07-29 13:24:29 +02:00
parent ea943d7adf
commit cdec1cf43f
6 changed files with 51 additions and 20 deletions

View File

@@ -74,6 +74,9 @@ def create_parser():
"One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps", "One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps",
) )
plot_parser.add_argument("--options", "-o", default=None) plot_parser.add_argument("--options", "-o", default=None)
plot_parser.add_argument(
"--show", action="store_true", help="show the plots instead of saving them"
)
plot_parser.set_defaults(func=plot_all) plot_parser.set_defaults(func=plot_all)
dispersion_parser = subparsers.add_parser( dispersion_parser = subparsers.add_parser(
@@ -186,7 +189,7 @@ def plot_all(args):
if args.options is not None: if args.options is not None:
opts |= dict([o.split("=")[:2] for o in re.split("[, ]", args.options)]) opts |= dict([o.split("=")[:2] for o in re.split("[, ]", args.options)])
root = Path(args.sim_dir).resolve() root = Path(args.sim_dir).resolve()
scripts.plot_all(root, args.spectrum_limits, **opts) scripts.plot_all(root, args.spectrum_limits, show=args.show, **opts)
def plot_init_field_spec(args): def plot_init_field_spec(args):

View File

@@ -13,6 +13,8 @@ from scipy.interpolate import UnivariateSpline
from scipy.interpolate.interpolate import interp1d from scipy.interpolate.interpolate import interp1d
from tqdm import utils from tqdm import utils
from scgenerator.const import PARAM_SEPARATOR
from .logger import get_logger from .logger import get_logger
from . import io, math from . import io, math
@@ -65,7 +67,7 @@ def plot_setup(
# ensure no overwrite # ensure no overwrite
ind = 0 ind = 0
while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists(): while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists():
ind += 1 ind += 1
if mode == "default": if mode == "default":
@@ -864,10 +866,11 @@ def uniform_axis(
else: else:
raise TypeError(f"Don't know how to interpret {new_axis_spec}") raise TypeError(f"Don't know how to interpret {new_axis_spec}")
tmp_axis, ind, ext = units.sort_axis(axis, plt_range) tmp_axis, ind, ext = units.sort_axis(axis, plt_range)
values = np.atleast_2d(values)
if np.allclose((diff := np.diff(tmp_axis))[0], diff): if np.allclose((diff := np.diff(tmp_axis))[0], diff):
new_axis = tmp_axis new_axis = tmp_axis
values = values[:, ind]
else: else:
values = np.atleast_2d(values)
if plt_range.unit.type == "WL": if plt_range.unit.type == "WL":
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis) values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis)) new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))

View File

@@ -9,11 +9,12 @@ import numpy as np
from tqdm import tqdm from tqdm import tqdm
from ..utils.parameter import BareParams from ..utils.parameter import BareParams
from ..const import PARAM_SEPARATOR
from ..initialize import ParamSequence from ..initialize import ParamSequence
from ..physics import units, fiber from ..physics import units, fiber
from ..spectra import Pulse from ..spectra import Pulse
from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop from ..utils import pretty_format_value, pretty_format_from_sim_name, auto_crop
from ..plotting import plot_setup from ..plotting import plot_setup
from .. import env, math from .. import env, math
@@ -24,7 +25,7 @@ def fingerprint(params: BareParams):
return h1, h2 return h1, h2
def plot_all(sim_dir: Path, limits: list[str], **opts): def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
for k, v in opts.items(): for k, v in opts.items():
if k in ["skip"]: if k in ["skip"]:
opts[k] = int(v) opts[k] = int(v)
@@ -39,7 +40,12 @@ def plot_all(sim_dir: Path, limits: list[str], **opts):
pulse = Pulse(p) pulse = Pulse(p)
for left, right, unit in limits: for left, right, unit in limits:
path, fig, ax = plot_setup( path, fig, ax = plot_setup(
pulse.path.parent / f"{pulse.path.name}_{left:.1f}_{right:.1f}_{unit}" pulse.path.parent
/ (
pretty_format_from_sim_name(pulse.path.name)
+ PARAM_SEPARATOR
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
)
) )
pulse.plot_2D( pulse.plot_2D(
left, left,
@@ -49,8 +55,11 @@ def plot_all(sim_dir: Path, limits: list[str], **opts):
**opts, **opts,
) )
bar.update() bar.update()
if show:
plt.show()
else:
fig.savefig(path, bbox_inches="tight") fig.savefig(path, bbox_inches="tight")
plt.close("all") plt.close(fig)
def plot_init_field_spec( def plot_init_field_spec(

View File

@@ -208,15 +208,30 @@ def format_value(value) -> str:
def pretty_format_value(name: str, value) -> str: def pretty_format_value(name: str, value) -> str:
try:
return getattr(BareParams, name).display(value) return getattr(BareParams, name).display(value)
except AttributeError:
return name + PARAM_SEPARATOR + str(value)
def pretty_format_from_file_name(name: str) -> str: def pretty_format_from_sim_name(name: str) -> str:
"""formats a pretty version of a simulation directory
Parameters
----------
name : str
name of the simulation (directory name)
Returns
-------
str
prettier name
"""
s = name.split(PARAM_SEPARATOR) s = name.split(PARAM_SEPARATOR)
out = [] out = []
for key, value in zip(s[::2], s[1::2]): for key, value in zip(s[::2], s[1::2]):
try: try:
out.append(getattr(BareParams, key).display(float(value))) out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))]
except (AttributeError, ValueError): except (AttributeError, ValueError):
out.append(key + PARAM_SEPARATOR + value) out.append(key + PARAM_SEPARATOR + value)
return PARAM_SEPARATOR.join(out) return PARAM_SEPARATOR.join(out)
@@ -307,6 +322,9 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
variable[k] = v variable[k] = v
for k in variable: for k in variable:
new_dict[k] = None new_dict[k] = None
new_dict["readjust_wavelength"] = False
return replace(old, variable=variable, **new_dict) return replace(old, variable=variable, **new_dict)

View File

@@ -359,7 +359,7 @@ class BareParams:
quantum_noise: bool = Parameter(boolean) quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech")) shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm")) wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
intensity_noise: float = Parameter(in_range_incl(0, 1)) intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"))
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs")) t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))

View File

@@ -1,14 +1,12 @@
import scgenerator as sc import scgenerator as sc
from pathlib import Path from pathlib import Path
import os
p = Path("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PPP") os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/")
configs = [ root = Path("PM1550+PMHNLF+PM1550+PM2000")
sc.io.load_config(p / c)
for c in ("PM1550.toml", "PMHNLF_appended.toml", "PM2000_appended.toml")
]
for variable, params in sc.utils.required_simulations(*configs): confs = sc.io.load_config_sequence(root / "4_PM2000.toml")
print(variable) final = sc.utils.final_config_from_sequence(*confs)
# sc.initialize.ContinuationParamSequence(configs[-1]) print(final)