diff --git a/numjuggler/likefunc.py b/numjuggler/likefunc.py index 33b9174..7943dcf 100644 --- a/numjuggler/likefunc.py +++ b/numjuggler/likefunc.py @@ -7,7 +7,8 @@ from __future__ import print_function from collections import OrderedDict -import sys + +from numjuggler.utils.io import resolve_fname_or_stream def trivial(x): @@ -68,29 +69,19 @@ def __init__(self, log=False): self.doc = "" return - def __call1(self, x): - res = self.get_value(x) - return res - - def __call2(self, x): + def __call__(self, x): res = self.get_value(x) - self.ld[x] = res + if self.log: + self.ld[x] = res return res - def __call__(self, x): - return self.__call1(x) - @property def log(self): return self.__lf @log.setter - def log(self, v): - if v: - self.__call__ = self.__call2 - else: - self.__call__ = self.__call1 - self.__ld = v + def log(self, _log): + self.__lf = _log def __str__(self): res = [] @@ -101,25 +92,12 @@ def __str__(self): res.append('other -> {}'.format(self.default._mydoc)) return '\n'.join(res) - def _get_log_file(self, fname): - if self.log is None: + def write_log_as_map(self, t, fname_or_stream=None): + if not self.log: raise ValueError("Cannon write log for unlogged mapping.") - - # Define where to print - if fname is None: - fout = sys.stdout - else: - try: - fout = open(fname, 'w') - except: - print("Cannot open {}. Print to stdout".format(repr(fname))) - fout = sys.stdout - return fout - - def write_log_as_map(self, t, fname=None): - fout = self._get_log_file(fname) - for nold, nnew in self.ld.items(): - print('{} {}: {}'.format(t, nnew, nold), file=fout) + with resolve_fname_or_stream(fname_or_stream, "w") as fout: + for nold, nnew in self.ld.items(): + print('{} {}: {}'.format(t, nnew, nold), file=fout) class LikeFunction(LikeFunctionBase): @@ -261,7 +239,7 @@ def read_map_file(fname, log=False): # Dictionary type -> LikeFunction maps = {} - with open(fname, 'r') as mapfile: + with resolve_fname_or_stream(fname, 'r') as mapfile: for l in mapfile: t, ranges, f = _parse_map_line(l) if t is None: diff --git a/numjuggler/utils/io.py b/numjuggler/utils/io.py index 3e99910..20e6d13 100644 --- a/numjuggler/utils/io.py +++ b/numjuggler/utils/io.py @@ -1,4 +1,6 @@ import os +import sys + from .resource import Path from contextlib import contextmanager @@ -11,3 +13,18 @@ def cd_temporarily(cd_to): yield finally: os.chdir(cur_dir) + + +@contextmanager +def resolve_fname_or_stream(fname_or_stream, mode="r"): + is_input = mode == 'r' + if fname_or_stream is None: + if is_input: + yield sys.stdin + else: + yield sys.stdout + elif is_input and hasattr(fname_or_stream, "read") or not is_input and hasattr(fname_or_stream, "write"): + yield fname_or_stream + else: + with open(fname_or_stream, mode=mode) as fid: + yield fid diff --git a/tests/test_likefunc.py b/tests/test_likefunc.py new file mode 100644 index 0000000..a5a0a66 --- /dev/null +++ b/tests/test_likefunc.py @@ -0,0 +1,54 @@ +import pytest +from numjuggler.utils.io import cd_temporarily +from six import StringIO +import numjuggler.likefunc as lf + + + +@pytest.mark.parametrize("data, log, present, absent, expected_present_value, expected_absent_value, expected_text", [ + ( + """ + c 1: 12 + c 2: 14 + """, + False, + 1, 3, + 12, 3, + "", + ), + ( + """ + c 1: 12 + c 2: 14 + """, + True, + 1, 3, + 12, 3, + "cel 12: 1\ncel 3: 3\n", + ), +]) +def test_LikeFunction( + data, + log, + present, + absent, + expected_present_value, + expected_absent_value, + expected_text +): + input = StringIO(data) + maps = lf.read_map_file(input, log) + actual = StringIO() + for k in maps: + like_function = maps[k] + assert like_function.log == log + present_value = like_function(present) + assert present_value == expected_present_value + absent_value = like_function(absent) + assert absent_value == expected_absent_value + if log: + like_function.write_log_as_map(k, actual) + else: + with pytest.raises(ValueError): + like_function.write_log_as_map(k, actual) + assert actual.getvalue() == expected_text