diff --git a/application/database/db.py b/application/database/db.py index e68292890..e70fd9084 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -171,7 +171,13 @@ class Embeddings(BaseModel): # type: ignore ), ) - +class GapAnalysisResults(BaseModel): + __tablename__ = "gap_analysis_results" + cache_key = sqla.Column(sqla.String,primary_key=True) + ga_object = sqla.Column(sqla.String) + __table_args__ = ( + sqla.UniqueConstraint(cache_key, name="unique_cache_key_field"), + ) class RelatedRel(StructuredRel): pass @@ -425,7 +431,6 @@ def link_CRE_to_Node(self, CRE_id, node_id, link_type): def gap_analysis(self, name_1, name_2): base_standard = NeoStandard.nodes.filter(name=name_1) denylist = ["Cross-cutting concerns"] - from pprint import pprint from datetime import datetime t1 = datetime.now() @@ -442,8 +447,6 @@ def gap_analysis(self, name_1, name_2): resolve_objects=True, ) t2 = datetime.now() - pprint(f"path records all took {t2-t1}") - pprint(path_records_all.__len__()) path_records, _ = db.cypher_query( """ OPTIONAL MATCH (BaseStandard:NeoStandard {name: $name1}) @@ -484,10 +487,6 @@ def format_path_record(rec): "end": NEO_DB.parse_node(rec.end_node), "path": [format_segment(seg, rec.nodes) for seg in rec.relationships], } - - pprint( - f"path records all took {t2-t1} path records took {t3 - t2}, total: {t3 - t1}" - ) return [NEO_DB.parse_node(rec) for rec in base_standard], [ format_path_record(rec[0]) for rec in (path_records + path_records_all) ] @@ -1634,8 +1633,21 @@ def add_embedding( self.session.commit() return existing - - + + def get_gap_analysis_result(self,cache_key): + res = self.session.query(GapAnalysisResults).filter(GapAnalysisResults.cache_key==cache_key).first() + if res: + return res.ga_object + + def add_gap_analysis_result(self,cache_key:str,ga_object: dict): + existing = self.get_gap_analysis_result(cache_key) + if not existing: + res = GapAnalysisResults(cache_key=cache_key,ga_object=flask_json.dumps(ga_object)) + self.session.add(res) + self.session.commit() + else: + return existing.ga_object + def dbNodeFromNode(doc: cre_defs.Node) -> Optional[Node]: if doc.doctype == cre_defs.Credoctypes.Standard: return dbNodeFromStandard(doc) @@ -1766,7 +1778,9 @@ def gap_analysis( node_names: List[str], store_in_cache: bool = False, cache_key: str = "", + ): + cre_db = Node_collection() base_standard, paths = neo_db.gap_analysis(node_names[0], node_names[1]) if base_standard is None: return None @@ -1809,16 +1823,21 @@ def gap_analysis( ): # lightweight memory option to not return potentially huge object and instead store in a cache, # in case this is called via worker, we save both this and the caller memory by avoiding duplicate object in mem - conn = redis.connect() + # conn = redis.connect() if cache_key == "": cache_key = make_array_hash(node_names) - conn.set(cache_key, flask_json.dumps({"result": grouped_paths})) + # conn.set(cache_key, flask_json.dumps({"result": grouped_paths})) + if cre_db: + cre_db.add_gap_analysis_result(cache_key=cache_key,ga_object={"result": extra_paths_dict[key]}) + for key in extra_paths_dict: - conn.set( - cache_key + "->" + key, - flask_json.dumps({"result": extra_paths_dict[key]}), - ) + if cre_db: + cre_db.add_gap_analysis_result(cache_key=cache_key,ga_object={"result": extra_paths_dict[key]}) + # conn.set( + # cache_key + "->" + key, + # flask_json.dumps({"result": extra_paths_dict[key]}), + # ) return (node_names, {}, {}) return (node_names, grouped_paths, extra_paths_dict) diff --git a/application/utils/redis.py b/application/utils/redis.py index 769d9d5f6..7fa47e645 100644 --- a/application/utils/redis.py +++ b/application/utils/redis.py @@ -5,13 +5,14 @@ def connect(): redis_url = os.getenv("REDIS_URL", "redis://localhost:6379") - - url = urlparse(redis_url) - r = redis.Redis( - host=url.hostname, - port=url.port, - password=url.password, - ssl=True, - ssl_cert_reqs=None, - ) - return r + if redis_url == "redis://localhost:6379": + return redis.from_url(redis_url) + else: + url = urlparse(redis_url) + return redis.Redis( + host=url.hostname, + port=url.port, + password=url.password, + ssl=True, + ssl_cert_reqs=None, + ) diff --git a/application/web/web_main.py b/application/web/web_main.py index 75deab6dd..a7ce01820 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -226,26 +226,29 @@ def gap_analysis() -> Any: standards = request.args.getlist("standard") conn = redis.connect() standards_hash = make_array_hash(standards) - if conn.exists(standards_hash): - gap_analysis_results = conn.get(standards_hash) - if gap_analysis_results: - gap_analysis_dict = json.loads(gap_analysis_results) - if gap_analysis_dict.get("result"): - return jsonify({"result": gap_analysis_dict.get("result")}) - elif gap_analysis_dict.get("job_id"): - try: - res = job.Job.fetch( - id=gap_analysis_dict.get("job_id"), connection=conn - ) - except exceptions.NoSuchJobError as nje: - abort(404, "No such job") - if ( - res.get_status() != job.JobStatus.FAILED - and res.get_status() == job.JobStatus.STOPPED - and res.get_status() == job.JobStatus.CANCELED - ): - logger.info("gap analysis job id already exists, returning early") - return jsonify({"job_id": gap_analysis_dict.get("job_id")}) + result = database.get_gap_analysis_result(standards_hash) + if result: + gap_analysis_dict = json.loads(result) + if gap_analysis_dict.get("results") + return jsonify({"result": gap_analysis_dict.get("result")}) + + gap_analysis_results = conn.get(standards_hash) + if gap_analysis_results: + gap_analysis_dict = json.loads(gap_analysis_results) + if gap_analysis_dict.get("job_id"): + try: + res = job.Job.fetch( + id=gap_analysis_dict.get("job_id"), connection=conn + ) + except exceptions.NoSuchJobError as nje: + abort(404, "No such job") + if ( + res.get_status() != job.JobStatus.FAILED + and res.get_status() == job.JobStatus.STOPPED + and res.get_status() == job.JobStatus.CANCELED + ): + logger.info("gap analysis job id already exists, returning early") + return jsonify({"job_id": gap_analysis_dict.get("job_id")}) q = Queue(connection=conn) gap_analysis_job = q.enqueue_call( db.gap_analysis, @@ -269,12 +272,20 @@ def gap_analysis_weak_links() -> Any: conn = redis.connect() standards_hash = make_array_hash(standards) cache_key = standards_hash + "->" + key - if conn.exists(cache_key): - gap_analysis_results = conn.get(cache_key) - if gap_analysis_results: - gap_analysis_dict = json.loads(gap_analysis_results) - if gap_analysis_dict.get("result"): - return jsonify({"result": gap_analysis_dict.get("result")}) + + database = db.Node_collection() + result = database.get_gap_analysis_result(cache_key=cache_key) + if result: + gap_analysis_dict = json.loads(gap_analysis_results) + if gap_analysis_dict.get("result"): + return jsonify({"result": gap_analysis_dict.get("result")}) + + # if conn.exists(cache_key): + # gap_analysis_results = conn.get(cache_key) + # if gap_analysis_results: + # gap_analysis_dict = json.loads(gap_analysis_results) + # if gap_analysis_dict.get("result"): + # return jsonify({"result": gap_analysis_dict.get("result")}) abort(404, "No such Cache") @@ -315,7 +326,9 @@ def fetch_job() -> Any: if conn.exists(standards_hash): logger.info("and hash is already in cache") - ga = conn.get(standards_hash) + # ga = conn.get(standards_hash) + database = db.Node_collection() + ga = database.get_gap_analysis_result(standards_hash) if ga: logger.info("and results in cache") ga = json.loads(ga)