Skip to content

Commit

Permalink
Allow more path customization for feature generation/inference (#493)
Browse files Browse the repository at this point in the history
* Add path_to_features to config

* Extend path_to_preds and path_to_features to inference code

* Update GCN cronjob

* Update generate_features.py

* Update config default for ids_skipgaia file

* Remove code.interact()
  • Loading branch information
bfhealy authored Sep 29, 2023
1 parent 029d1b4 commit 3ef9c54
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 20 deletions.
4 changes: 3 additions & 1 deletion config.defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1358,13 +1358,15 @@ feature_stats:

# Below, specify algorithms, external catalogs and features to include in generated feature lists:
feature_generation:
# Path to save generated features
path_to_features:
# If --doSpecificIDs is set in generate_features.py, the script will generate features for the below dataset instead of a field/ccd/quad.
# Dataset must contain columns named "ztf_id" and "coordinates" with data in the format of these fields on Kowalski
# Default dataset is the training set downloadable from Fritz
dataset: tools/fritzDownload/merged_classifications_features.parquet
# Once generate_features.py is run with --doSpecificIDs, a file will be saved with the default name below.
# Set --skipCloseSources to load the file below and skip the idenfication of close sources:
ids_skipGaia: tools/fritzDownload/specific_ids_dropCloseSources.json
ids_skipGaia: specific_ids_dropCloseSources.json
period_algorithms:
CPU:
- LS
Expand Down
6 changes: 3 additions & 3 deletions gcn_cronjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,18 +201,18 @@ def query_gcn_events(
if (not features_file.exists()) | (has_new_sources):
print("Generating features on Expanse...")
os.system(
f"scp {filepath} {username}@login.expanse.sdsc.edu:/home/{username}/scope/tools/fritzDownload/."
f"scp {filepath} {username}@login.expanse.sdsc.edu:/expanse/lustre/projects/umn131/{username}/{generated_features_dirname}/fg_sources/."
)
os.system(
f'ssh -tt {username}@login.expanse.sdsc.edu \
"source .bash_profile && \
cd scope/{generated_features_dirname}/slurm && \
cd /expanse/lustre/projects/umn131/{username}/{generated_features_dirname}/slurm && \
sbatch --wait --export=DOBS={save_dateobs},DS={filepath.name} {partition}_slurm.sub"'
)
print("Finished generating features on Expanse.")

os.system(
f"rsync -avh {username}@login.expanse.sdsc.edu:/home/{username}/scope/{generated_features_dirname} {BASE_DIR}/."
f"rsync -avh {username}@login.expanse.sdsc.edu:/expanse/lustre/projects/umn131/{username}/{generated_features_dirname} {BASE_DIR}/."
)

if features_file.exists():
Expand Down
16 changes: 12 additions & 4 deletions tools/generate_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@
ext_catalog_info = config['feature_generation']['external_catalog_features']
cesium_feature_list = config['feature_generation']['cesium_features']
period_algorithms = config['feature_generation']['period_algorithms']
path_to_features = config['feature_generation']['path_to_features']

if path_to_features is not None:
BASE_DIR = pathlib.Path(path_to_features)

kowalski_instances = Kowalski(timeout=timeout, instances=instances)

Expand All @@ -87,7 +91,8 @@ def drop_close_bright_stars(
limit: int = 10000,
Ncore: int = 8,
save: bool = False,
save_filename: str = 'tools/fritzDownload/specific_ids_dropCloseSources.json',
save_directory: str = 'generated_features',
save_filename: str = 'specific_ids_dropCloseSources.json',
):
"""
Use Gaia to identify and drop sources that are too close to bright stars
Expand All @@ -103,7 +108,8 @@ def drop_close_bright_stars(
:param limit: if doSpecificIDs is set, max number of sources to be queries in one batch (int)
:param Ncore: if doSpecificIDs is set, number of cores over which to parallelize queries (int)
:param save: if set, save sources passing the close source analysis (bool)
:param save_filename: path/name from BASE_DIR to save sources (str)
:param save_directory: directory within BASE_DIR to save sources (str)
:param save_filename: filename to use when saving sources (str)
:return id_dct_keep: dictionary containing subset of input sources far enough away from bright stars
"""
Expand Down Expand Up @@ -383,7 +389,8 @@ def drop_close_bright_stars(
id_dct_keep = id_dct

if save:
with open(str(BASE_DIR / save_filename), 'w') as f:
os.makedirs(BASE_DIR / save_directory, exist_ok=True)
with open(str(BASE_DIR / save_directory / save_filename), 'w') as f:
json.dump(id_dct_keep, f)

print(f"Dropped {len(id_dct) - len(id_dct_keep)} sources.")
Expand Down Expand Up @@ -591,7 +598,7 @@ def generate_features(
else:
# Load pre-saved dataset if Gaia analysis already complete
fg_sources_config = config['feature_generation']['ids_skipGaia']
fg_sources_path = str(BASE_DIR / fg_sources_config)
fg_sources_path = str(BASE_DIR / dirname / fg_sources_config)

if fg_sources_path.endswith('.json'):
with open(fg_sources_path, 'r') as f:
Expand Down Expand Up @@ -643,6 +650,7 @@ def generate_features(
limit=limit,
Ncore=Ncore,
save=not doNotSave,
save_directory=dirname,
)

else:
Expand Down
3 changes: 3 additions & 0 deletions tools/generate_features_job_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
config = yaml.load(config_yaml, Loader=yaml.FullLoader)

fields_to_run = config['feature_generation']['fields_to_run']
path_to_features = config['feature_generation']['path_to_features']
if path_to_features is not None:
BASE_DIR = pathlib.Path(path_to_features)


def parse_commandline():
Expand Down
4 changes: 4 additions & 0 deletions tools/generate_features_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
gaia_catalog = config['kowalski']['collections']['gaia']
ext_catalog_info = config['feature_generation']['external_catalog_features']
cesium_feature_list = config['feature_generation']['cesium_features']
path_to_features = config['feature_generation']['path_to_features']

if path_to_features is not None:
BASE_DIR = pathlib.Path(path_to_features)


def check_quads_for_sources(
Expand Down
25 changes: 17 additions & 8 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,22 @@
warnings.filterwarnings('ignore')

BASE_DIR = pathlib.Path(__file__).parent.parent.absolute()
BASE_DIR_FEATS = pathlib.Path(__file__).parent.parent.absolute()
BASE_DIR_PREDS = pathlib.Path(__file__).parent.parent.absolute()
JUST = 50


config_path = BASE_DIR / "config.yaml"
with open(config_path) as config_yaml:
config = yaml.load(config_yaml, Loader=yaml.FullLoader)

path_to_features = config['feature_generation']['path_to_features']
path_to_preds = config['inference']['path_to_preds']

if path_to_features is not None:
BASE_DIR_FEATS = pathlib.Path(path_to_features)
if path_to_preds is not None:
BASE_DIR_PREDS = pathlib.Path(path_to_preds)

period_suffix_config = config['features']['info']['period_suffix']

# Load training set
Expand Down Expand Up @@ -114,7 +123,7 @@ def clean_data(
# file to store flagged ids and features with missing values
if not whole_field:
filename = (
str(BASE_DIR)
str(BASE_DIR_PREDS)
+ f"/preds_{algorithm}/field_"
+ str(field)
+ "/ccd_"
Expand All @@ -125,7 +134,7 @@ def clean_data(
)
else:
filename = (
str(BASE_DIR)
str(BASE_DIR_PREDS)
+ f"/preds_{algorithm}/field_"
+ str(field)
+ f"/field_{field}_flagged.json"
Expand Down Expand Up @@ -263,19 +272,19 @@ def run_inference(
if not int_field:
if 'specific_ids' in field:
default_features_file = str(
BASE_DIR
BASE_DIR_FEATS
/ f"{feature_directory}/specific_ids/gen_gcn_features_{field}.parquet"
)
else:
# default file location for source ids
if whole_field:
default_features_file = (
str(BASE_DIR) + f"/{feature_directory}/field_" + str(field)
str(BASE_DIR_FEATS) + f"/{feature_directory}/field_" + str(field)
)
else:
if feature_directory == 'features':
default_features_file = (
str(BASE_DIR)
str(BASE_DIR_FEATS)
+ f"/{feature_directory}/field_"
+ str(field)
+ "/ccd_"
Expand All @@ -286,7 +295,7 @@ def run_inference(
)
else:
default_features_file = (
str(BASE_DIR)
str(BASE_DIR_FEATS)
+ f"/{feature_directory}/field_"
+ str(field)
+ f"/{feature_file_prefix}_"
Expand All @@ -302,7 +311,7 @@ def run_inference(
features_filename = kwargs.get("features_filename", default_features_file)

out_dir = os.path.join(
os.path.dirname(__file__), f"{str(BASE_DIR)}/preds_{algorithm}/"
os.path.dirname(__file__), f"{str(BASE_DIR_PREDS)}/preds_{algorithm}/"
)

if not whole_field:
Expand Down
11 changes: 8 additions & 3 deletions tools/run_inference_job_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
import numpy as np

BASE_DIR = pathlib.Path(__file__).parent.parent.absolute()
BASE_DIR_PREDS = pathlib.Path(__file__).parent.parent.absolute()

# Read config file
config_path = BASE_DIR / "config.yaml"
with open(config_path) as config_yaml:
config = yaml.load(config_yaml, Loader=yaml.FullLoader)

path_to_preds = config['inference']['path_to_preds']
if path_to_preds is not None:
BASE_DIR_PREDS = pathlib.Path(path_to_preds)


def parse_commandline():
"""
Expand All @@ -23,7 +28,7 @@ def parse_commandline():
"--dirname",
type=str,
default='inference',
help="Directory name for training slurm scripts",
help="Directory name for inference slurm scripts",
)
parser.add_argument(
"-f", "--filetype", default="slurm", help="Type of job submission file"
Expand All @@ -50,7 +55,7 @@ def filter_completed(fields, algorithm):
fields_copy = fields.copy()

for field in fields:
searchDir = BASE_DIR / f'preds_{algorithm}' / f'field_{field}'
searchDir = BASE_DIR_PREDS / f'preds_{algorithm}' / f'field_{field}'
searchDir.mkdir(parents=True, exist_ok=True)
generator = searchDir.iterdir()
has_parquet = np.sum([x.suffix == '.parquet' for x in generator]) > 0
Expand Down Expand Up @@ -78,7 +83,7 @@ def run_job(field):
filetype = args.filetype
dirname = args.dirname

slurmDir = str(BASE_DIR / dirname)
slurmDir = str(BASE_DIR_PREDS / dirname)

fields = config['inference']['fields_to_run']
algorithm = args.algorithm
Expand Down
6 changes: 5 additions & 1 deletion tools/run_inference_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@


BASE_DIR = pathlib.Path(__file__).parent.parent.absolute()
BASE_DIR_PREDS = pathlib.Path(__file__).parent.parent.absolute()

config_path = BASE_DIR / "config.yaml"
with open(config_path) as config_yaml:
config = yaml.load(config_yaml, Loader=yaml.FullLoader)

path_to_preds = config['inference']['path_to_preds']
if path_to_preds is not None:
BASE_DIR_PREDS = pathlib.Path(path_to_preds)

if __name__ == "__main__":

Expand Down Expand Up @@ -141,7 +145,7 @@
dirname = f"{algorithm}_{args.dirname}"
jobname = f"{args.job_name}_{algorithm}"

dirpath = BASE_DIR / dirname
dirpath = BASE_DIR_PREDS / dirname
os.makedirs(dirpath, exist_ok=True)

slurmDir = os.path.join(dirpath, 'slurm')
Expand Down

0 comments on commit 3ef9c54

Please sign in to comment.