diff --git a/src/asdl_adt/adt.py b/src/asdl_adt/adt.py index 79878fd..cbdf941 100644 --- a/src/asdl_adt/adt.py +++ b/src/asdl_adt/adt.py @@ -76,10 +76,13 @@ class _BuildClasses(asdl.VisitorBase): def __init__( self, ext_types: Optional[Mapping[str, type]] = None, + mixin_types: Optional[Mapping[str, type]] = None, memoize: Optional[Collection[str]] = None, ): super().__init__() self.module: Optional[ModuleType] = None + + self._mixin_types = mixin_types or {} self._memoize = memoize or set() self._type_map = { **_BuildClasses._builtin_types, @@ -131,7 +134,13 @@ def _adt_class(self, *, name, base, fields: Union[List[str], OrderedDict]): "__qualname__": f"{basename}.{name}", "__annotations__": {f: None for f in fields}, } - cls = attrs.frozen(init=False)(type(name, (base,), members)) + + if mixin := self._mixin_types.get(name): + base_types = (base, mixin) + else: + base_types = (base,) + + cls = attrs.frozen(init=False)(type(name, base_types, members)) if cls.__name__ in self._memoize: cls.__new__ = self._cached_new_fn(cls, fields) return cls @@ -243,10 +252,13 @@ def _get_point_validator(self, field: asdl.Field): def ADT( # pylint: disable=invalid-name asdl_str: str, + *, ext_types: Optional[Mapping[str, Union[type, Callable]]] = None, memoize: Optional[Collection[str]] = None, + mixin_types: Optional[Mapping[str, type]] = None, ): - """Function that converts an ASDL grammar into a Python Module. + """ + Function that converts an ASDL grammar into a Python Module. The returned module will contain one class for every ASDL type declared in the input grammar, and one (sub-)class for every @@ -271,7 +283,7 @@ def ADT( # pylint: disable=invalid-name ================= asdl_str : str The ASDL definition string - ext_types : Optional[Mapping[str, type]] + ext_types : Optional[Mapping[str, Union[type, Callable]]] A mapping of custom type names to Python types. Used to create validators for the __init__ method of generated classes. Several built-in types are implied, with the following corresponding Python types: @@ -282,6 +294,10 @@ def ADT( # pylint: disable=invalid-name * 'string' - str memoize : Optional[Collection[str]] collection of constructor names to memoize, optional + mixin_types : Optional[Mapping[str, type]] + A mapping of generated type names (matching the ASDL productions) to + mixin classes from which to inherit. This is useful for injecting custom + methods into the generated classes. Returns ================= @@ -307,7 +323,7 @@ def ADT( # pylint: disable=invalid-name if mod := sys.modules.get(asdl_ast.name): return mod - builder = _BuildClasses(ext_types, memoize) + builder = _BuildClasses(ext_types, mixin_types, memoize) builder.visit(asdl_ast) mod = builder.module diff --git a/test/test_mixins.py b/test/test_mixins.py new file mode 100644 index 0000000..dc7a2aa --- /dev/null +++ b/test/test_mixins.py @@ -0,0 +1,38 @@ +""" +Tests of mixin class feature +""" + +import pytest + +from asdl_adt import ADT + + +def test_basic_mixin(): + """ + Test that a basic mixin class is properly injected into the given sum type + alternative, and not the others. + """ + + class MixinA: # pylint: disable=C0115,C0116,R0903 + def double(self): + return self.update(x=2 * self.x) + + mixin_grammar = ADT( + """ + module test_basic_mixin { + prod = ( int x, int y ) + sum = A( int x ) + | B( float y ) + | C( int x, int y ) + } + """, + mixin_types={ + "A": MixinA, + }, + ) + + obj = mixin_grammar.A(3) + assert obj.double() == mixin_grammar.A(6) + + with pytest.raises(AttributeError, match="'B' object has no attribute 'double'"): + mixin_grammar.B(3.14).double() diff --git a/test/test_module_caching.py b/test/test_module_caching.py new file mode 100644 index 0000000..a289404 --- /dev/null +++ b/test/test_module_caching.py @@ -0,0 +1,16 @@ +""" +Coverage-focused test for object identity of same-named ADT types. +""" + +from asdl_adt import ADT + + +def test_module_caching(): + """ + Test that creating a second module with the same name returns the same + object identically as the first call + """ + + grammar_a = ADT("module cache_test { foo = ( int bar ) }") + grammar_b = ADT("module cache_test { foo = ( int bar ) }") + assert grammar_a is grammar_b diff --git a/test/test_ueq_grammar.py b/test/test_ueq_grammar.py index b526e1d..f995e68 100644 --- a/test/test_ueq_grammar.py +++ b/test/test_ueq_grammar.py @@ -46,7 +46,7 @@ def fixture_ueq_grammar(): | Scale( int coeff, expr e ) } """, - {"sym": Sym}, + ext_types={"sym": Sym}, )