From 24b83ec8b9376d6edadad336d52516bda545f3c2 Mon Sep 17 00:00:00 2001 From: alphasentaurii Date: Mon, 1 Apr 2024 17:55:46 -0400 Subject: [PATCH] import HstSvmRadio inside Hst class method --- spacekit/extractor/radio.py | 17 +++++++++-------- spacekit/preprocessor/scrub.py | 10 +++++++--- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/spacekit/extractor/radio.py b/spacekit/extractor/radio.py index ce79e37..5cc7cdf 100644 --- a/spacekit/extractor/radio.py +++ b/spacekit/extractor/radio.py @@ -6,20 +6,21 @@ import pandas as pd from spacekit.logger.log import Logger +try: + from astroquery.mast import Observations +except ImportError: + Observations = None + +try: + from progressbar import ProgressBar +except ImportError: + ProgressBar = None def check_astroquery(): - try: - from astroquery.mast import Observations - except ImportError: - Observations = None return Observations is not None def check_progressbar(): - try: - from progressbar import ProgressBar - except ImportError: - ProgressBar = None return ProgressBar is not None diff --git a/spacekit/preprocessor/scrub.py b/spacekit/preprocessor/scrub.py index 143ee1b..334bb8f 100644 --- a/spacekit/preprocessor/scrub.py +++ b/spacekit/preprocessor/scrub.py @@ -9,7 +9,6 @@ JwstFitsScraper, scrape_catalogs, ) -from spacekit.extractor.radio import HstSvmRadio from spacekit.preprocessor.encode import HstSvmEncoder, JwstEncoder, encode_booleans from spacekit.logger.log import Logger @@ -190,6 +189,11 @@ def __init__( self.make_subsamples = make_subsamples self.set_new_cols() self.set_prefix_cols() + self.initialize_radio() + + def initialize_radio(self): + from spacekit.extractor.radio import HstSvmRadio + self.radio = HstSvmRadio def preprocess_data(self): """Main calling function to run each preprocessing step for SVM regression data.""" @@ -200,7 +204,7 @@ def preprocess_data(self): n_retries = 3 while n_retries > 0: try: - self.df = HstSvmRadio(self.df).scrape_mast() + self.df = self.radio(self.df).scrape_mast() n_retries = 0 except Exception as e: self.log.warning(e) @@ -262,7 +266,7 @@ def scrub_qa_data(self): self.scrub_columns() # STAGE 2 initial encoding self.df = SvmFitsScraper(self.df, self.input_path).scrape_fits() - self.df = HstSvmRadio(self.df).scrape_mast() + self.df = self.radio(self.df).scrape_mast() def scrub_qa_summary(self, csvfile="single_visit_mosaics*.csv", idx=0): """Alternative if no .json files available (QA step not run during processing)"""