file name still wrong in converter
This commit is contained in:
@@ -2,9 +2,24 @@ from . import math
|
|||||||
from .math import abs2, argclosest, span
|
from .math import abs2, argclosest, span
|
||||||
from .physics import fiber, materials, pulse, simulate, units
|
from .physics import fiber, materials, pulse, simulate, units
|
||||||
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
from .physics.simulate import RK4IP, parallel_RK4IP, run_simulation
|
||||||
from .plotting import mean_values_plot, plot_spectrogram, propagation_plot, single_position_plot
|
from .plotting import (
|
||||||
|
mean_values_plot,
|
||||||
|
plot_spectrogram,
|
||||||
|
propagation_plot,
|
||||||
|
single_position_plot,
|
||||||
|
transform_2D_propagation,
|
||||||
|
transform_1D_values,
|
||||||
|
transform_mean_values,
|
||||||
|
get_extent,
|
||||||
|
)
|
||||||
from .spectra import Pulse, Spectrum, SimulationSeries
|
from .spectra import Pulse, Spectrum, SimulationSeries
|
||||||
from ._utils import Paths, open_config, parameter
|
from ._utils import Paths, open_config, parameter
|
||||||
from ._utils.parameter import Configuration, Parameters
|
from ._utils.parameter import Configuration, Parameters
|
||||||
from ._utils.utils import PlotRange
|
from ._utils.utils import PlotRange
|
||||||
from ._utils.variationer import Variationer, VariationDescriptor, VariationSpecsError
|
from ._utils.legacy import convert_sim_folder
|
||||||
|
from ._utils.variationer import (
|
||||||
|
Variationer,
|
||||||
|
VariationDescriptor,
|
||||||
|
VariationSpecsError,
|
||||||
|
DescriptorDict,
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,8 +8,9 @@ import numpy as np
|
|||||||
import toml
|
import toml
|
||||||
|
|
||||||
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
|
from ..const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1, Z_FN
|
||||||
from .parameter import Parameters
|
from .parameter import Configuration, Parameters
|
||||||
from .utils import fiber_folder, update_path, save_parameters
|
from .utils import fiber_folder, save_parameters
|
||||||
|
from .pbar import PBars
|
||||||
from .variationer import VariationDescriptor, Variationer
|
from .variationer import VariationDescriptor, Variationer
|
||||||
|
|
||||||
|
|
||||||
@@ -29,21 +30,32 @@ def convert_sim_folder(path: os.PathLike):
|
|||||||
path = Path(path)
|
path = Path(path)
|
||||||
config_paths, configs = load_config_sequence(path)
|
config_paths, configs = load_config_sequence(path)
|
||||||
master_config = dict(name=path.name, Fiber=configs)
|
master_config = dict(name=path.name, Fiber=configs)
|
||||||
|
with open(path / "initial_config.toml", "w") as f:
|
||||||
|
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
||||||
|
configuration = Configuration(path / "initial_config.toml")
|
||||||
new_fiber_paths: list[Path] = [
|
new_fiber_paths: list[Path] = [
|
||||||
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
|
path / fiber_folder(i, path.name, cfg["name"]) for i, cfg in enumerate(configs)
|
||||||
]
|
]
|
||||||
for p in new_fiber_paths:
|
for p in new_fiber_paths:
|
||||||
p.mkdir(exist_ok=True)
|
p.mkdir(exist_ok=True)
|
||||||
var = Variationer(c["variable"] for c in configs)
|
repeat = configs[0].get("repeat", 1)
|
||||||
|
|
||||||
paths: dict[Path, VariationDescriptor] = {
|
pbar = PBars(configuration.total_num_steps, "Converting")
|
||||||
path / descr.branch.formatted_descriptor(): descr for descr in var.iterate()
|
|
||||||
|
old_paths: dict[Path, VariationDescriptor] = {
|
||||||
|
path / descr.branch.formatted_descriptor(): (descr, param.final_path)
|
||||||
|
for descr, param in configuration
|
||||||
}
|
}
|
||||||
for p in paths:
|
|
||||||
|
# create map from old to new path
|
||||||
|
|
||||||
|
pprint(old_paths)
|
||||||
|
quit()
|
||||||
|
for p in old_paths:
|
||||||
if not p.is_dir():
|
if not p.is_dir():
|
||||||
raise FileNotFoundError(f"missing {p} from {path}")
|
raise FileNotFoundError(f"missing {p} from {path}")
|
||||||
processed_paths: Set[Path] = set()
|
processed_paths: Set[Path] = set()
|
||||||
for old_variation_path, descriptor in paths.items(): # fiberA=0, fiber B=0
|
for old_variation_path, descriptor in old_paths.items(): # fiberA=0, fiber B=0
|
||||||
vary_parts = old_variation_path.name.split("fiber")[1:]
|
vary_parts = old_variation_path.name.split("fiber")[1:]
|
||||||
identifiers = [
|
identifiers = [
|
||||||
"".join("fiber" + el for el in vary_parts[: i + 1]).strip()
|
"".join("fiber" + el for el in vary_parts[: i + 1]).strip()
|
||||||
@@ -72,6 +84,9 @@ def convert_sim_folder(path: os.PathLike):
|
|||||||
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j),
|
new_variation_path / SPEC1_FN_N.format(spec_num - cum_z_num, j),
|
||||||
spec1,
|
spec1,
|
||||||
)
|
)
|
||||||
|
pbar.update()
|
||||||
|
else:
|
||||||
|
pbar.update(value=repeat)
|
||||||
old_spec.unlink()
|
old_spec.unlink()
|
||||||
if move:
|
if move:
|
||||||
if i > 0:
|
if i > 0:
|
||||||
@@ -88,8 +103,6 @@ def convert_sim_folder(path: os.PathLike):
|
|||||||
|
|
||||||
for cp in config_paths:
|
for cp in config_paths:
|
||||||
cp.unlink()
|
cp.unlink()
|
||||||
with open(path / "initial_config.toml", "w") as f:
|
|
||||||
toml.dump(master_config, f, encoder=toml.TomlNumpyEncoder())
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from ..const import PARAM_FN, PARAM_SEPARATOR, __version__
|
|||||||
from ..errors import EvaluatorError, NoDefaultError
|
from ..errors import EvaluatorError, NoDefaultError
|
||||||
from ..logger import get_logger
|
from ..logger import get_logger
|
||||||
from ..physics import fiber, materials, pulse, units
|
from ..physics import fiber, materials, pulse, units
|
||||||
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names
|
from .utils import _mock_function, fiber_folder, func_rewrite, get_arg_names, update_path
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -75,6 +75,7 @@ VALID_VARIABLE = {
|
|||||||
"interpolation_degree",
|
"interpolation_degree",
|
||||||
"ideal_gas",
|
"ideal_gas",
|
||||||
"length",
|
"length",
|
||||||
|
"num",
|
||||||
}
|
}
|
||||||
|
|
||||||
MANDATORY_PARAMETERS = [
|
MANDATORY_PARAMETERS = [
|
||||||
@@ -519,6 +520,12 @@ class Parameters(_AbstractParameters):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@property
|
||||||
|
def final_path(self) -> Path:
|
||||||
|
if self.output_path is not None:
|
||||||
|
return update_path(self.output_path)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class Rule:
|
class Rule:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -777,6 +784,7 @@ class Configuration:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
fiber_configs: list[dict[str, Any]]
|
fiber_configs: list[dict[str, Any]]
|
||||||
|
vary_dicts: list[dict[str, list]]
|
||||||
master_config: dict[str, Any]
|
master_config: dict[str, Any]
|
||||||
fiber_paths: list[Path]
|
fiber_paths: list[Path]
|
||||||
num_sim: int
|
num_sim: int
|
||||||
@@ -814,9 +822,11 @@ class Configuration:
|
|||||||
self,
|
self,
|
||||||
final_config_path: os.PathLike,
|
final_config_path: os.PathLike,
|
||||||
overwrite: bool = True,
|
overwrite: bool = True,
|
||||||
|
wait: bool = False,
|
||||||
skip_callback: Callable[[int], None] = None,
|
skip_callback: Callable[[int], None] = None,
|
||||||
):
|
):
|
||||||
self.logger = get_logger(__name__)
|
self.logger = get_logger(__name__)
|
||||||
|
self.wait = wait
|
||||||
|
|
||||||
self.overwrite = overwrite
|
self.overwrite = overwrite
|
||||||
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
self.final_path, self.fiber_configs = utils.load_config_sequence(final_config_path)
|
||||||
@@ -842,7 +852,8 @@ class Configuration:
|
|||||||
config.setdefault("name", Parameters.name.default)
|
config.setdefault("name", Parameters.name.default)
|
||||||
self.z_num += config["z_num"]
|
self.z_num += config["z_num"]
|
||||||
fiber_names.add(config["name"])
|
fiber_names.add(config["name"])
|
||||||
self.variationer.append(config.pop("variable"))
|
vary_dict = config.pop("variable")
|
||||||
|
self.variationer.append(vary_dict)
|
||||||
self.fiber_paths.append(
|
self.fiber_paths.append(
|
||||||
utils.ensure_folder(
|
utils.ensure_folder(
|
||||||
self.final_path / fiber_folder(i, self.name, config["name"]),
|
self.final_path / fiber_folder(i, self.name, config["name"]),
|
||||||
@@ -850,9 +861,11 @@ class Configuration:
|
|||||||
prevent_overwrite=not self.overwrite,
|
prevent_overwrite=not self.overwrite,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.__validate_variable(config)
|
self.__validate_variable(vary_dict)
|
||||||
self.num_fibers += 1
|
self.num_fibers += 1
|
||||||
Evaluator.evaluate_default(config, True)
|
Evaluator.evaluate_default(
|
||||||
|
self.__build_base_config() | config | {k: v[0] for k, v in vary_dict.items()}, True
|
||||||
|
)
|
||||||
self.num_sim = self.variationer.var_num()
|
self.num_sim = self.variationer.var_num()
|
||||||
self.total_num_steps = sum(
|
self.total_num_steps = sum(
|
||||||
config["z_num"] * self.variationer.var_num(i)
|
config["z_num"] * self.variationer.var_num(i)
|
||||||
@@ -860,8 +873,13 @@ class Configuration:
|
|||||||
)
|
)
|
||||||
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
self.parallel = self.master_config.get("parallel", Parameters.parallel.default)
|
||||||
|
|
||||||
def __validate_variable(self, config: dict[str, Any]):
|
def __build_base_config(self):
|
||||||
for k, v in config.get("variable", {}).items():
|
cfg = self.fiber_configs[0].copy()
|
||||||
|
vary = cfg.pop("variable", {})
|
||||||
|
return cfg | {k: v[0] for k, v in vary.items()}
|
||||||
|
|
||||||
|
def __validate_variable(self, vary_dict: dict[str, list]):
|
||||||
|
for k, v in vary_dict.items():
|
||||||
p = getattr(Parameters, k)
|
p = getattr(Parameters, k)
|
||||||
validator_list(p.validator)("variable " + k, v)
|
validator_list(p.validator)("variable " + k, v)
|
||||||
if k not in VALID_VARIABLE:
|
if k not in VALID_VARIABLE:
|
||||||
@@ -873,7 +891,6 @@ class Configuration:
|
|||||||
for i in range(self.num_fibers):
|
for i in range(self.num_fibers):
|
||||||
for sim_config in self.iterate_single_fiber(i):
|
for sim_config in self.iterate_single_fiber(i):
|
||||||
params = Parameters(**sim_config.config)
|
params = Parameters(**sim_config.config)
|
||||||
params.compute()
|
|
||||||
yield sim_config.descriptor, params
|
yield sim_config.descriptor, params
|
||||||
|
|
||||||
def iterate_single_fiber(
|
def iterate_single_fiber(
|
||||||
@@ -943,6 +960,8 @@ class Configuration:
|
|||||||
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
config dictionary. The only key possibly modified is 'prev_data_dir', which
|
||||||
gets set if the simulation is partially completed
|
gets set if the simulation is partially completed
|
||||||
"""
|
"""
|
||||||
|
if not self.wait:
|
||||||
|
return self.Action.RUN, sim_config.config
|
||||||
out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
|
out_status, num = self.sim_status(sim_config.output_path, sim_config.config)
|
||||||
if out_status == self.State.COMPLETE:
|
if out_status == self.State.COMPLETE:
|
||||||
return self.Action.SKIP, sim_config.config
|
return self.Action.SKIP, sim_config.config
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import abc
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import multiprocessing
|
|
||||||
import threading
|
import threading
|
||||||
import typing
|
import typing
|
||||||
|
from collections import abc
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Union
|
from typing import Iterable, Union
|
||||||
@@ -24,7 +24,19 @@ class PBars:
|
|||||||
head_kwargs=None,
|
head_kwargs=None,
|
||||||
worker_kwargs=None,
|
worker_kwargs=None,
|
||||||
) -> "PBars":
|
) -> "PBars":
|
||||||
|
"""creates a PBars obj
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
task : int | Iterable
|
||||||
|
if int : total length of the main task
|
||||||
|
if Iterable : behaves like tqdm
|
||||||
|
desc : str
|
||||||
|
description of the main task
|
||||||
|
num_sub_bars : int
|
||||||
|
number of sub-tasks
|
||||||
|
|
||||||
|
"""
|
||||||
self.id = random.randint(100000, 999999)
|
self.id = random.randint(100000, 999999)
|
||||||
try:
|
try:
|
||||||
self.width = os.get_terminal_size().columns
|
self.width = os.get_terminal_size().columns
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
|||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import printable as str_printable
|
from string import printable as str_printable
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, Iterator, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import toml
|
import toml
|
||||||
@@ -236,3 +236,24 @@ def update_path(p: str) -> str:
|
|||||||
|
|
||||||
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
|
def fiber_folder(i: int, sim_name: str, fiber_name: str) -> str:
|
||||||
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
|
return PARAM_SEPARATOR.join([format(i), sim_name, fiber_name])
|
||||||
|
|
||||||
|
|
||||||
|
def iter_simulations(path: os.PathLike) -> list[Path]:
|
||||||
|
"""finds simulations folders contained in a parent directory
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : os.PathLike
|
||||||
|
parent path
|
||||||
|
|
||||||
|
Yields
|
||||||
|
-------
|
||||||
|
Path
|
||||||
|
Absolute Path to the simulation folder
|
||||||
|
"""
|
||||||
|
paths: list[Path] = []
|
||||||
|
for pwd, _, files in os.walk(path):
|
||||||
|
if PARAM_FN in files:
|
||||||
|
paths.append(Path(pwd))
|
||||||
|
paths.sort(key=lambda el: el.parent.name)
|
||||||
|
return [p for p in paths if p.parent.name == paths[-1].parent.name]
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ from math import prod
|
|||||||
import itertools
|
import itertools
|
||||||
from collections.abc import MutableMapping, Sequence
|
from collections.abc import MutableMapping, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Generator, Iterable, Optional, Union
|
from typing import Any, Callable, Generator, Generic, Iterable, Iterator, Optional, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
|
from pydantic.main import BaseModel
|
||||||
|
|
||||||
from ..const import PARAM_SEPARATOR
|
from ..const import PARAM_SEPARATOR
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class VariationSpecsError(ValueError):
|
class VariationSpecsError(ValueError):
|
||||||
pass
|
pass
|
||||||
@@ -111,15 +114,15 @@ class Variationer:
|
|||||||
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
return max(1, prod(prod(el) for el in self.all_indices[: index + 1]))
|
||||||
|
|
||||||
|
|
||||||
class VariationDescriptor(utils.HashableBaseModel):
|
class VariationDescriptor(BaseModel):
|
||||||
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
raw_descr: tuple[tuple[tuple[str, Any], ...], ...]
|
||||||
index: tuple[tuple[int, ...], ...]
|
index: tuple[tuple[int, ...], ...]
|
||||||
separator: str = "fiber"
|
separator: str = "fiber"
|
||||||
_format_registry: dict[str, Callable[..., str]] = {}
|
_format_registry: dict[str, Callable[..., str]] = {}
|
||||||
__ids: dict[int, int] = {}
|
__ids: dict[int, int] = {}
|
||||||
|
|
||||||
def __str__(self) -> str:
|
class Config:
|
||||||
return self.formatted_descriptor(add_identifier=False)
|
allow_mutation = False
|
||||||
|
|
||||||
def formatted_descriptor(self, add_identifier=False) -> str:
|
def formatted_descriptor(self, add_identifier=False) -> str:
|
||||||
"""formats a variable list into a str such that each simulation has a unique
|
"""formats a variable list into a str such that each simulation has a unique
|
||||||
@@ -183,6 +186,24 @@ class VariationDescriptor(utils.HashableBaseModel):
|
|||||||
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
raw_descr=self.raw_descr[key], index=self.index[key], separator=self.separator
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.formatted_descriptor(add_identifier=False)
|
||||||
|
|
||||||
|
def __lt__(self, other: "VariationDescriptor") -> bool:
|
||||||
|
return self.raw_descr < other.raw_descr
|
||||||
|
|
||||||
|
def __le__(self, other: "VariationDescriptor") -> bool:
|
||||||
|
return self.raw_descr <= other.raw_descr
|
||||||
|
|
||||||
|
def __gt__(self, other: "VariationDescriptor") -> bool:
|
||||||
|
return self.raw_descr > other.raw_descr
|
||||||
|
|
||||||
|
def __ge__(self, other: "VariationDescriptor") -> bool:
|
||||||
|
return self.raw_descr >= other.raw_descr
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.raw_descr)
|
||||||
|
|
||||||
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
def update_config(self, cfg: dict[str, Any], index=-1) -> dict[str, Any]:
|
||||||
"""updates a dictionary with the value of the descriptor
|
"""updates a dictionary with the value of the descriptor
|
||||||
|
|
||||||
@@ -252,3 +273,34 @@ class BranchDescriptor(VariationDescriptor):
|
|||||||
@validator("raw_descr")
|
@validator("raw_descr")
|
||||||
def validate_raw_descr(cls, v):
|
def validate_raw_descr(cls, v):
|
||||||
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
|
return tuple(tuple(el for el in variable if el[0] != "num") for variable in v)
|
||||||
|
|
||||||
|
|
||||||
|
class DescriptorDict(Generic[T]):
|
||||||
|
def __init__(self, dico: dict[VariationDescriptor, T] = None):
|
||||||
|
self.dico: dict[tuple[tuple[tuple[str, Any], ...], ...], tuple[VariationDescriptor, T]] = {}
|
||||||
|
if dico is not None:
|
||||||
|
for k, v in dico.items():
|
||||||
|
self[k] = v
|
||||||
|
|
||||||
|
def __setitem__(self, key: VariationDescriptor, value: T):
|
||||||
|
if not isinstance(key, VariationDescriptor):
|
||||||
|
raise TypeError("key must be a VariationDescriptor instance")
|
||||||
|
self.dico[key.raw_descr] = (key, value)
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, key: Union[VariationDescriptor, tuple[tuple[tuple[str, Any], ...], ...]]
|
||||||
|
) -> T:
|
||||||
|
if isinstance(key, VariationDescriptor):
|
||||||
|
return self.dico[key.raw_descr][1]
|
||||||
|
else:
|
||||||
|
return self.dico[key][1]
|
||||||
|
|
||||||
|
def items(self) -> Iterator[tuple[VariationDescriptor, T]]:
|
||||||
|
for k, v in self.dico.items():
|
||||||
|
yield k, v[1]
|
||||||
|
|
||||||
|
def keys(self) -> list[VariationDescriptor]:
|
||||||
|
return [v[0] for v in self.dico.values()]
|
||||||
|
|
||||||
|
def values(self) -> list[T]:
|
||||||
|
return [v[1] for v in self.dico.values()]
|
||||||
|
|||||||
@@ -491,6 +491,7 @@ class Simulations:
|
|||||||
|
|
||||||
def _run_available(self):
|
def _run_available(self):
|
||||||
for variable, params in self.configuration:
|
for variable, params in self.configuration:
|
||||||
|
params.compute()
|
||||||
v_list_str = variable.formatted_descriptor(True)
|
v_list_str = variable.formatted_descriptor(True)
|
||||||
save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
save_parameters(params.prepare_for_dump(), Path(params.output_path))
|
||||||
|
|
||||||
@@ -718,7 +719,7 @@ def run_simulation(
|
|||||||
config_file: os.PathLike,
|
config_file: os.PathLike,
|
||||||
method: Union[str, Type[Simulations]] = None,
|
method: Union[str, Type[Simulations]] = None,
|
||||||
):
|
):
|
||||||
config = Configuration(config_file)
|
config = Configuration(config_file, wait=True)
|
||||||
|
|
||||||
sim = new_simulation(config, method)
|
sim = new_simulation(config, method)
|
||||||
sim.run()
|
sim.run()
|
||||||
@@ -760,6 +761,8 @@ def parallel_RK4IP(
|
|||||||
]:
|
]:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
params = list(Configuration(config))
|
params = list(Configuration(config))
|
||||||
|
for _, param in params:
|
||||||
|
param.compute()
|
||||||
n = len(params)
|
n = len(params)
|
||||||
z_num = params[0][1].z_num
|
z_num = params[0][1].z_num
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Union
|
||||||
@@ -12,8 +13,8 @@ from pydantic import BaseModel, DirectoryPath, root_validator
|
|||||||
from . import math
|
from . import math
|
||||||
from ._utils import load_spectrum
|
from ._utils import load_spectrum
|
||||||
from ._utils.parameter import Parameters
|
from ._utils.parameter import Parameters
|
||||||
from ._utils.utils import PlotRange
|
from ._utils.utils import PlotRange, iter_simulations
|
||||||
from .const import SPECN_FN1, PARAM_FN, SPEC1_FN_N, SPEC1_FN
|
from .const import PARAM_FN, SPEC1_FN, SPEC1_FN_N, SPECN_FN1
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .physics import pulse, units
|
from .physics import pulse, units
|
||||||
from .plotting import (
|
from .plotting import (
|
||||||
@@ -131,11 +132,10 @@ class SimulationSeries:
|
|||||||
|
|
||||||
def __init__(self, path: os.PathLike):
|
def __init__(self, path: os.PathLike):
|
||||||
self.logger = get_logger()
|
self.logger = get_logger()
|
||||||
path = Path(path)
|
for self.path in iter_simulations(path):
|
||||||
subdirs = [el for el in path.glob("*") if (el / PARAM_FN).exists()]
|
break
|
||||||
while not (path / PARAM_FN).exists() and len(subdirs) == 1:
|
else:
|
||||||
path = subdirs[0]
|
raise FileNotFoundError(f"No simulation in {path}")
|
||||||
self.path = path
|
|
||||||
self.params = Parameters.load(self.path / PARAM_FN)
|
self.params = Parameters.load(self.path / PARAM_FN)
|
||||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||||
self.t = self.params.t
|
self.t = self.params.t
|
||||||
@@ -356,24 +356,22 @@ class SimulationSeries:
|
|||||||
|
|
||||||
|
|
||||||
class Pulse(Sequence):
|
class Pulse(Sequence):
|
||||||
path: Path
|
def __new__(cls, path: os.PathLike):
|
||||||
default_ind: Optional[int]
|
warnings.warn(
|
||||||
params: Parameters
|
"You are using the legacy version of the pulse loader. "
|
||||||
z: np.ndarray
|
"Please consider updating your data with scgenerator.convert_sim_folder "
|
||||||
namx: int
|
"and loading data with the SimulationSeries class"
|
||||||
t: np.ndarray
|
)
|
||||||
w: np.ndarray
|
if (Path(path) / SPECN_FN1.format(0)).exists():
|
||||||
w_order: np.ndarray
|
return LegacyPulse(path)
|
||||||
|
return SimulationSeries(path)
|
||||||
|
|
||||||
def __new__(cls, path: os.PathLike, *args, **kwargs) -> "Pulse":
|
def __getitem__(self, key) -> Spectrum:
|
||||||
try:
|
raise NotImplementedError()
|
||||||
if load_spectrum(Path(path) / SPECN_FN1.format(0)).ndim == 2:
|
|
||||||
return super().__new__(LegacyPulse)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
return super().__new__(cls)
|
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
|
||||||
|
class LegacyPulse(Sequence):
|
||||||
|
def __init__(self, path: os.PathLike):
|
||||||
"""load a data folder as a pulse
|
"""load a data folder as a pulse
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -388,6 +386,35 @@ class Pulse(Sequence):
|
|||||||
FileNotFoundError
|
FileNotFoundError
|
||||||
path does not contain proper data
|
path does not contain proper data
|
||||||
"""
|
"""
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self.path = Path(path)
|
||||||
|
|
||||||
|
if not self.path.is_dir():
|
||||||
|
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
||||||
|
|
||||||
|
self.params = Parameters.load(self.path / "params.toml")
|
||||||
|
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
||||||
|
if self.params.fiber_map is None:
|
||||||
|
self.params.fiber_map = [(0.0, self.params.name)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.z = np.load(os.path.join(path, "z.npy"))
|
||||||
|
except FileNotFoundError:
|
||||||
|
if self.params is not None:
|
||||||
|
self.z = self.params.z_targets
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
||||||
|
if self.nmax <= 0:
|
||||||
|
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
||||||
|
|
||||||
|
self.t = self.params.t
|
||||||
|
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
||||||
|
self.w_order = np.argsort(w)
|
||||||
|
self.w = w
|
||||||
|
self.wl = units.m.inv(self.w)
|
||||||
|
self.params.w = self.w
|
||||||
|
self.params.z_targets = self.z
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
@@ -404,6 +431,73 @@ class Pulse(Sequence):
|
|||||||
def __getitem__(self, key) -> Spectrum:
|
def __getitem__(self, key) -> Spectrum:
|
||||||
return self.all_spectra(key)
|
return self.all_spectra(key)
|
||||||
|
|
||||||
|
def intensity(self, unit):
|
||||||
|
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||||
|
x_axis = unit.inv(self.w)
|
||||||
|
else:
|
||||||
|
x_axis = unit.inv(self.t)
|
||||||
|
|
||||||
|
order = np.argsort(x_axis)
|
||||||
|
func = dict(
|
||||||
|
WL=self._to_wl_int,
|
||||||
|
FREQ=self._to_freq_int,
|
||||||
|
AFREQ=self._to_afreq_int,
|
||||||
|
TIME=self._to_time_int,
|
||||||
|
)[unit.type]
|
||||||
|
|
||||||
|
for spec in self:
|
||||||
|
yield x_axis[order], func(spec)[:, order]
|
||||||
|
|
||||||
|
def _to_wl_int(self, spectrum):
|
||||||
|
return units.to_WL(math.abs2(spectrum), spectrum.wl)
|
||||||
|
|
||||||
|
def _to_freq_int(self, spectrum):
|
||||||
|
return math.abs2(spectrum)
|
||||||
|
|
||||||
|
def _to_afreq_int(self, spectrum):
|
||||||
|
return math.abs2(spectrum)
|
||||||
|
|
||||||
|
def _to_time_int(self, spectrum):
|
||||||
|
return math.abs2(np.fft.ifft(spectrum))
|
||||||
|
|
||||||
|
def amplitude(self, unit):
|
||||||
|
if unit.type in ["WL", "FREQ", "AFREQ"]:
|
||||||
|
x_axis = unit.inv(self.w)
|
||||||
|
else:
|
||||||
|
x_axis = unit.inv(self.t)
|
||||||
|
|
||||||
|
order = np.argsort(x_axis)
|
||||||
|
func = dict(
|
||||||
|
WL=self._to_wl_amp,
|
||||||
|
FREQ=self._to_freq_amp,
|
||||||
|
AFREQ=self._to_afreq_amp,
|
||||||
|
TIME=self._to_time_amp,
|
||||||
|
)[unit.type]
|
||||||
|
|
||||||
|
for spec in self:
|
||||||
|
yield x_axis[order], func(spec)[:, order]
|
||||||
|
|
||||||
|
def _to_wl_amp(self, spectrum):
|
||||||
|
return (
|
||||||
|
np.sqrt(
|
||||||
|
units.to_WL(
|
||||||
|
math.abs2(spectrum),
|
||||||
|
spectrum.wl,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
* spectrum
|
||||||
|
/ np.abs(spectrum)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _to_freq_amp(self, spectrum):
|
||||||
|
return spectrum
|
||||||
|
|
||||||
|
def _to_afreq_amp(self, spectrum):
|
||||||
|
return spectrum
|
||||||
|
|
||||||
|
def _to_time_amp(self, spectrum):
|
||||||
|
return np.fft.ifft(spectrum)
|
||||||
|
|
||||||
def all_spectra(self, ind=None) -> Spectrum:
|
def all_spectra(self, ind=None) -> Spectrum:
|
||||||
"""
|
"""
|
||||||
loads the data already simulated.
|
loads the data already simulated.
|
||||||
@@ -425,10 +519,7 @@ class Pulse(Sequence):
|
|||||||
# Check if file exists and assert how many z positions there are
|
# Check if file exists and assert how many z positions there are
|
||||||
|
|
||||||
if ind is None:
|
if ind is None:
|
||||||
if self.default_ind is None:
|
ind = range(self.nmax)
|
||||||
ind = range(self.nmax)
|
|
||||||
else:
|
|
||||||
ind = self.default_ind
|
|
||||||
if isinstance(ind, (int, np.integer)):
|
if isinstance(ind, (int, np.integer)):
|
||||||
ind = [ind]
|
ind = [ind]
|
||||||
elif isinstance(ind, (float, np.floating)):
|
elif isinstance(ind, (float, np.floating)):
|
||||||
@@ -452,7 +543,12 @@ class Pulse(Sequence):
|
|||||||
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
return np.fft.ifft(self.all_spectra(ind=ind), axis=-1)
|
||||||
|
|
||||||
def _load1(self, i: int):
|
def _load1(self, i: int):
|
||||||
pass
|
if i < 0:
|
||||||
|
i = self.nmax + i
|
||||||
|
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
||||||
|
spec = np.atleast_2d(spec)
|
||||||
|
spec = Spectrum(spec, self.params)
|
||||||
|
return spec
|
||||||
|
|
||||||
def plot_2D(
|
def plot_2D(
|
||||||
self,
|
self,
|
||||||
@@ -554,46 +650,3 @@ class Pulse(Sequence):
|
|||||||
index
|
index
|
||||||
"""
|
"""
|
||||||
return math.argclosest(self.z, z)
|
return math.argclosest(self.z, z)
|
||||||
|
|
||||||
|
|
||||||
class LegacyPulse(Pulse):
|
|
||||||
def __init__(self, path: os.PathLike, default_ind: Union[int, Iterable[int]] = None):
|
|
||||||
print("old init called", path, default_ind)
|
|
||||||
self.logger = get_logger(__name__)
|
|
||||||
self.path = Path(path)
|
|
||||||
self.default_ind = default_ind
|
|
||||||
|
|
||||||
if not self.path.is_dir():
|
|
||||||
raise FileNotFoundError(f"Folder {self.path} does not exist")
|
|
||||||
|
|
||||||
self.params = Parameters.load(self.path / "params.toml")
|
|
||||||
self.params.compute(["name", "t", "l", "w_c", "w0", "z_targets"])
|
|
||||||
if self.params.fiber_map is None:
|
|
||||||
self.params.fiber_map = [(0.0, self.params.name)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.z = np.load(os.path.join(path, "z.npy"))
|
|
||||||
except FileNotFoundError:
|
|
||||||
if self.params is not None:
|
|
||||||
self.z = self.params.z_targets
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
self.nmax = len(list(self.path.glob("spectra_*.npy")))
|
|
||||||
if self.nmax <= 0:
|
|
||||||
raise FileNotFoundError(f"No appropriate file in specified folder {self.path}")
|
|
||||||
|
|
||||||
self.t = self.params.t
|
|
||||||
w = math.wspace(self.t) + units.m(self.params.wavelength)
|
|
||||||
self.w_order = np.argsort(w)
|
|
||||||
self.w = w
|
|
||||||
self.wl = units.m.inv(self.w)
|
|
||||||
self.params.w = self.w
|
|
||||||
self.params.z_targets = self.z
|
|
||||||
|
|
||||||
def _load1(self, i: int):
|
|
||||||
if i < 0:
|
|
||||||
i = self.nmax + i
|
|
||||||
spec = load_spectrum(self.path / SPECN_FN1.format(i))
|
|
||||||
spec = np.atleast_2d(spec)
|
|
||||||
spec = Spectrum(spec, self.params)
|
|
||||||
return spec
|
|
||||||
|
|||||||
Reference in New Issue
Block a user