diff --git a/application/utils/spreadsheet.py b/application/utils/spreadsheet.py index 38136711..f1606405 100644 --- a/application/utils/spreadsheet.py +++ b/application/utils/spreadsheet.py @@ -272,8 +272,9 @@ def prepare_spreadsheet( def write_csv(docs: List[Dict[str, Any]]) -> io.StringIO: data = io.StringIO() - fieldnames: List[str] = list(docs[0].keys()) - writer: csv.DictWriter = csv.DictWriter(data, fieldnames=fieldnames) # type: ignore + fieldnames = {} + [fieldnames.update(d) for d in docs] + writer: csv.DictWriter = csv.DictWriter(data, fieldnames=fieldnames.keys()) # type: ignore writer.writeheader() writer.writerows(docs) return data diff --git a/application/web/web_main.py b/application/web/web_main.py index febb1059..56ece8c0 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -80,7 +80,8 @@ def extend_cre_with_tag_links( others = list(frozenset(others)) for o in others: o.links = [] - cre.add_link(defs.Link(ltype=defs.LinkTypes.Related, document=o)) + if not cre.link_exists(o) and o.id != cre.id: + cre.add_link(defs.Link(ltype=defs.LinkTypes.Related, document=o)) return cre @@ -137,7 +138,9 @@ def find_cre(creid: str = None, crename: str = None) -> Any: # refer return f"
{mdutils.cre_to_md([cre])}" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(docs=[cre]) + docs = sheet_utils.ExportSheet().prepare_spreadsheet( + docs=[cre], storage=database + ) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: @@ -209,7 +212,9 @@ def find_node_by_name(name: str, ntype: str = defs.Credoctypes.Standard.value) - return f"
{mdutils.cre_to_md(nodes)}" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(docs=nodes) + docs = sheet_utils.ExportSheet().prepare_spreadsheet( + docs=nodes, storage=database + ) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: @@ -242,7 +247,9 @@ def find_document_by_tag() -> Any: if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(docs=documents) + docs = sheet_utils.ExportSheet().prepare_spreadsheet( + docs=documents, storage=database + ) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents))) @@ -372,7 +379,9 @@ def text_search() -> Any: if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(docs=documents) + docs = sheet_utils.ExportSheet().prepare_spreadsheet( + docs=documents, storage=database + ) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents))) @@ -402,7 +411,9 @@ def find_root_cres() -> Any: if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}" elif opt_format == SupportedFormats.CSV.value: - docs = sheet_utils.prepare_spreadsheet(docs=documents) + docs = sheet_utils.ExportSheet().prepare_spreadsheet( + docs=documents, storage=database + ) return write_csv(docs=docs).getvalue().encode("utf-8") elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents)))