Skip to content

Commit

Permalink
add testing util to mock koza, rename cli runner to cli utils
Browse files Browse the repository at this point in the history
  • Loading branch information
glass-ships committed Apr 24, 2024
1 parent 004332a commit e2fb994
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 17 deletions.
24 changes: 9 additions & 15 deletions src/koza/cli_runner.py → src/koza/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_koza_app(source_name) -> Optional[KozaApp]:
def transform_source(
source: str,
output_dir: str,
output_format: OutputFormat = OutputFormat('tsv'),
output_format: OutputFormat = OutputFormat("tsv"),
global_table: str = None,
local_table: str = None,
schema: str = None,
Expand All @@ -61,9 +61,9 @@ def transform_source(
"""
logger = get_logger(name=Path(source).name if log else None, verbose=verbose)

with open(source, 'r') as source_fh:
with open(source, "r") as source_fh:
source_config = PrimaryFileConfig(**yaml.load(source_fh, Loader=UniqueIncludeLoader))

# TODO: Try moving this to source_config class
if not source_config.name:
source_config.name = Path(source).stem
Expand All @@ -73,7 +73,7 @@ def transform_source(
if not Path(filename).exists():
filename = Path(source).parent / "transform.py"
if not Path(filename).exists():
raise FileNotFoundError(f"Could not find transform file for {source}")
raise FileNotFoundError(f"Could not find transform file for {source}")
source_config.transform_code = filename

koza_source = Source(source_config, row_limit)
Expand All @@ -94,7 +94,7 @@ def transform_source(
def validate_file(
file: str,
format: FormatType = FormatType.csv,
delimiter: str = ',',
delimiter: str = ",",
header_delimiter: str = None,
skip_blank_lines: bool = True,
):
Expand Down Expand Up @@ -149,14 +149,14 @@ def get_translation_table(
logger.debug("No global table used for transform")
else:
if isinstance(global_table, str):
with open(global_table, 'r') as global_tt_fh:
with open(global_table, "r") as global_tt_fh:
global_tt = yaml.safe_load(global_tt_fh)
elif isinstance(global_table, Dict):
global_tt = global_table

if local_table:
if isinstance(local_table, str):
with open(local_table, 'r') as local_tt_fh:
with open(local_table, "r") as local_tt_fh:
local_tt = yaml.safe_load(local_tt_fh)
elif isinstance(local_table, Dict):
local_tt = local_table
Expand All @@ -170,8 +170,8 @@ def get_translation_table(
def _set_koza_app(
source: Source,
translation_table: TranslationTable = None,
output_dir: str = './output',
output_format: OutputFormat = OutputFormat('tsv'),
output_dir: str = "./output",
output_format: OutputFormat = OutputFormat("tsv"),
schema: str = None,
node_type: str = None,
edge_type: str = None,
Expand All @@ -184,9 +184,3 @@ def _set_koza_app(
)
logger.debug(f"koza_apps entry created for {source.config.name}: {koza_apps[source.config.name]}")
return koza_apps[source.config.name]


def test_koza(koza: KozaApp):
"""Manually sets KozaApp (for testing)"""
global koza_app
koza_app = koza
2 changes: 1 addition & 1 deletion src/koza/converter/biolink_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from biolink_model.datamodel.pydanticmodel_v2 import Gene

from koza.cli_runner import koza_app
from koza.cli_utils import koza_app


def gpi2gene(row: dict) -> Gene:
Expand Down
2 changes: 1 addition & 1 deletion src/koza/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
from typing import Optional

from koza.cli_runner import transform_source, validate_file
from koza.cli_utils import transform_source, validate_file
from koza.model.config.source_config import FormatType, OutputFormat

import typer
Expand Down
78 changes: 78 additions & 0 deletions src/koza/utils/testing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import types
from typing import Iterable

from loguru import logger

from koza.app import KozaApp
from koza.cli_utils import get_koza_app, get_translation_table, _set_koza_app
from koza.model.config.source_config import PrimaryFileConfig
from koza.model.source import Source

def test_koza(koza: KozaApp):
"""Manually sets KozaApp for testing"""
global koza_app
koza_app = koza

def mock_koza():
"""Mock KozaApp for testing"""
def _mock_write(self, *entities):
if hasattr(self, '_entities'):
self._entities.extend(list(entities))
else:
self._entities = list(entities)

def _make_mock_koza_app(
name: str,
data: Iterable,
transform_code: str,
map_cache=None,
filters=None,
global_table=None,
local_table=None,
):
mock_source_file_config = PrimaryFileConfig(
name=name,
files=[],
transform_code=transform_code,
)
mock_source_file = Source(mock_source_file_config)
mock_source_file._reader = data

_set_koza_app(
source=mock_source_file,
translation_table=get_translation_table(global_table, local_table, logger),
logger=logger,
)
koza = get_koza_app(name)

# TODO filter mocks
koza._map_cache = map_cache
koza.write = types.MethodType(_mock_write, koza)

return koza

def _transform(
name: str,
data: Iterable,
transform_code: str,
map_cache=None,
filters=None,
global_table=None,
local_table=None,
):
koza_app = _make_mock_koza_app(
name,
data,
transform_code,
map_cache=map_cache,
filters=filters,
global_table=global_table,
local_table=local_table,
)
test_koza(koza_app)
koza_app.process_sources()
if not hasattr(koza_app, '_entities'):
koza_app._entities = []
return koza_app._entities

return _transform

0 comments on commit e2fb994

Please sign in to comment.