diff --git a/avrogen/core_writer.py b/avrogen/core_writer.py index 4da2178..5f9fa80 100644 --- a/avrogen/core_writer.py +++ b/avrogen/core_writer.py @@ -43,7 +43,7 @@ def convert_default(idx, full_name=None, do_json=True): return f'self.RECORD_SCHEMA.fields_dict["{idx}"].default' -def get_default(field, use_logical_types, my_full_name=None, f_name=None): +def get_default(field, use_logical_types, my_full_name=None, f_name=None, class_naming_strategy=None): default_written = False f_name = f_name if f_name is not None else field.name if keyword.iskeyword(field.name): @@ -85,16 +85,21 @@ def get_default(field, use_logical_types, my_full_name=None, f_name=None): elif isinstance(default_type, schema.FixedSchema): return 'bytes()' elif isinstance(default_type, schema.RecordSchema): - f = clean_fullname(default_type.name) - return f'{f}Class.construct_with_defaults()' + classname = ( + class_naming_strategy(default_type.fullname) + if class_naming_strategy + else f'{clean_fullname(default_type.name)}Class' + ) + return f'{classname}.construct_with_defaults()' raise AttributeError('cannot get default for field') -def write_defaults(record, writer, my_full_name=None, use_logical_types=False): +def write_defaults(record, writer, my_full_name=None, use_logical_types=False, class_naming_strategy=None): """ Write concrete record class's constructor part which initializes fields with default values :param schema.RecordSchema record: Avro RecordSchema whose class we are generating :param TabbedWriter writer: Writer to write to :param str my_full_name: Full name of the RecordSchema we are writing. Should only be provided for protocol requests. + :param function class_naming_strategy: Function to compute class names from full names :return: """ i = 0 @@ -105,7 +110,7 @@ def write_defaults(record, writer, my_full_name=None, use_logical_types=False): f_name = field.name if keyword.iskeyword(field.name): f_name = field.name + get_field_type_name(field.type, use_logical_types) - default = get_default(field, use_logical_types, my_full_name=my_full_name, f_name=f_name) + default = get_default(field, use_logical_types, my_full_name=my_full_name, f_name=f_name, class_naming_strategy=class_naming_strategy) writer.write(f'\nself.{f_name} = {default}') something_written = True i += 1 @@ -113,16 +118,17 @@ def write_defaults(record, writer, my_full_name=None, use_logical_types=False): writer.write('\npass') -def write_fields(record, writer, use_logical_types): +def write_fields(record, writer, use_logical_types, class_naming_strategy=None): """ Write field definitions for a given RecordSchema :param schema.RecordSchema record: Avro RecordSchema we are generating :param TabbedWriter writer: Writer to write to + :param function class_naming_strategy: Function to compute class names from full names :return: """ writer.write('\n\n') for field in record.fields: # type: schema.Field - write_field(field, writer, use_logical_types) + write_field(field, writer, use_logical_types, class_naming_strategy=class_naming_strategy) def get_field_name(field, use_logical_types): name = field.name @@ -130,11 +136,12 @@ def get_field_name(field, use_logical_types): name = field.name + get_field_type_name(field.type, use_logical_types) return name -def write_field(field, writer, use_logical_types): +def write_field(field, writer, use_logical_types, class_naming_strategy=None): """ Write a single field definition :param field: :param writer: + :param function class_naming_strategy: Function to compute class names from full names :return: """ name = get_field_name(field, use_logical_types) @@ -152,7 +159,13 @@ def {name}(self, value: {ret_type_name}) -> None: {set_docstring} self._inner_dict['{raw_name}'] = value -'''.format(name=name, get_docstring=get_docstring, set_docstring=set_docstring, raw_name=field.name, ret_type_name=get_field_type_name(field.type, use_logical_types))) +'''.format( + name=name, + get_docstring=get_docstring, + set_docstring=set_docstring, + raw_name=field.name, + ret_type_name=get_field_type_name(field.type, use_logical_types, class_naming_strategy=class_naming_strategy) +)) def get_primitive_field_initializer(field_schema): @@ -168,10 +181,11 @@ def get_primitive_field_initializer(field_schema): return get_field_type_name(field_schema, False) + "()" -def get_field_type_name(field_schema, use_logical_types): +def get_field_type_name(field_schema, use_logical_types, class_naming_strategy=None): """ Gets a python type-hint for a given schema :param schema.Schema field_schema: + :param function class_naming_strategy: Function to compute class names from full names :return: String containing python type hint """ if use_logical_types and field_schema.props.get('logicalType'): @@ -190,16 +204,21 @@ def get_field_type_name(field_schema, use_logical_types): # For enums, we have their "class" types, but they're actually # represented as strings. This is a decent hack to work around # the issue. - return f'Union[str, "{field_schema.name}Class"]' + classname = class_naming_strategy(field_schema.fullname) if class_naming_strategy else f'{field_schema.name}Class' + return f'Union[str, "{classname}"]' elif isinstance(field_schema, schema.NamedSchema): - return f'"{field_schema.name}Class"' + classname = class_naming_strategy(field_schema.fullname) if class_naming_strategy else f'{field_schema.name}Class' + return f'"{classname}"' elif isinstance(field_schema, schema.ArraySchema): return 'List[' + get_field_type_name(field_schema.items, use_logical_types) + ']' elif isinstance(field_schema, schema.MapSchema): return 'Dict[str, ' + get_field_type_name(field_schema.values, use_logical_types) + ']' elif isinstance(field_schema, schema.UnionSchema): - type_names = [get_field_type_name(x, use_logical_types) for x in field_schema.schemas if - get_field_type_name(x, use_logical_types)] + type_names = [ + get_field_type_name(x, use_logical_types, class_naming_strategy=class_naming_strategy) + for x in field_schema.schemas + if get_field_type_name(x, use_logical_types, class_naming_strategy=class_naming_strategy) + ] if len(type_names) > 1: return 'Union[' + ', '.join(type_names) + ']' elif len(type_names) == 1: @@ -304,11 +323,12 @@ def write_get_schema(writer): writer.write('\nreturn __SCHEMAS.get(fullname)\n\n') -def write_reader_impl(record_types, writer, use_logical_types): +def write_reader_impl(record_types, writer, use_logical_types, class_naming_strategy=None): """ Write specific reader implementation :param list[schema.RecordSchema] record_types: :param writer: + :param function class_naming_strategy: Function to compute class names from full names :return: """ writer.write('\n\n\nclass SpecificDatumReader(%s):' % ( @@ -317,10 +337,11 @@ def write_reader_impl(record_types, writer, use_logical_types): writer.write('\nSCHEMA_TYPES = {') with writer.indent(): for t in record_types: - t_class = t.split('.')[-1] - writer.write('\n"{t_class}": {t_class}Class,'.format(t_class=t_class)) - writer.write('\n".{t_class}": {t_class}Class,'.format(t_class=t_class)) - writer.write('\n"{f_class}": {t_class}Class,'.format(t_class=t_class, f_class=t)) + t_class = class_naming_strategy(t) if class_naming_strategy else t.split(".")[-1] + classname = class_naming_strategy(t) if class_naming_strategy else f'{t.split(".")[-1]}Class' + writer.write('\n"{t_class}": {classname},'.format(t_class=t_class, classname=classname)) + writer.write('\n".{t_class}": {classname},'.format(t_class=t_class, classname=classname)) + writer.write('\n"{f_class}": {classname},'.format(t_class=t_class, f_class=t, classname=classname)) writer.write('\n}') writer.write('\n\n\ndef __init__(self, readers_schema=None, **kwargs):') @@ -371,16 +392,21 @@ def generate_namespace_modules(names, output_folder): return ns_dict -def write_schema_record(record, writer, use_logical_types): +def write_schema_record(record, writer, use_logical_types, class_naming_strategy=None): """ Writes class representing Avro record schema :param avro.schema.RecordSchema record: :param TabbedWriter writer: + :param function class_naming_strategy: Function to compute class names from full names :return: """ - - _, type_name = ns_.split_fullname(record.fullname) - writer.write('''\nclass {name}Class(DictWrapper):'''.format(name=type_name)) + fullname = record.fullname + classname = ( + class_naming_strategy(fullname) + if class_naming_strategy + else f'{ns_.split_fullname(fullname)[-1]}Class' + ) + writer.write("""\nclass {name}(DictWrapper):""".format(name=classname)) with writer.indent(): writer.write('\n') @@ -390,24 +416,24 @@ def write_schema_record(record, writer, use_logical_types): writer.write('# No docs available.') writer.write('\n\nRECORD_SCHEMA = get_schema_type("%s")' % (record.fullname)) - write_record_init(record, writer, use_logical_types) + write_record_init(record, writer, use_logical_types, class_naming_strategy=class_naming_strategy) - write_fields(record, writer, use_logical_types) + write_fields(record, writer, use_logical_types, class_naming_strategy=class_naming_strategy) -def write_record_init(record, writer, use_logical_types): +def write_record_init(record, writer, use_logical_types, class_naming_strategy=None): writer.write('\ndef __init__(self,') with writer.indent(): delayed_lines = [] default_map: Dict[str, str] = {} for field in record.fields: # type: schema.Field name = get_field_name(field, use_logical_types) - ret_type_name = get_field_type_name(field.type, use_logical_types) + ret_type_name = get_field_type_name(field.type, use_logical_types, class_naming_strategy=class_naming_strategy) default_type, nullable = find_type_of_default(field.type) if not nullable and field.has_default: # print(record.name, field.name, field.default) - default = get_default(field, use_logical_types, f_name=field.name) + default = get_default(field, use_logical_types, f_name=field.name, class_naming_strategy=class_naming_strategy) default_map[name] = default ret_type_name = f"Optional[{ret_type_name}]" nullable = True @@ -437,7 +463,8 @@ def write_record_init(record, writer, use_logical_types): writer.write(f'\nself.{name} = {name}') writer.write('\n\n@classmethod') - writer.write(f'\ndef construct_with_defaults(cls) -> "{record.name}Class":') + classname = class_naming_strategy(record.fullname) if class_naming_strategy else f'{record.name}Class' + writer.write(f'\ndef construct_with_defaults(cls) -> "{classname}":') with writer.indent(): writer.write('\nself = cls.construct({})') writer.write('\nself._restore_defaults()') @@ -447,7 +474,7 @@ def write_record_init(record, writer, use_logical_types): writer.write('\n') writer.write(f'\ndef _restore_defaults(self) -> None:') with writer.indent(): - write_defaults(record, writer, use_logical_types=use_logical_types) + write_defaults(record, writer, use_logical_types=use_logical_types, class_naming_strategy=class_naming_strategy) def write_enum(enum, writer): diff --git a/avrogen/schema.py b/avrogen/schema.py index 620b498..c5d1510 100644 --- a/avrogen/schema.py +++ b/avrogen/schema.py @@ -22,12 +22,13 @@ logger.setLevel(logging.INFO) -def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None): +def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None, class_naming_strategy=None): """ Generate file containing concrete classes for RecordSchemas in given avro schema json :param str schema_json: JSON representing avro schema :param list[str] custom_imports: Add additional import modules :param str avro_json_converter: AvroJsonConverter type to use for default values + :param function class_naming_strategy: Function to compute class names from full names :return Dict[str, str]: """ @@ -61,7 +62,7 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a current_namespace = namespace if isinstance(field_schema, schema.RecordSchema): logger.debug(f'Writing schema: {clean_fullname(field_schema.fullname)}') - write_schema_record(field_schema, writer, use_logical_types) + write_schema_record(field_schema, writer, use_logical_types, class_naming_strategy=class_naming_strategy) elif isinstance(field_schema, schema.EnumSchema): logger.debug(f'Writing enum: {field_schema.fullname}', field_schema.fullname) write_enum(field_schema, writer) @@ -71,14 +72,14 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a # Lookup table for fullname. for name, field_schema in names: - n = clean_fullname(field_schema.name) + classname = class_naming_strategy(name) if class_naming_strategy else f'{clean_fullname(field_schema.name)}Class' full = field_schema.fullname - writer.write(f"\n'{full}': {n}Class,") + writer.write(f"\n'{full}': {classname},") # Lookup table for names without namespace. for name, field_schema in names: - n = clean_fullname(field_schema.name) - writer.write(f"\n'{n}': {n}Class,") + classname = class_naming_strategy(name) if class_naming_strategy else f'{clean_fullname(field_schema.name)}Class' + writer.write(f"\n'{classname}': {classname},") writer.untab() writer.write('\n}\n\n') @@ -116,12 +117,13 @@ def write_populate_schemas(writer): writer.write('\n__SCHEMAS = dict((n.fullname.lstrip("."), n) for n in six.itervalues(__NAMES.names))\n') -def write_namespace_modules(ns_dict, output_folder): +def write_namespace_modules(ns_dict, output_folder, class_naming_strategy=None): """ Writes content of the generated namespace modules. A python module will be created for each namespace and will import concrete schema classes from SchemaClasses :param ns_dict: :param output_folder: + :param function class_naming_strategy: Function to compute class names from full names :return: """ for ns in six.iterkeys(ns_dict): @@ -130,19 +132,22 @@ def write_namespace_modules(ns_dict, output_folder): if ns != '': currency += '.' * len(ns.split('.')) for name in ns_dict[ns]: - f.write(f'from {currency}schema_classes import {name}Class\n') + classname = class_naming_strategy(ns + '.' + name) if class_naming_strategy else f'{name}Class' + f.write(f'from {currency}schema_classes import {classname}\n') f.write('\n\n') for name in ns_dict[ns]: - f.write(f"{name} = {name}Class\n") + if not class_naming_strategy: + f.write(f"{name} = {name}Class\n") -def write_specific_reader(record_types, output_folder, use_logical_types): +def write_specific_reader(record_types, output_folder, use_logical_types, class_naming_strategy=None): """ Writes specific reader for a avro schema into generated root module :param record_types: :param output_folder: + :param function class_naming_strategy: Function to compute class names from full names :return: """ with open(os.path.join(output_folder, "__init__.py"), "a+") as f: @@ -152,23 +157,25 @@ def write_specific_reader(record_types, output_folder, use_logical_types): writer.write('\nfrom .schema_classes import SCHEMA as get_schema_type') writer.write('\nfrom .schema_classes import _json_converter as json_converter') for t in record_types: - writer.write(f'\nfrom .schema_classes import {t.split(".")[-1]}Class') + classname = class_naming_strategy(t) if class_naming_strategy else f'{t.split(".")[-1]}Class' + writer.write(f'\nfrom .schema_classes import {classname}') writer.write('\nfrom avro.io import DatumReader') if use_logical_types: writer.write('\nfrom avrogen import logical') - write_reader_impl(record_types, writer, use_logical_types) + write_reader_impl(record_types, writer, use_logical_types, class_naming_strategy=class_naming_strategy) -def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None): +def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None, class_naming_strategy=None): """ Generates concrete classes, namespace modules, and a SpecificRecordReader for a given avro schema :param str schema_json: JSON containing avro schema :param str output_folder: Folder in which to create generated files :param list[str] custom_imports: Add additional import modules + :param function class_naming_strategy: Function to compute class names from full names :return: """ - schema_py, names = generate_schema(schema_json, use_logical_types, custom_imports) + schema_py, names = generate_schema(schema_json, use_logical_types, custom_imports, class_naming_strategy=class_naming_strategy) names = sorted(names) if not os.path.isdir(output_folder): @@ -185,5 +192,5 @@ def write_schema_files(schema_json, output_folder, use_logical_types=False, cust with open(os.path.join(output_folder, "__init__.py"), "w+") as f: pass # make sure we create this file from scratch - write_namespace_modules(ns_dict, output_folder) - write_specific_reader(names, output_folder, use_logical_types) + write_namespace_modules(ns_dict, output_folder, class_naming_strategy=class_naming_strategy) + write_specific_reader(names, output_folder, use_logical_types, class_naming_strategy=class_naming_strategy) diff --git a/tests/generator_tests.py b/tests/generator_tests.py index edfd36e..0b605fb 100644 --- a/tests/generator_tests.py +++ b/tests/generator_tests.py @@ -205,6 +205,34 @@ def test_tweet(self): self.assertTrue(hasattr(kop, 'known')) self.assertTrue(hasattr(kop, 'data')) + def test_tweet_custom(self): + def class_naming_strategy(fullname): + return 'Custom' + fullname.split('.')[-1] + + schema_json = self.read_schema('tweet.json') + avrogen.schema.write_schema_files(schema_json, self.output_dir, class_naming_strategy=class_naming_strategy) + root_module, _ = self.load_gen(self.test_name) + twitter_ns = importlib.import_module('.com.bifflabs.grok.model.twitter.avro', self.test_name) + common_ns = importlib.import_module('.com.bifflabs.grok.model.common.avro', self.test_name) + + self.assertTrue(hasattr(twitter_ns, 'CustomAvroTweet')) + self.assertTrue(hasattr(twitter_ns, 'CustomAvroTweetMetadata')) + self.assertTrue(hasattr(common_ns, 'CustomAvroPoint')) + self.assertTrue(hasattr(common_ns, 'CustomAvroDateTime')) + self.assertTrue(hasattr(common_ns, 'CustomAvroKnowableOptionString')) + self.assertTrue(hasattr(common_ns, 'CustomAvroKnowableListString')) + self.assertTrue(hasattr(common_ns, 'CustomAvroKnowableBoolean')) + self.assertTrue(hasattr(common_ns, 'CustomAvroKnowableOptionPoint')) + + self.assertIsNotNone(twitter_ns.CustomAvroTweet.construct_with_defaults()) + self.assertIsNotNone(twitter_ns.CustomAvroTweetMetadata.construct_with_defaults()) + self.assertIsNotNone(common_ns.CustomAvroPoint.construct_with_defaults()) + self.assertIsNotNone(common_ns.CustomAvroDateTime.construct_with_defaults()) + self.assertIsNotNone(common_ns.CustomAvroKnowableOptionString.construct_with_defaults()) + self.assertIsNotNone(common_ns.CustomAvroKnowableListString.construct_with_defaults()) + self.assertIsNotNone(common_ns.CustomAvroKnowableBoolean.construct_with_defaults()) + self.assertIsNotNone(common_ns.CustomAvroKnowableOptionString.construct_with_defaults()) + def test_logical(self): schema_json = self.read_schema('logical_types.json') avrogen.schema.write_schema_files(schema_json, self.output_dir, use_logical_types=True)