diff --git a/examples/dudley2006.py b/src/scgenerator/examples/dudley2006.py similarity index 82% rename from examples/dudley2006.py rename to src/scgenerator/examples/dudley2006.py index 3b24cbe..8d91ab3 100644 --- a/examples/dudley2006.py +++ b/src/scgenerator/examples/dudley2006.py @@ -1,3 +1,4 @@ +import click import colorcet as cc import matplotlib.pyplot as plt import numpy as np @@ -6,6 +7,7 @@ from plotapp import PlotApp import scgenerator as sc params = sc.Parameters( + name="Dudley2006 example", wavelength=835e-9, width=50e-15, peak_power=10e3, @@ -35,14 +37,12 @@ params = sc.Parameters( ) -def compute_manual(): +def compute_manual(save: bool): spec0 = params.compute("spec_0") w_c, w0, gamma = params.compute("w_c", "w0", "gamma") p = params.compile() print(p.dt) - beta_op = sc.operators.constant_polynomial_dispersion( - params.beta2_coefficients, w_c, params.compute("dispersion_ind") - ) + beta_op = sc.operators.constant_polynomial_dispersion(params.beta2_coefficients, w_c) linear = sc.operators.envelope_linear_operator( beta_op, # sc.operators.constant_quantity(0), @@ -94,7 +94,10 @@ def compute_manual(): def linear(_): return linear_arr - prop = sc.propagation("examples/dudley_manual.zip", params) + if save: + prop = sc.propagation("examples/dudley_manual.zip", params) + else: + prop = sc.propagation(params) z = [] for i, (spec, stat) in enumerate( sc.solve43(spec0, linear, nonlinear, params.length, 1e-6, 1e-6, 0.9, h_const=20e-6) @@ -106,8 +109,11 @@ def compute_manual(): z.append(stat["z"]) -def compute_auto(): - sc.compute(params, True, "examples/dudley2006") +def compute_auto(save: bool): + if save: + sc.compute(params, True, "examples/dudley2006") + else: + sc.compute(params) def plot(): @@ -120,6 +126,15 @@ def plot(): plt.show() +@click.command() +@click.option("--show/--no-show", default=False) +@click.option("--save/--no-save", default=False) +def main(show: bool, save: bool): + compute_manual(save) + compute_auto(save) + if show: + plot() + + if __name__ == "__main__": - # compute_auto() - plot() + main() diff --git a/src/scgenerator/helpers.py b/src/scgenerator/helpers.py index a1f0cea..7f85999 100644 --- a/src/scgenerator/helpers.py +++ b/src/scgenerator/helpers.py @@ -1,6 +1,7 @@ """ series of helper functions """ + import os import warnings from pathlib import Path @@ -166,28 +167,25 @@ def extend_axis(axis: np.ndarray) -> np.ndarray: def compute( parameters: Parameters, overwrite: bool = False, output: os.PathLike | None = None ) -> Propagation: + prop_params = parameters.compile() if output is None: - name = Path(parameters.compute("name")).stem + ".zip" + prop = propagation(prop_params) else: path = Path(output) name = path.parent / (path.stem + ".zip") - - prop_params = parameters.compile() - prop = propagation(name, prop_params, bundle_data=True, overwrite=overwrite) + prop = propagation(name, prop_params, bundle_data=True, overwrite=overwrite) with warnings.catch_warnings(), tqdm(total=prop_params.z_num) as pbar: warnings.filterwarnings("error") - for i, (spec, new_stat) in enumerate( - solve43( - prop_params.spec_0, - prop_params.linear_operator, - prop_params.nonlinear_operator, - prop_params.length, - prop_params.tolerated_error, - prop_params.tolerated_error, - 0.9, - targets=prop_params.z_targets, - ) + for spec, _ in solve43( + prop_params.spec_0, + prop_params.linear_operator, + prop_params.nonlinear_operator, + prop_params.length, + prop_params.tolerated_error, + prop_params.tolerated_error, + 0.9, + targets=prop_params.z_targets, ): pbar.update() prop.append(spec) diff --git a/src/scgenerator/math.py b/src/scgenerator/math.py index 6457c91..cdd354b 100644 --- a/src/scgenerator/math.py +++ b/src/scgenerator/math.py @@ -22,16 +22,10 @@ c = 299792458.0 def fft_functions( full_field: bool, ) -> tuple[Callable[[np.ndarray], np.ndarray], Callable[[np.ndarray], np.ndarray]]: - if platform.processor() == "arm": - if full_field: - return sfft.rfft, sfft.irfft - else: - return sfft.fft, sfft.ifft + if full_field: + return sfft.rfft, sfft.irfft else: - if full_field: - return np.fft.rfft, np.fft.irfft - else: - return np.fft.fft, np.fft.ifft + return sfft.fft, sfft.ifft def expm1_int(y: np.ndarray, dx: float) -> np.ndarray: