Skip to content

Commit

Permalink
Update estimators.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles committed Jun 17, 2024
1 parent 4903705 commit f4013d2
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions artistools/estimators/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,42 @@ def get_rankbatch_parquetfile(
parquetfilepath = folderpath / parquetfilename

if not parquetfilepath.exists():
generate_parquet = True
elif next(folderpath.glob("estimators_????.out*")).stat().st_mtime > parquetfilepath.stat().st_mtime:
print(
f" {parquetfilepath.relative_to(modelpath.parent)} is older than the estimator text files. File will be deleted and regenerated..."
)
parquetfilepath.unlink()
generate_parquet = True
else:
generate_parquet = False

if generate_parquet:
print(f" generating {parquetfilepath.relative_to(modelpath.parent)}...")
estfilepaths = []
for mpirank in batch_mpiranks:
# not worth printing an error, because ranks with no cells to update do not produce an estimator file
with contextlib.suppress(FileNotFoundError):
estfilepath = at.firstexisting(f"estimators_{mpirank:04d}.out", folder=folderpath, tryzipped=True)
estfilepaths.append(estfilepath)
estfilepaths.append(
at.firstexisting(f"estimators_{mpirank:04d}.out", folder=folderpath, tryzipped=True)
)

time_start = time.perf_counter()

try:
from artistools.rustext import estimparse as rustestimparse
except ImportError:
print("WARNING: Rust extension not available. Falling back to slow python reader.")
use_rust = False

print(
f" reading {len(estfilepaths)} estimator files from {folderpath.relative_to(Path(folderpath).parent)}{' with rust compiled function' if use_rust else ''}...",
f" reading {len(estfilepaths)} estimator files in {folderpath.relative_to(Path(folderpath).parent)} with {'fast rust reader' if use_rust else 'slow python reader'}...",
end="",
flush=True,
)

pldf_batch: pl.DataFrame
if use_rust:
from artistools.rustext import estimparse as rustestimparse

pldf_batch = rustestimparse(str(folderpath), min(batch_mpiranks), max(batch_mpiranks))
pldf_batch = pldf_batch.with_columns(
pl.col(c).cast(pl.Int32)
Expand All @@ -267,7 +283,7 @@ def get_rankbatch_parquetfile(
)
)
print(
f"took {time.perf_counter() - time_start:.1f} s. Writing {parquetfilepath.relative_to(modelpath.parent)}...",
f"took {time.perf_counter() - time_start:.1f} s. Writing parquet file...",
end="",
flush=True,
)
Expand Down

0 comments on commit f4013d2

Please sign in to comment.