Skip to content

Commit

Permalink
add support for new format in main
Browse files Browse the repository at this point in the history
  • Loading branch information
northdpole committed Aug 11, 2024
1 parent ac058ce commit 9e7d6a6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
30 changes: 13 additions & 17 deletions application/cmd/cre_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def parse_file(

def register_standard(
standard_entries: List[defs.Standard],
collection: db.Node_collection,
collection: db.Node_collection = None,
generate_embeddings=True,
calculate_gap_analysis=True,
db_connection_str: str = "",
Expand All @@ -237,15 +237,17 @@ def register_standard(
generate_embeddings = False

if not standard_entries:
logger.warning("register_standard() calleed with no standard_entries")
logger.warning("register_standard() called with no standard_entries")
return
if not collection:

if collection is None:
collection = db_connect(path=db_connection_str)

conn = redis.connect()
ph = prompt_client.PromptHandler(database=collection)
importing_name = standard_entries[0].name
standard_hash = gap_analysis.make_resources_key([importing_name])
if conn.get(standard_hash):
if calculate_gap_analysis and conn.get(standard_hash):
logger.info(
f"Standard importing job with info-hash {standard_hash} has already returned, skipping"
)
Expand All @@ -267,11 +269,12 @@ def register_standard(
generate_embeddings = False
if generate_embeddings and importing_name:
ph.generate_embeddings_for(importing_name)
populate_neo4j_db(collection)
# calculate gap analysis
jobs = []
pending_stadards = collection.standards()

if calculate_gap_analysis and not os.environ.get("CRE_NO_CALCULATE_GAP_ANALYSIS"):
# calculate gap analysis
populate_neo4j_db(collection)
jobs = []
pending_stadards = collection.standards()
for standard_name in pending_stadards:
if standard_name == importing_name:
continue
Expand Down Expand Up @@ -316,12 +319,7 @@ def parse_standards_from_spreadsheeet(
) -> None:
"""given a yaml with standards, build a list of standards in the db"""
collection = db_connect(cache_location)
if "CRE:name" in cre_file[0].keys():
collection = collection.with_graph()
documents = spreadsheet_parsers.parse_export_format(cre_file)
register_cre(documents, collection)
pass
elif any(key.startswith("CRE hierarchy") for key in cre_file[0].keys()):
if any(key.startswith("CRE hierarchy") for key in cre_file[0].keys()):
conn = redis.connect()
collection = collection.with_graph()
redis.empty_queues(conn)
Expand Down Expand Up @@ -649,9 +647,7 @@ def create_spreadsheet(
) -> Any:
"""Reads cre docs exported from a standards_collection.export()
dumps each doc into a workbook"""
flat_dicts = sheet_utils.prepare_spreadsheet(
collection=collection, docs=exported_documents
)
flat_dicts = sheet_utils.prepare_spreadsheet(docs=exported_documents)
return sheet_utils.write_spreadsheet(
title=title, docs=flat_dicts, emails=share_with
)
Expand Down
2 changes: 1 addition & 1 deletion application/database/inmemory_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def add_node(self, *args, **kwargs):
def get_hierarchy(self, rootIDs: List[str], creID: str):
if creID in rootIDs:
return 0

if self.__parent_child_subgraph == None:
if len(self.graph.edges) == 0:
raise ValueError("Graph has no edges")
Expand Down

0 comments on commit 9e7d6a6

Please sign in to comment.