-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🎨 Allow
datar.all.filter regardless
of allow_conflict_names
- Loading branch information
Showing
7 changed files
with
261 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import argparse | ||
|
||
|
||
def test_getattr(module, allow_conflict_names, fun, error): | ||
from datar import options | ||
options(allow_conflict_names=allow_conflict_names) | ||
|
||
if module == "all": | ||
import datar.all as d | ||
elif module == "base": | ||
import datar.base as d | ||
elif module == "dplyr": | ||
import datar.dplyr as d | ||
|
||
if not error: | ||
return getattr(d, fun) | ||
|
||
try: | ||
getattr(d, fun) | ||
except Exception as e: | ||
raised = type(e).__name__ | ||
assert raised == error, f"Raised {raised}, expected {error}" | ||
else: | ||
raise AssertionError(f"{error} should have raised") | ||
|
||
|
||
def _import(module, fun): | ||
if module == "all" and fun == "sum": | ||
from datar.all import sum # noqa: F401 | ||
elif module == "all" and fun == "slice": | ||
from datar.all import slice # noqa: F401 | ||
elif module == "base" and fun == "sum": | ||
from datar.base import sum # noqa: F401 | ||
elif module == "dplyr" and fun == "slice": | ||
from datar.dplyr import slice # noqa: F401 | ||
|
||
|
||
def test_import(module, allow_conflict_names, fun, error): | ||
from datar import options | ||
options(allow_conflict_names=allow_conflict_names) | ||
|
||
if not error: | ||
return _import(module, fun) | ||
|
||
try: | ||
_import(module, fun) | ||
except Exception as e: | ||
raised = type(e).__name__ | ||
assert raised == error, f"Raised {raised}, expected {error}" | ||
else: | ||
raise AssertionError(f"{error} should have raised") | ||
|
||
|
||
def make_test(module, allow_conflict_names, getattr, fun, error): | ||
if fun == "_": | ||
fun = "sum" if module in ["all", "base"] else "slice" | ||
|
||
if getattr: | ||
return test_getattr(module, allow_conflict_names, fun, error) | ||
|
||
return test_import(module, allow_conflict_names, fun, error) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--module", | ||
choices=["all", "base", "dplyr"], | ||
required=True, | ||
help="The module to test" | ||
) | ||
parser.add_argument( | ||
"--allow-conflict-names", | ||
action="store_true", | ||
help="Whether to allow conflict names", | ||
default=False, | ||
) | ||
parser.add_argument( | ||
"--getattr", | ||
action="store_true", | ||
help=( | ||
"Whether to test datar.all.sum, " | ||
"otherwise test from datar.all import sum." | ||
), | ||
default=False, | ||
) | ||
parser.add_argument( | ||
"--fun", | ||
help=( | ||
"The function to test. " | ||
"If _ then sum for all/base, slice for dplyr" | ||
), | ||
choices=["sum", "filter", "_"], | ||
default="_", | ||
) | ||
parser.add_argument( | ||
"--error", | ||
help="The error to expect", | ||
) | ||
args = parser.parse_args() | ||
|
||
make_test( | ||
args.module, | ||
args.allow_conflict_names, | ||
args.getattr, | ||
args.fun, | ||
args.error, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import sys | ||
import subprocess | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
|
||
def _run_conflict_names(module, allow_conflict_names, getat, error): | ||
here = Path(__file__).parent | ||
conflict_names = here / "conflict_names.py" | ||
cmd = [ | ||
sys.executable, | ||
str(conflict_names), | ||
"--module", | ||
module, | ||
] | ||
if error: | ||
cmd += ["--error", error] | ||
if allow_conflict_names: | ||
cmd.append("--allow-conflict-names") | ||
if getat: | ||
cmd.append("--getattr") | ||
|
||
p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | ||
return p.wait(), " ".join(cmd) | ||
|
||
|
||
def test_from_all_import_allow_conflict_names_true(): | ||
r, cmd = _run_conflict_names("all", True, False, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_from_all_import_allow_conflict_names_false(): | ||
r, cmd = _run_conflict_names("all", False, False, "ImportError") | ||
assert r == 0, cmd | ||
|
||
|
||
def test_all_getattr_allow_conflict_names_true(): | ||
r, cmd = _run_conflict_names("all", True, True, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_all_getattr_allow_conflict_names_false(): | ||
r, cmd = _run_conflict_names("all", False, True, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_from_base_import_allow_conflict_names_true(): | ||
r, cmd = _run_conflict_names("base", True, False, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_from_base_import_allow_conflict_names_false(): | ||
r, cmd = _run_conflict_names("base", False, False, "ImportError") | ||
assert r == 0, cmd | ||
|
||
|
||
def test_base_getattr_allow_conflict_names_true(): | ||
r, cmd = _run_conflict_names("base", True, True, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_base_getattr_allow_conflict_names_false(): | ||
r, cmd = _run_conflict_names("base", False, True, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_from_dplyr_import_allow_conflict_names_true(): | ||
r, cmd = _run_conflict_names("dplyr", True, False, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_from_dplyr_import_allow_conflict_names_false(): | ||
r, cmd = _run_conflict_names("dplyr", False, False, "ImportError") | ||
assert r == 0, cmd | ||
|
||
|
||
def test_dplyr_getattr_allow_conflict_names_true(): | ||
r, cmd = _run_conflict_names("dplyr", True, True, None) | ||
assert r == 0, cmd | ||
|
||
|
||
def test_dplyr_getattr_allow_conflict_names_false(): | ||
r, cmd = _run_conflict_names("dplyr", False, True, None) | ||
assert r == 0, cmd |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters