Skip to content

Commit

Permalink
🎨 Allow datar.all.filter regardless of allow_conflict_names
Browse files Browse the repository at this point in the history
  • Loading branch information
pwwang committed Aug 11, 2023
1 parent dce636d commit 2a6cc41
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 15 deletions.
18 changes: 17 additions & 1 deletion datar/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,23 @@

__all__ = [key for key in locals() if not key.startswith("_")]

if get_option("allow_conflict_names"): # noqa: F405 pragma: no cover
if get_option("allow_conflict_names"): # noqa: F405
__all__.extend(_base_conflict_names | _dplyr_conflict_names)
for name in _base_conflict_names | _dplyr_conflict_names:
locals()[name] = locals()[name + "_"]


def __getattr__(name):
"""Even when allow_conflict_names is False, datar.base.sum should be fine
"""
if name in _base_conflict_names | _dplyr_conflict_names:
import sys
import ast
from executing import Source
node = Source.executing(sys._getframe(1)).node
if isinstance(node, (ast.Call, ast.Attribute)):
# import datar.all as d
# d.sum(...) or getattr(d, "sum")(...)
return globals()[name + "_"]

raise AttributeError
18 changes: 17 additions & 1 deletion datar/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,23 @@
__all__ = [key for key in locals() if not key.startswith("_")]
_conflict_names = {"min", "max", "sum", "abs", "round", "all", "any", "re"}

if get_option("allow_conflict_names"): # noqa: F405 pragma: no cover
if get_option("allow_conflict_names"): # noqa: F405
__all__.extend(_conflict_names)
for name in _conflict_names:
locals()[name] = locals()[name + "_"]


def __getattr__(name):
"""Even when allow_conflict_names is False, datar.base.sum should be fine
"""
if name in _conflict_names:
import sys
import ast
from executing import Source
node = Source.executing(sys._getframe(1)).node
if isinstance(node, (ast.Call, ast.Attribute)):
# import datar.base as d
# d.sum(...)
return globals()[name + "_"]

raise AttributeError
18 changes: 17 additions & 1 deletion datar/dplyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@
__all__ = [key for key in locals() if not key.startswith("_")]
_conflict_names = {"filter", "slice"}

if _get_option("allow_conflict_names"): # pragma: no cover
if _get_option("allow_conflict_names"):
__all__.extend(_conflict_names)
for name in _conflict_names:
locals()[name] = locals()[name + "_"]


def __getattr__(name):
"""Even when allow_conflict_names is False, datar.base.sum should be fine
"""
if name in _conflict_names:
import sys
import ast
from executing import Source
node = Source.executing(sys._getframe(1)).node
if isinstance(node, (ast.Call, ast.Attribute)):
# import datar.dplyr as d
# d.sum(...)
return globals()[name + "_"]

raise AttributeError
24 changes: 12 additions & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

112 changes: 112 additions & 0 deletions tests/conflict_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import argparse


def test_getattr(module, allow_conflict_names, fun, error):
from datar import options
options(allow_conflict_names=allow_conflict_names)

if module == "all":
import datar.all as d
elif module == "base":
import datar.base as d
elif module == "dplyr":
import datar.dplyr as d

if not error:
return getattr(d, fun)

try:
getattr(d, fun)
except Exception as e:
raised = type(e).__name__
assert raised == error, f"Raised {raised}, expected {error}"
else:
raise AssertionError(f"{error} should have raised")


def _import(module, fun):
if module == "all" and fun == "sum":
from datar.all import sum # noqa: F401
elif module == "all" and fun == "slice":
from datar.all import slice # noqa: F401
elif module == "base" and fun == "sum":
from datar.base import sum # noqa: F401
elif module == "dplyr" and fun == "slice":
from datar.dplyr import slice # noqa: F401


def test_import(module, allow_conflict_names, fun, error):
from datar import options
options(allow_conflict_names=allow_conflict_names)

if not error:
return _import(module, fun)

try:
_import(module, fun)
except Exception as e:
raised = type(e).__name__
assert raised == error, f"Raised {raised}, expected {error}"
else:
raise AssertionError(f"{error} should have raised")


def make_test(module, allow_conflict_names, getattr, fun, error):
if fun == "_":
fun = "sum" if module in ["all", "base"] else "slice"

if getattr:
return test_getattr(module, allow_conflict_names, fun, error)

return test_import(module, allow_conflict_names, fun, error)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--module",
choices=["all", "base", "dplyr"],
required=True,
help="The module to test"
)
parser.add_argument(
"--allow-conflict-names",
action="store_true",
help="Whether to allow conflict names",
default=False,
)
parser.add_argument(
"--getattr",
action="store_true",
help=(
"Whether to test datar.all.sum, "
"otherwise test from datar.all import sum."
),
default=False,
)
parser.add_argument(
"--fun",
help=(
"The function to test. "
"If _ then sum for all/base, slice for dplyr"
),
choices=["sum", "filter", "_"],
default="_",
)
parser.add_argument(
"--error",
help="The error to expect",
)
args = parser.parse_args()

make_test(
args.module,
args.allow_conflict_names,
args.getattr,
args.fun,
args.error,
)


if __name__ == "__main__":
main()
85 changes: 85 additions & 0 deletions tests/test_conflict_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import sys
import subprocess
from pathlib import Path

import pytest


def _run_conflict_names(module, allow_conflict_names, getat, error):
here = Path(__file__).parent
conflict_names = here / "conflict_names.py"
cmd = [
sys.executable,
str(conflict_names),
"--module",
module,
]
if error:
cmd += ["--error", error]
if allow_conflict_names:
cmd.append("--allow-conflict-names")
if getat:
cmd.append("--getattr")

p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return p.wait(), " ".join(cmd)


def test_from_all_import_allow_conflict_names_true():
r, cmd = _run_conflict_names("all", True, False, None)
assert r == 0, cmd


def test_from_all_import_allow_conflict_names_false():
r, cmd = _run_conflict_names("all", False, False, "ImportError")
assert r == 0, cmd


def test_all_getattr_allow_conflict_names_true():
r, cmd = _run_conflict_names("all", True, True, None)
assert r == 0, cmd


def test_all_getattr_allow_conflict_names_false():
r, cmd = _run_conflict_names("all", False, True, None)
assert r == 0, cmd


def test_from_base_import_allow_conflict_names_true():
r, cmd = _run_conflict_names("base", True, False, None)
assert r == 0, cmd


def test_from_base_import_allow_conflict_names_false():
r, cmd = _run_conflict_names("base", False, False, "ImportError")
assert r == 0, cmd


def test_base_getattr_allow_conflict_names_true():
r, cmd = _run_conflict_names("base", True, True, None)
assert r == 0, cmd


def test_base_getattr_allow_conflict_names_false():
r, cmd = _run_conflict_names("base", False, True, None)
assert r == 0, cmd


def test_from_dplyr_import_allow_conflict_names_true():
r, cmd = _run_conflict_names("dplyr", True, False, None)
assert r == 0, cmd


def test_from_dplyr_import_allow_conflict_names_false():
r, cmd = _run_conflict_names("dplyr", False, False, "ImportError")
assert r == 0, cmd


def test_dplyr_getattr_allow_conflict_names_true():
r, cmd = _run_conflict_names("dplyr", True, True, None)
assert r == 0, cmd


def test_dplyr_getattr_allow_conflict_names_false():
r, cmd = _run_conflict_names("dplyr", False, True, None)
assert r == 0, cmd
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ per-file-ignores =
datar/base.py: F401, F402, F403, E402
datar/dplyr.py: F401, F402, F403, E402
datar/data/metadata.py: E501
tests/test_conflict_names.py: F401
max-line-length = 81

0 comments on commit 2a6cc41

Please sign in to comment.