Skip to content

Commit

Permalink
Add mixin types feature to ADT (#10)
Browse files Browse the repository at this point in the history
Now, by specifying mixin_types={'prod': cls} one can inject an additional
base class to any production `prod`. This enables users to attach methods
to the generated classes. It might also be useful to static type checkers
by setting a base type to be one of Python 3.8's "Protocol" types.
  • Loading branch information
alexreinking authored Nov 8, 2022
1 parent 5838ad8 commit 5c29b7a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 5 deletions.
24 changes: 20 additions & 4 deletions src/asdl_adt/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ class _BuildClasses(asdl.VisitorBase):
def __init__(
self,
ext_types: Optional[Mapping[str, type]] = None,
mixin_types: Optional[Mapping[str, type]] = None,
memoize: Optional[Collection[str]] = None,
):
super().__init__()
self.module: Optional[ModuleType] = None

self._mixin_types = mixin_types or {}
self._memoize = memoize or set()
self._type_map = {
**_BuildClasses._builtin_types,
Expand Down Expand Up @@ -131,7 +134,13 @@ def _adt_class(self, *, name, base, fields: Union[List[str], OrderedDict]):
"__qualname__": f"{basename}.{name}",
"__annotations__": {f: None for f in fields},
}
cls = attrs.frozen(init=False)(type(name, (base,), members))

if mixin := self._mixin_types.get(name):
base_types = (base, mixin)
else:
base_types = (base,)

cls = attrs.frozen(init=False)(type(name, base_types, members))
if cls.__name__ in self._memoize:
cls.__new__ = self._cached_new_fn(cls, fields)
return cls
Expand Down Expand Up @@ -243,10 +252,13 @@ def _get_point_validator(self, field: asdl.Field):

def ADT( # pylint: disable=invalid-name
asdl_str: str,
*,
ext_types: Optional[Mapping[str, Union[type, Callable]]] = None,
memoize: Optional[Collection[str]] = None,
mixin_types: Optional[Mapping[str, type]] = None,
):
"""Function that converts an ASDL grammar into a Python Module.
"""
Function that converts an ASDL grammar into a Python Module.
The returned module will contain one class for every ASDL type
declared in the input grammar, and one (sub-)class for every
Expand All @@ -271,7 +283,7 @@ def ADT( # pylint: disable=invalid-name
=================
asdl_str : str
The ASDL definition string
ext_types : Optional[Mapping[str, type]]
ext_types : Optional[Mapping[str, Union[type, Callable]]]
A mapping of custom type names to Python types. Used to create validators for
the __init__ method of generated classes. Several built-in types are implied,
with the following corresponding Python types:
Expand All @@ -282,6 +294,10 @@ def ADT( # pylint: disable=invalid-name
* 'string' - str
memoize : Optional[Collection[str]]
collection of constructor names to memoize, optional
mixin_types : Optional[Mapping[str, type]]
A mapping of generated type names (matching the ASDL productions) to
mixin classes from which to inherit. This is useful for injecting custom
methods into the generated classes.
Returns
=================
Expand All @@ -307,7 +323,7 @@ def ADT( # pylint: disable=invalid-name
if mod := sys.modules.get(asdl_ast.name):
return mod

builder = _BuildClasses(ext_types, memoize)
builder = _BuildClasses(ext_types, mixin_types, memoize)
builder.visit(asdl_ast)

mod = builder.module
Expand Down
38 changes: 38 additions & 0 deletions test/test_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Tests of mixin class feature
"""

import pytest

from asdl_adt import ADT


def test_basic_mixin():
"""
Test that a basic mixin class is properly injected into the given sum type
alternative, and not the others.
"""

class MixinA: # pylint: disable=C0115,C0116,R0903
def double(self):
return self.update(x=2 * self.x)

mixin_grammar = ADT(
"""
module test_basic_mixin {
prod = ( int x, int y )
sum = A( int x )
| B( float y )
| C( int x, int y )
}
""",
mixin_types={
"A": MixinA,
},
)

obj = mixin_grammar.A(3)
assert obj.double() == mixin_grammar.A(6)

with pytest.raises(AttributeError, match="'B' object has no attribute 'double'"):
mixin_grammar.B(3.14).double()
16 changes: 16 additions & 0 deletions test/test_module_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Coverage-focused test for object identity of same-named ADT types.
"""

from asdl_adt import ADT


def test_module_caching():
"""
Test that creating a second module with the same name returns the same
object identically as the first call
"""

grammar_a = ADT("module cache_test { foo = ( int bar ) }")
grammar_b = ADT("module cache_test { foo = ( int bar ) }")
assert grammar_a is grammar_b
2 changes: 1 addition & 1 deletion test/test_ueq_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fixture_ueq_grammar():
| Scale( int coeff, expr e )
}
""",
{"sym": Sym},
ext_types={"sym": Sym},
)


Expand Down

0 comments on commit 5c29b7a

Please sign in to comment.