diff --git a/bin/infermapping b/bin/infermapping index d580858..188273a 100644 --- a/bin/infermapping +++ b/bin/infermapping @@ -3,11 +3,11 @@ from argparse import ArgumentParser import logging from typing import List, Type -from pyvis.network import Network from semanticlabeling.labeledcolumn import YetUnknownTypeColumn from semanticlabeling.labelinferencer import SemanticLabelInferencer from util.file import InputFile +from util import graphbuilder, graphvisualizer from util.knowledgesource import KnowledgeSource logging.basicConfig(level=logging.INFO) @@ -17,13 +17,15 @@ logger = logging.getLogger(__name__) def main( input_file_path: str, input_file_cls: Type[InputFile], - target_ontology_paths: List[str] + target_ontology_paths: List[str], + visualize: bool ): logger.info( f'Semantic label inferencing called with input file {input_file_path} ' f'and target ontologies {" ".join(target_ontology_paths)}') ontologies = list([KnowledgeSource(path) for path in target_ontology_paths]) + # input_file holds a list of labeled columns: # (Pdb) pp(input_file.__dict__) # {'columns': [, @@ -72,44 +74,20 @@ def main( # ) # ] input_file = input_file_cls(input_file_path=input_file_path) + label_inferencer = SemanticLabelInferencer(input_file, ontologies) + graph = graphbuilder.build_from_label_inferencer(label_inferencer) + + if visualize: + graphvisualizer.visualize(graph) - network = Network() - for knowledge_source in ontologies: - for column_name, labeled_column in knowledge_source.columns.items(): - if not isinstance(labeled_column, YetUnknownTypeColumn): - node_label = str(labeled_column) - network.add_node(column_name, label=node_label, color='green') - - for column_name, labeled_column in knowledge_source.columns.items(): - if not isinstance(labeled_column, YetUnknownTypeColumn): - source_node_id = column_name - for link_name, targets in labeled_column.links.items(): - for target in targets: - target_node_id = target.column_name - network.add_edge(source_node_id, target_node_id, title=link_name) - - for labeled_column in label_inferencer.get_labeled_columns(): - node_label = str(labeled_column) - network.add_node(labeled_column.column_name, label=node_label, color='red') - - for labeled_column in label_inferencer.get_labeled_columns(): - source_node_id = labeled_column.column_name - for link_name, targets in labeled_column.links.items(): - for target in targets: - target_node_id = target.column_name - network.add_edge(source_node_id, target_node_id, title=link_name) - import pdb; pdb.set_trace() - - network.show('network.html', notebook=False, ) - - pass raise NotImplementedError() if __name__ == '__main__': arg_parser = ArgumentParser() arg_parser.add_argument('input_file') + arg_parser.add_argument( '--filetype', default='csv', @@ -117,6 +95,7 @@ if __name__ == '__main__': ) arg_parser.add_argument('target_ontologies', nargs='+') + arg_parser.add_argument('--visualize', action='store_true') args = arg_parser.parse_args() @@ -128,5 +107,6 @@ if __name__ == '__main__': main( input_file_path=input_file_path, input_file_cls=input_file_cls, - target_ontology_paths=target_ontology_paths + target_ontology_paths=target_ontology_paths, + visualize=args.visualize )