Skip to content

Commit

Permalink
Refactoring in main executable
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Westphal committed Jun 21, 2024
1 parent e772a6c commit 684e2c9
Showing 1 changed file with 13 additions and 33 deletions.
46 changes: 13 additions & 33 deletions bin/infermapping
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ from argparse import ArgumentParser
import logging
from typing import List, Type

from pyvis.network import Network

from semanticlabeling.labeledcolumn import YetUnknownTypeColumn
from semanticlabeling.labelinferencer import SemanticLabelInferencer
from util.file import InputFile
from util import graphbuilder, graphvisualizer
from util.knowledgesource import KnowledgeSource

logging.basicConfig(level=logging.INFO)
Expand All @@ -17,13 +17,15 @@ logger = logging.getLogger(__name__)
def main(
input_file_path: str,
input_file_cls: Type[InputFile],
target_ontology_paths: List[str]
target_ontology_paths: List[str],
visualize: bool
):
logger.info(
f'Semantic label inferencing called with input file {input_file_path} '
f'and target ontologies {" ".join(target_ontology_paths)}')

ontologies = list([KnowledgeSource(path) for path in target_ontology_paths])

# input_file holds a list of labeled columns:
# (Pdb) pp(input_file.__dict__)
# {'columns': [<semanticlabeling.labeledcolumn.IDColumn object at 0x12c737080>,
Expand Down Expand Up @@ -72,51 +74,28 @@ def main(
# )
# ]
input_file = input_file_cls(input_file_path=input_file_path)

label_inferencer = SemanticLabelInferencer(input_file, ontologies)
graph = graphbuilder.build_from_label_inferencer(label_inferencer)

if visualize:
graphvisualizer.visualize(graph)

network = Network()
for knowledge_source in ontologies:
for column_name, labeled_column in knowledge_source.columns.items():
if not isinstance(labeled_column, YetUnknownTypeColumn):
node_label = str(labeled_column)
network.add_node(column_name, label=node_label, color='green')

for column_name, labeled_column in knowledge_source.columns.items():
if not isinstance(labeled_column, YetUnknownTypeColumn):
source_node_id = column_name
for link_name, targets in labeled_column.links.items():
for target in targets:
target_node_id = target.column_name
network.add_edge(source_node_id, target_node_id, title=link_name)

for labeled_column in label_inferencer.get_labeled_columns():
node_label = str(labeled_column)
network.add_node(labeled_column.column_name, label=node_label, color='red')

for labeled_column in label_inferencer.get_labeled_columns():
source_node_id = labeled_column.column_name
for link_name, targets in labeled_column.links.items():
for target in targets:
target_node_id = target.column_name
network.add_edge(source_node_id, target_node_id, title=link_name)
import pdb; pdb.set_trace()

network.show('network.html', notebook=False, )

pass
raise NotImplementedError()


if __name__ == '__main__':
arg_parser = ArgumentParser()
arg_parser.add_argument('input_file')

arg_parser.add_argument(
'--filetype',
default='csv',
help='csv or sampled_csv'
)

arg_parser.add_argument('target_ontologies', nargs='+')
arg_parser.add_argument('--visualize', action='store_true')

args = arg_parser.parse_args()

Expand All @@ -128,5 +107,6 @@ if __name__ == '__main__':
main(
input_file_path=input_file_path,
input_file_cls=input_file_cls,
target_ontology_paths=target_ontology_paths
target_ontology_paths=target_ontology_paths,
visualize=args.visualize
)

0 comments on commit 684e2c9

Please sign in to comment.