some problem with overriding

This commit is contained in:
Benoît Sierro
2021-07-29 13:24:29 +02:00
parent ea943d7adf
commit cdec1cf43f
6 changed files with 51 additions and 20 deletions

View File

@@ -74,6 +74,9 @@ def create_parser():
"One plot is made for each limit set provided. Example : 600,1200,nm or -2,2,ps",
)
plot_parser.add_argument("--options", "-o", default=None)
plot_parser.add_argument(
"--show", action="store_true", help="show the plots instead of saving them"
)
plot_parser.set_defaults(func=plot_all)
dispersion_parser = subparsers.add_parser(
@@ -186,7 +189,7 @@ def plot_all(args):
if args.options is not None:
opts |= dict([o.split("=")[:2] for o in re.split("[, ]", args.options)])
root = Path(args.sim_dir).resolve()
scripts.plot_all(root, args.spectrum_limits, **opts)
scripts.plot_all(root, args.spectrum_limits, show=args.show, **opts)
def plot_init_field_spec(args):

View File

@@ -13,6 +13,8 @@ from scipy.interpolate import UnivariateSpline
from scipy.interpolate.interpolate import interp1d
from tqdm import utils
from scgenerator.const import PARAM_SEPARATOR
from .logger import get_logger
from . import io, math
@@ -65,7 +67,7 @@ def plot_setup(
# ensure no overwrite
ind = 0
while (full_path := (out_dir / (plot_name + f"_{ind}." + file_type))).exists():
while (full_path := (out_dir / (plot_name + f"{PARAM_SEPARATOR}{ind}." + file_type))).exists():
ind += 1
if mode == "default":
@@ -864,10 +866,11 @@ def uniform_axis(
else:
raise TypeError(f"Don't know how to interpret {new_axis_spec}")
tmp_axis, ind, ext = units.sort_axis(axis, plt_range)
values = np.atleast_2d(values)
if np.allclose((diff := np.diff(tmp_axis))[0], diff):
new_axis = tmp_axis
values = values[:, ind]
else:
values = np.atleast_2d(values)
if plt_range.unit.type == "WL":
values[:, ind] = np.apply_along_axis(units.to_WL, 1, values[:, ind], tmp_axis)
new_axis = np.linspace(tmp_axis.min(), tmp_axis.max(), len(tmp_axis))

View File

@@ -9,11 +9,12 @@ import numpy as np
from tqdm import tqdm
from ..utils.parameter import BareParams
from ..const import PARAM_SEPARATOR
from ..initialize import ParamSequence
from ..physics import units, fiber
from ..spectra import Pulse
from ..utils import pretty_format_value, pretty_format_from_file_name, auto_crop
from ..utils import pretty_format_value, pretty_format_from_sim_name, auto_crop
from ..plotting import plot_setup
from .. import env, math
@@ -24,7 +25,7 @@ def fingerprint(params: BareParams):
return h1, h2
def plot_all(sim_dir: Path, limits: list[str], **opts):
def plot_all(sim_dir: Path, limits: list[str], show=False, **opts):
for k, v in opts.items():
if k in ["skip"]:
opts[k] = int(v)
@@ -39,7 +40,12 @@ def plot_all(sim_dir: Path, limits: list[str], **opts):
pulse = Pulse(p)
for left, right, unit in limits:
path, fig, ax = plot_setup(
pulse.path.parent / f"{pulse.path.name}_{left:.1f}_{right:.1f}_{unit}"
pulse.path.parent
/ (
pretty_format_from_sim_name(pulse.path.name)
+ PARAM_SEPARATOR
+ f"{left:.1f}{PARAM_SEPARATOR}{right:.1f}{PARAM_SEPARATOR}{unit}"
)
)
pulse.plot_2D(
left,
@@ -49,8 +55,11 @@ def plot_all(sim_dir: Path, limits: list[str], **opts):
**opts,
)
bar.update()
if show:
plt.show()
else:
fig.savefig(path, bbox_inches="tight")
plt.close("all")
plt.close(fig)
def plot_init_field_spec(

View File

@@ -208,15 +208,30 @@ def format_value(value) -> str:
def pretty_format_value(name: str, value) -> str:
try:
return getattr(BareParams, name).display(value)
except AttributeError:
return name + PARAM_SEPARATOR + str(value)
def pretty_format_from_file_name(name: str) -> str:
def pretty_format_from_sim_name(name: str) -> str:
"""formats a pretty version of a simulation directory
Parameters
----------
name : str
name of the simulation (directory name)
Returns
-------
str
prettier name
"""
s = name.split(PARAM_SEPARATOR)
out = []
for key, value in zip(s[::2], s[1::2]):
try:
out.append(getattr(BareParams, key).display(float(value)))
out += [key.replace("_", " "), getattr(BareParams, key).display(float(value))]
except (AttributeError, ValueError):
out.append(key + PARAM_SEPARATOR + value)
return PARAM_SEPARATOR.join(out)
@@ -307,6 +322,9 @@ def override_config(new: BareConfig, old: BareConfig = None) -> BareConfig:
variable[k] = v
for k in variable:
new_dict[k] = None
new_dict["readjust_wavelength"] = False
return replace(old, variable=variable, **new_dict)

View File

@@ -359,7 +359,7 @@ class BareParams:
quantum_noise: bool = Parameter(boolean)
shape: str = Parameter(literal("gaussian", "sech"))
wavelength: float = Parameter(in_range_incl(100e-9, 3000e-9), display_info=(1e9, "nm"))
intensity_noise: float = Parameter(in_range_incl(0, 1))
intensity_noise: float = Parameter(in_range_incl(0, 1), display_info=(1e2, "%"))
width: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))
t0: float = Parameter(in_range_excl(0, 1e-9), display_info=(1e15, "fs"))

View File

@@ -1,14 +1,12 @@
import scgenerator as sc
from pathlib import Path
import os
p = Path("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/PPP")
os.chdir("/Users/benoitsierro/Nextcloud/PhD/Supercontinuum/PCF Simulations/")
configs = [
sc.io.load_config(p / c)
for c in ("PM1550.toml", "PMHNLF_appended.toml", "PM2000_appended.toml")
]
root = Path("PM1550+PMHNLF+PM1550+PM2000")
for variable, params in sc.utils.required_simulations(*configs):
print(variable)
confs = sc.io.load_config_sequence(root / "4_PM2000.toml")
final = sc.utils.final_config_from_sequence(*confs)
# sc.initialize.ContinuationParamSequence(configs[-1])
print(final)