added overwrite functionality

This commit is contained in:
Benoît Sierro
2023-08-09 11:30:30 +02:00
parent 7bb15871c3
commit ea8bc0360d
2 changed files with 22 additions and 1 deletions

View File

@@ -191,6 +191,7 @@ def propagation(
file_or_params: os.PathLike | Parameters,
params: Parameters | None = None,
bundle_data: bool = False,
overwrite: bool = False,
) -> Propagation:
file = None
if isinstance(file_or_params, Parameters):
@@ -198,7 +199,7 @@ def propagation(
else:
file = Path(file_or_params)
if file is not None and file.exists():
if file is not None and file.exists() and params is None:
io = ZipFileIOHandler(file)
return _open_existing_propagation(io)
@@ -206,6 +207,11 @@ def propagation(
raise ValueError("Parameters must be specified to create new simulation")
if file is not None:
if file.exists() and params is not None:
if overwrite:
Path(file).unlink()
else:
raise FileExistsError(f"{file} already exists. use overwrite=True to overwrite")
io = ZipFileIOHandler(file)
else:
io = MemoryIOHandler()

View File

@@ -101,6 +101,21 @@ def test_clear(tmp_path: Path):
assert not zpath.exists()
def test_overwrite(tmp_path: Path):
params = Parameters(**PARAMS)
zpath = tmp_path / "file.zip"
_ = propagation(zpath, params)
orig_file = zpath.read_bytes()
with pytest.raises(FileExistsError):
_ = propagation(zpath, params)
_ = propagation(zpath, params, overwrite=True)
assert zpath.read_bytes() != orig_file
assert len(zpath.read_bytes()) == len(orig_file)
def test_zip_bundle(tmp_path: Path):
params = Parameters(**PARAMS)