diff --git a/setup.cfg b/setup.cfg index 6ead902..f38dfaf 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,4 +35,8 @@ scgenerator = data/silica.json [options.packages.find] -where = src \ No newline at end of file +where = src + +[options.entry_points] +console_scripts = + scgenerator = scgenerator.cli.cli:main \ No newline at end of file diff --git a/src/scgenerator/cli/__init__.py b/src/scgenerator/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/scgenerator/cli/__main__.py b/src/scgenerator/cli/__main__.py deleted file mode 100644 index 7a8f1df..0000000 --- a/src/scgenerator/cli/__main__.py +++ /dev/null @@ -1,25 +0,0 @@ -import argparse - -def create_parser(): - parser = argparse.ArgumentParser( - description="scgenerator command", - prog="scgenerator" - ) - - return parser - - -def main(): - parser = create_parser() - subparsers = parser.add_subparsers( - help="sub-command help" - ) - - newconfig = subparsers.add_parser( - "newconfig", - help="create a new configuration file" - ) - - -if __name__ == "__main__": - main() diff --git a/src/scgenerator/cli/cli.py b/src/scgenerator/cli/cli.py new file mode 100644 index 0000000..912ef35 --- /dev/null +++ b/src/scgenerator/cli/cli.py @@ -0,0 +1,83 @@ +import argparse +import os +import random + +import ray +from scgenerator.physics.simulate import new_simulations, resume_simulations + + +def create_parser(): + parser = argparse.ArgumentParser(description="scgenerator command", prog="scgenerator") + + subparsers = parser.add_subparsers(help="sub-command help") + + parser.add_argument( + "--id", + type=int, + default=random.randint(0, 1e18), + help="Unique id of the session. Only useful when running several processes at the same time.", + ) + parser.add_argument( + "--start-ray", + action="store_true", + help="assume no ray instance has been started beforehand", + ) + + run_parser = subparsers.add_parser("run", help="run a simulation from a config file") + + run_parser.add_argument("config", help="path to the toml configuration file") + run_parser.set_defaults(func=run_sim) + + resume_parser = subparsers.add_parser("resume", help="resume a simulation") + resume_parser.add_argument( + "data_dir", + help="path to the directory where the initial_config.toml and the data is stored", + ) + resume_parser.set_defaults(func=resume_sim) + + newconfig_parser = subparsers.add_parser("newconfig", help="create a new configuration file") + + return parser + + +def main(): + parser = create_parser() + args = parser.parse_args() + args.func(args) + + +def run_sim(args): + + if args.start_ray: + init_str = ray.init() + else: + init_str = ray.init( + address="auto", + _node_ip_address=os.environ.get("ip_head", "127.0.0.1").split(":")[0], + _redis_password=os.environ.get("redis_password", "caco1234"), + ) + + print(init_str) + sim = new_simulations(args.config, args.id) + + sim.run() + + +def resume_sim(args): + if args.start_ray: + init_str = ray.init() + else: + init_str = ray.init( + address="auto", + _node_ip_address=os.environ.get("ip_head", "127.0.0.1").split(":")[0], + _redis_password=os.environ.get("redis_password", "caco1234"), + ) + + print(init_str) + sim = resume_simulations(args.data_dir, args.id) + + sim.run() + + +if __name__ == "__main__": + main() diff --git a/src/scgenerator/cli/new_config.py b/src/scgenerator/cli/new_config.py new file mode 100644 index 0000000..7d3c088 --- /dev/null +++ b/src/scgenerator/cli/new_config.py @@ -0,0 +1,61 @@ +from .. import const + + +def list_input(): + answer = "" + while answer == "": + answer = input("Please enter a list of values (one per line)\n") + + out = [process_input(answer)] + + while answer != "": + answer = input() + out.append(process_input(answer)) + + return out[:-1] + + +def process_input(s): + try: + return int(s) + except ValueError: + pass + + try: + return float(s) + except ValueError: + pass + + return s + + +def accept(question, default=True): + question += " ([y]/n)" if default else " (y/[n])" + question += "\n" + inp = input(question) + + yes_str = ["y", "yes"] + if default: + yes_str.append("") + + return inp.lower() in yes_str + + +def get(section, param_name): + question = f"Please enter a value for the parameter '{param_name}'\n" + valid = const.valid_param_types[section][param_name] + + is_valid = False + value = None + + while not is_valid: + answer = input(question) + if answer == "\\variable" and param_name in const.valid_variable[section]: + value = list_input() + print(value) + is_valid = all(valid(v) for v in value) + else: + value = process_input(answer) + is_valid = valid(value) + + return value \ No newline at end of file diff --git a/src/scgenerator/initialize.py b/src/scgenerator/initialize.py index c4c8608..29f3f0b 100644 --- a/src/scgenerator/initialize.py +++ b/src/scgenerator/initialize.py @@ -140,15 +140,15 @@ def tspace(time_window=None, t_num=None, dt=None): raise TypeError("not enough parameter to determine time vector") -def validate_single_parameter(parent, key, value): +def validate_single_parameter(section, key, value): try: - func = valid_param_types[parent][key] + func = valid_param_types[section][key] except KeyError: s = f"The parameter '{key}' does not belong " - if parent == "root": + if section == "root": s += "at the root of the config file" else: - s += f"in the category '{parent}'" + s += f"in the category '{section}'" s += ". Make sure it is a valid parameter in the first place" raise TypeError(s) if not func(value): @@ -178,11 +178,11 @@ def _validate_types(config): if param_name == "variable": for k_vary, v_vary in param_value.items(): if not isinstance(v_vary, list): - raise TypeError(f"Varying parameters should be specified in a list") + raise TypeError(f"Variable parameters should be specified in a list") if len(v_vary) < 1: raise ValueError( - f"Varying parameters lists should contain at least 1 element" + f"Variable parameters lists should contain at least 1 element" ) if k_vary not in valid_variable[domain]: diff --git a/src/scgenerator/io.py b/src/scgenerator/io.py index 71dec7a..7ad7567 100644 --- a/src/scgenerator/io.py +++ b/src/scgenerator/io.py @@ -411,11 +411,13 @@ def merge_same_simulations(path: str): logger.warning(f"could not send send {len(base_folders)} folder(s) to trash") -def get_data_folder(task_id: int, name_if_new: str = ""): +def get_data_folder(task_id: int, name_if_new: str = "data"): + if name_if_new == "": + name_if_new = "data" idstr = str(int(task_id)) tmp = os.getenv(TMP_FOLDER_KEY_BASE + idstr) if tmp is None: - tmp = ensure_folder("scgenerator_" + name_if_new + idstr) + tmp = ensure_folder("scgenerator " + name_if_new) os.environ[TMP_FOLDER_KEY_BASE + idstr] = tmp return tmp diff --git a/src/scgenerator/spectra.py b/src/scgenerator/spectra.py index 7a3bf41..b34ff08 100644 --- a/src/scgenerator/spectra.py +++ b/src/scgenerator/spectra.py @@ -5,11 +5,31 @@ from typing import Any, List, Tuple import numpy as np -from . import io +from . import io, initialize, math +from .plotting import units from .logger import get_logger -class Spectra(Sequence): +class Spectrum(np.ndarray): + def __new__(cls, input_array, wl, frep=1): + # Input array is an already formed ndarray instance + # We first cast to be our class type + obj = np.asarray(input_array).view(cls) + # add the new attribute to the created instance + obj.frep = frep + obj.wl = wl + # Finally, we must return the newly created object: + return obj + + def __array_finalize__(self, obj): + # see InfoArray.__array_finalize__ for comments + if obj is None: + return + self.frep = getattr(obj, "frep", None) + self.wl = getattr(obj, "wl", None) + + +class Pulse(Sequence): def __init__(self, path: str): self.logger = get_logger(__name__) self.path = path @@ -35,6 +55,12 @@ class Spectra(Sequence): if self.nmax <= 0: raise FileNotFoundError(f"No appropriate file in specified folder {self.path}") + self.t = self.params["t"] + w = initialize.wspace(self.t) + units.m(self.params["wavelength"]) + self.w_order = np.argsort(w) + self.w = w + self.wl = units.m.inv(self.w) + def __iter__(self): """ similar to all_spectra but works as an iterator @@ -42,7 +68,7 @@ class Spectra(Sequence): self.logger.debug(f"iterating through {self.path}") for i in range(self.nmax): - yield io.load_single_spectrum(self.path, i) + yield self._load1(i) def __len__(self): return self.nmax @@ -50,6 +76,74 @@ class Spectra(Sequence): def __getitem__(self, key): return self.all_spectra(ind=range(self.nmax)[key]) + def intensity(self, unit): + if unit.type in ["WL", "FREQ", "AFREQ"]: + x_axis = unit.inv(self.w) + else: + x_axis = unit.inv(self.t) + + order = np.argsort(x_axis) + func = dict( + WL=self._to_wl_int, + FREQ=self._to_freq_int, + AFREQ=self._to_afreq_int, + TIME=self._to_time_int, + )[unit.type] + + for spec in self: + yield x_axis[order], func(spec)[:, order] + + def _to_wl_int(self, spectrum): + return units.to_WL(math.abs2(spectrum), spectrum.frep, spectrum.wl) + + def _to_freq_int(self, spectrum): + return math.abs2(spectrum) + + def _to_afreq_int(self, spectrum): + return math.abs2(spectrum) + + def _to_time_int(self, spectrum): + return math.abs2(np.fft.ifft(spectrum)) + + def amplitude(self, unit): + if unit.type in ["WL", "FREQ", "AFREQ"]: + x_axis = unit.inv(self.w) + else: + x_axis = unit.inv(self.t) + + order = np.argsort(x_axis) + func = dict( + WL=self._to_wl_amp, + FREQ=self._to_freq_amp, + AFREQ=self._to_afreq_amp, + TIME=self._to_time_amp, + )[unit.type] + + for spec in self: + yield x_axis[order], func(spec)[:, order] + + def _to_wl_amp(self, spectrum): + return ( + np.sqrt( + units.to_WL( + math.abs2(spectrum), + spectrum.frep, + spectrum.wl, + ) + ) + * spectrum + / np.abs(spectrum) + ) + + def _to_freq_amp(self, spectrum): + return spectrum + + def _to_afreq_amp(self, spectrum): + return spectrum + + def _to_time_amp(self, spectrum): + return np.fft.ifft(spectrum) + def all_spectra(self, ind=None): """ loads the data already simulated. @@ -86,6 +180,11 @@ class Spectra(Sequence): return spectra.squeeze() + def _load1(self, i: int): + return Spectrum( + np.atleast_2d(io.load_single_spectrum(self.path, i)), self.wl, self.params["frep"] + ) + class SpectraCollection(Mapping, Sequence): def __init__(self, path: str):