diff --git a/application/database/inmemory_graph.py b/application/database/inmemory_graph.py index 1284cfc8..ecafb97d 100644 --- a/application/database/inmemory_graph.py +++ b/application/database/inmemory_graph.py @@ -46,26 +46,35 @@ def get_hierarchy(self, rootIDs: List[str], creID: str): ): include_cres.append(el[0]) include_cres.append(el[1]) - + for el in rootIDs: - if el not in include_cres: # If the root is not in the parent/children graph, add it to prevent an error and continue, there is not path to our CRE anyway + if ( + el not in include_cres + ): # If the root is not in the parent/children graph, add it to prevent an error and continue, there is not path to our CRE anyway include_cres.append(f"CRE: {el}") - self.__parent_child_subgraph = self.graph.subgraph(set(include_cres) ) - + self.__parent_child_subgraph = self.graph.subgraph(set(include_cres)) + shortest_path = sys.maxsize - for root in rootIDs: + for root in rootIDs: try: shortest_path = min( shortest_path, len( nx.shortest_path( - self.__parent_child_subgraph, f"CRE: {root}", f"CRE: {creID}" + self.__parent_child_subgraph, + f"CRE: {root}", + f"CRE: {creID}", ) - )-1, + ) + - 1, ) - except nx.NodeNotFound as nnf: # If the CRE is not in the parent/children graph it means that it's a lone CRE, so it's a root and we return 0 + except ( + nx.NodeNotFound + ) as nnf: # If the CRE is not in the parent/children graph it means that it's a lone CRE, so it's a root and we return 0 return 0 - except nx.NetworkXNoPath as nxnp: # If there is no path to the CRE, continue + except ( + nx.NetworkXNoPath + ) as nxnp: # If there is no path to the CRE, continue continue return shortest_path diff --git a/application/frontend/src/pages/Explorer/explorer.tsx b/application/frontend/src/pages/Explorer/explorer.tsx index 4836e2df..08a29d22 100644 --- a/application/frontend/src/pages/Explorer/explorer.tsx +++ b/application/frontend/src/pages/Explorer/explorer.tsx @@ -8,9 +8,9 @@ import { LoadingAndErrorIndicator } from '../../components/LoadingAndErrorIndica import { TYPE_CONTAINS, TYPE_LINKED_TO } from '../../const'; import { useDataStore } from '../../providers/DataProvider'; import { LinkedTreeDocument, TreeDocument } from '../../types'; -import { LinkedStandards } from './LinkedStandards'; import { getDocumentDisplayName } from '../../utils'; import { getInternalUrl } from '../../utils/document'; +import { LinkedStandards } from './LinkedStandards'; export const Explorer = () => { const { dataLoading, dataTree } = useDataStore(); diff --git a/application/web/web_main.py b/application/web/web_main.py index a1da2ac5..febb1059 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -137,7 +137,7 @@ def find_cre(creid: str = None, crename: str = None) -> Any: # refer return f"
{mdutils.cre_to_md([cre])}
" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(collection=database, docs=[cre]) + docs = sheet_utils.prepare_spreadsheet(docs=[cre]) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: @@ -209,7 +209,7 @@ def find_node_by_name(name: str, ntype: str = defs.Credoctypes.Standard.value) - return f"
{mdutils.cre_to_md(nodes)}
" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(collection=database, docs=nodes) + docs = sheet_utils.prepare_spreadsheet(docs=nodes) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: @@ -242,7 +242,7 @@ def find_document_by_tag() -> Any: if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}
" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(collection=database, docs=documents) + docs = sheet_utils.prepare_spreadsheet(docs=documents) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents))) @@ -372,7 +372,7 @@ def text_search() -> Any: if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}
" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(collection=database, docs=documents) + docs = sheet_utils.prepare_spreadsheet(docs=documents) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents))) @@ -402,7 +402,7 @@ def find_root_cres() -> Any: if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}
" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(collection=database, docs=documents) + docs = sheet_utils.prepare_spreadsheet(docs=documents) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents))) @@ -721,30 +721,48 @@ def get_cre_csv() -> Any: @app.route("/rest/v1/cre_csv_import", methods=["POST"]) def import_from_cre_csv() -> Any: + if not os.environ.get("CRE_ALLOW_IMPORT"): + abort( + 403, + "Importing is disabled, set the environment variable CRE_ALLOW_IMPORT to allow this functionality", + ) + # TODO: (spyros) add optional gap analysis and embeddings calculation database = db.Node_collection().with_graph() file = request.files.get("cre_csv") + calculate_embeddings = ( + False if not request.args.get("calculate_embeddings") else True + ) + calculate_gap_analysis = ( + False if not request.args.get("calculate_gap_analysis") else True + ) + if file is None: abort(400, "No file provided") contents = file.read() csv_read = csv.DictReader(contents.decode("utf-8").splitlines()) + documents = spreadsheet_parsers.parse_export_format(list(csv_read)) + cres = documents.pop(defs.Credoctypes.CRE.value) - cres, standards = spreadsheet_parsers.parse_export_format(list(csv_read)) - + standards = documents new_cres = [] for cre in cres: new_cre, exists = cre_main.register_cre(cre, database) if not exists: new_cres.append(new_cre) - for standard in standards: - cre_main.register_node(collection=database, node=standard) - + for _, entries in standards.items(): + cre_main.register_standard( + collection=database, + standard_entries=list(entries), + generate_embeddings=calculate_embeddings, + calculate_gap_analysis=calculate_gap_analysis, + ) return jsonify( { "status": "success", "new_cres": [c.external_id for c in new_cres], - "new_standard_entries": len(standards), + "new_standards": len(standards), } )