Skip to content

Commit

Permalink
add schema class naming strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasrcosta authored and lcosta-ch committed Oct 3, 2022
1 parent 7958ac1 commit 9e3740f
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 46 deletions.
87 changes: 57 additions & 30 deletions avrogen/core_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def convert_default(idx, full_name=None, do_json=True):
return f'self.RECORD_SCHEMA.fields_dict["{idx}"].default'


def get_default(field, use_logical_types, my_full_name=None, f_name=None):
def get_default(field, use_logical_types, my_full_name=None, f_name=None, class_naming_strategy=None):
default_written = False
f_name = f_name if f_name is not None else field.name
if keyword.iskeyword(field.name):
Expand Down Expand Up @@ -85,16 +85,21 @@ def get_default(field, use_logical_types, my_full_name=None, f_name=None):
elif isinstance(default_type, schema.FixedSchema):
return 'bytes()'
elif isinstance(default_type, schema.RecordSchema):
f = clean_fullname(default_type.name)
return f'{f}Class.construct_with_defaults()'
classname = (
class_naming_strategy(default_type.fullname)
if class_naming_strategy
else f'{clean_fullname(default_type.name)}Class'
)
return f'{classname}.construct_with_defaults()'
raise AttributeError('cannot get default for field')

def write_defaults(record, writer, my_full_name=None, use_logical_types=False):
def write_defaults(record, writer, my_full_name=None, use_logical_types=False, class_naming_strategy=None):
"""
Write concrete record class's constructor part which initializes fields with default values
:param schema.RecordSchema record: Avro RecordSchema whose class we are generating
:param TabbedWriter writer: Writer to write to
:param str my_full_name: Full name of the RecordSchema we are writing. Should only be provided for protocol requests.
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
i = 0
Expand All @@ -105,36 +110,38 @@ def write_defaults(record, writer, my_full_name=None, use_logical_types=False):
f_name = field.name
if keyword.iskeyword(field.name):
f_name = field.name + get_field_type_name(field.type, use_logical_types)
default = get_default(field, use_logical_types, my_full_name=my_full_name, f_name=f_name)
default = get_default(field, use_logical_types, my_full_name=my_full_name, f_name=f_name, class_naming_strategy=class_naming_strategy)
writer.write(f'\nself.{f_name} = {default}')
something_written = True
i += 1
if not something_written:
writer.write('\npass')


def write_fields(record, writer, use_logical_types):
def write_fields(record, writer, use_logical_types, class_naming_strategy=None):
"""
Write field definitions for a given RecordSchema
:param schema.RecordSchema record: Avro RecordSchema we are generating
:param TabbedWriter writer: Writer to write to
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
writer.write('\n\n')
for field in record.fields: # type: schema.Field
write_field(field, writer, use_logical_types)
write_field(field, writer, use_logical_types, class_naming_strategy=class_naming_strategy)

def get_field_name(field, use_logical_types):
name = field.name
if keyword.iskeyword(field.name):
name = field.name + get_field_type_name(field.type, use_logical_types)
return name

def write_field(field, writer, use_logical_types):
def write_field(field, writer, use_logical_types, class_naming_strategy=None):
"""
Write a single field definition
:param field:
:param writer:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
name = get_field_name(field, use_logical_types)
Expand All @@ -152,7 +159,13 @@ def {name}(self, value: {ret_type_name}) -> None:
{set_docstring}
self._inner_dict['{raw_name}'] = value
'''.format(name=name, get_docstring=get_docstring, set_docstring=set_docstring, raw_name=field.name, ret_type_name=get_field_type_name(field.type, use_logical_types)))
'''.format(
name=name,
get_docstring=get_docstring,
set_docstring=set_docstring,
raw_name=field.name,
ret_type_name=get_field_type_name(field.type, use_logical_types, class_naming_strategy=class_naming_strategy)
))


def get_primitive_field_initializer(field_schema):
Expand All @@ -168,10 +181,11 @@ def get_primitive_field_initializer(field_schema):
return get_field_type_name(field_schema, False) + "()"


def get_field_type_name(field_schema, use_logical_types):
def get_field_type_name(field_schema, use_logical_types, class_naming_strategy=None):
"""
Gets a python type-hint for a given schema
:param schema.Schema field_schema:
:param function class_naming_strategy: Function to compute class names from full names
:return: String containing python type hint
"""
if use_logical_types and field_schema.props.get('logicalType'):
Expand All @@ -190,16 +204,21 @@ def get_field_type_name(field_schema, use_logical_types):
# For enums, we have their "class" types, but they're actually
# represented as strings. This is a decent hack to work around
# the issue.
return f'Union[str, "{field_schema.name}Class"]'
classname = class_naming_strategy(field_schema.fullname) if class_naming_strategy else f'{field_schema.name}Class'
return f'Union[str, "{classname}"]'
elif isinstance(field_schema, schema.NamedSchema):
return f'"{field_schema.name}Class"'
classname = class_naming_strategy(field_schema.fullname) if class_naming_strategy else f'{field_schema.name}Class'
return f'"{classname}"'
elif isinstance(field_schema, schema.ArraySchema):
return 'List[' + get_field_type_name(field_schema.items, use_logical_types) + ']'
elif isinstance(field_schema, schema.MapSchema):
return 'Dict[str, ' + get_field_type_name(field_schema.values, use_logical_types) + ']'
elif isinstance(field_schema, schema.UnionSchema):
type_names = [get_field_type_name(x, use_logical_types) for x in field_schema.schemas if
get_field_type_name(x, use_logical_types)]
type_names = [
get_field_type_name(x, use_logical_types, class_naming_strategy=class_naming_strategy)
for x in field_schema.schemas
if get_field_type_name(x, use_logical_types, class_naming_strategy=class_naming_strategy)
]
if len(type_names) > 1:
return 'Union[' + ', '.join(type_names) + ']'
elif len(type_names) == 1:
Expand Down Expand Up @@ -304,11 +323,12 @@ def write_get_schema(writer):
writer.write('\nreturn __SCHEMAS.get(fullname)\n\n')


def write_reader_impl(record_types, writer, use_logical_types):
def write_reader_impl(record_types, writer, use_logical_types, class_naming_strategy=None):
"""
Write specific reader implementation
:param list[schema.RecordSchema] record_types:
:param writer:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
writer.write('\n\n\nclass SpecificDatumReader(%s):' % (
Expand All @@ -317,10 +337,11 @@ def write_reader_impl(record_types, writer, use_logical_types):
writer.write('\nSCHEMA_TYPES = {')
with writer.indent():
for t in record_types:
t_class = t.split('.')[-1]
writer.write('\n"{t_class}": {t_class}Class,'.format(t_class=t_class))
writer.write('\n".{t_class}": {t_class}Class,'.format(t_class=t_class))
writer.write('\n"{f_class}": {t_class}Class,'.format(t_class=t_class, f_class=t))
t_class = class_naming_strategy(t) if class_naming_strategy else t.split(".")[-1]
classname = class_naming_strategy(t) if class_naming_strategy else f'{t.split(".")[-1]}Class'
writer.write('\n"{t_class}": {classname},'.format(t_class=t_class, classname=classname))
writer.write('\n".{t_class}": {classname},'.format(t_class=t_class, classname=classname))
writer.write('\n"{f_class}": {classname},'.format(t_class=t_class, f_class=t, classname=classname))

writer.write('\n}')
writer.write('\n\n\ndef __init__(self, readers_schema=None, **kwargs):')
Expand Down Expand Up @@ -371,16 +392,21 @@ def generate_namespace_modules(names, output_folder):
return ns_dict


def write_schema_record(record, writer, use_logical_types):
def write_schema_record(record, writer, use_logical_types, class_naming_strategy=None):
"""
Writes class representing Avro record schema
:param avro.schema.RecordSchema record:
:param TabbedWriter writer:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""

_, type_name = ns_.split_fullname(record.fullname)
writer.write('''\nclass {name}Class(DictWrapper):'''.format(name=type_name))
fullname = record.fullname
classname = (
class_naming_strategy(fullname)
if class_naming_strategy
else f'{ns_.split_fullname(fullname)[-1]}Class'
)
writer.write("""\nclass {name}(DictWrapper):""".format(name=classname))

with writer.indent():
writer.write('\n')
Expand All @@ -390,24 +416,24 @@ def write_schema_record(record, writer, use_logical_types):
writer.write('# No docs available.')
writer.write('\n\nRECORD_SCHEMA = get_schema_type("%s")' % (record.fullname))

write_record_init(record, writer, use_logical_types)
write_record_init(record, writer, use_logical_types, class_naming_strategy=class_naming_strategy)

write_fields(record, writer, use_logical_types)
write_fields(record, writer, use_logical_types, class_naming_strategy=class_naming_strategy)


def write_record_init(record, writer, use_logical_types):
def write_record_init(record, writer, use_logical_types, class_naming_strategy=None):
writer.write('\ndef __init__(self,')
with writer.indent():
delayed_lines = []
default_map: Dict[str, str] = {}
for field in record.fields: # type: schema.Field
name = get_field_name(field, use_logical_types)
ret_type_name = get_field_type_name(field.type, use_logical_types)
ret_type_name = get_field_type_name(field.type, use_logical_types, class_naming_strategy=class_naming_strategy)
default_type, nullable = find_type_of_default(field.type)

if not nullable and field.has_default:
# print(record.name, field.name, field.default)
default = get_default(field, use_logical_types, f_name=field.name)
default = get_default(field, use_logical_types, f_name=field.name, class_naming_strategy=class_naming_strategy)
default_map[name] = default
ret_type_name = f"Optional[{ret_type_name}]"
nullable = True
Expand Down Expand Up @@ -437,7 +463,8 @@ def write_record_init(record, writer, use_logical_types):
writer.write(f'\nself.{name} = {name}')

writer.write('\n\n@classmethod')
writer.write(f'\ndef construct_with_defaults(cls) -> "{record.name}Class":')
classname = class_naming_strategy(record.fullname) if class_naming_strategy else f'{record.name}Class'
writer.write(f'\ndef construct_with_defaults(cls) -> "{classname}":')
with writer.indent():
writer.write('\nself = cls.construct({})')
writer.write('\nself._restore_defaults()')
Expand All @@ -447,7 +474,7 @@ def write_record_init(record, writer, use_logical_types):
writer.write('\n')
writer.write(f'\ndef _restore_defaults(self) -> None:')
with writer.indent():
write_defaults(record, writer, use_logical_types=use_logical_types)
write_defaults(record, writer, use_logical_types=use_logical_types, class_naming_strategy=class_naming_strategy)


def write_enum(enum, writer):
Expand Down
39 changes: 23 additions & 16 deletions avrogen/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
logger.setLevel(logging.INFO)


def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None):
def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None, class_naming_strategy=None):
"""
Generate file containing concrete classes for RecordSchemas in given avro schema json
:param str schema_json: JSON representing avro schema
:param list[str] custom_imports: Add additional import modules
:param str avro_json_converter: AvroJsonConverter type to use for default values
:param function class_naming_strategy: Function to compute class names from full names
:return Dict[str, str]:
"""

Expand Down Expand Up @@ -61,7 +62,7 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a
current_namespace = namespace
if isinstance(field_schema, schema.RecordSchema):
logger.debug(f'Writing schema: {clean_fullname(field_schema.fullname)}')
write_schema_record(field_schema, writer, use_logical_types)
write_schema_record(field_schema, writer, use_logical_types, class_naming_strategy=class_naming_strategy)
elif isinstance(field_schema, schema.EnumSchema):
logger.debug(f'Writing enum: {field_schema.fullname}', field_schema.fullname)
write_enum(field_schema, writer)
Expand All @@ -71,14 +72,14 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a

# Lookup table for fullname.
for name, field_schema in names:
n = clean_fullname(field_schema.name)
classname = class_naming_strategy(name) if class_naming_strategy else f'{clean_fullname(field_schema.name)}Class'
full = field_schema.fullname
writer.write(f"\n'{full}': {n}Class,")
writer.write(f"\n'{full}': {classname},")

# Lookup table for names without namespace.
for name, field_schema in names:
n = clean_fullname(field_schema.name)
writer.write(f"\n'{n}': {n}Class,")
classname = class_naming_strategy(name) if class_naming_strategy else f'{clean_fullname(field_schema.name)}Class'
writer.write(f"\n'{classname}': {classname},")

writer.untab()
writer.write('\n}\n\n')
Expand Down Expand Up @@ -116,12 +117,13 @@ def write_populate_schemas(writer):
writer.write('\n__SCHEMAS = dict((n.fullname.lstrip("."), n) for n in six.itervalues(__NAMES.names))\n')


def write_namespace_modules(ns_dict, output_folder):
def write_namespace_modules(ns_dict, output_folder, class_naming_strategy=None):
"""
Writes content of the generated namespace modules. A python module will be created for each namespace
and will import concrete schema classes from SchemaClasses
:param ns_dict:
:param output_folder:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
for ns in six.iterkeys(ns_dict):
Expand All @@ -130,19 +132,22 @@ def write_namespace_modules(ns_dict, output_folder):
if ns != '':
currency += '.' * len(ns.split('.'))
for name in ns_dict[ns]:
f.write(f'from {currency}schema_classes import {name}Class\n')
classname = class_naming_strategy(ns + '.' + name) if class_naming_strategy else f'{name}Class'
f.write(f'from {currency}schema_classes import {classname}\n')

f.write('\n\n')

for name in ns_dict[ns]:
f.write(f"{name} = {name}Class\n")
if not class_naming_strategy:
f.write(f"{name} = {name}Class\n")


def write_specific_reader(record_types, output_folder, use_logical_types):
def write_specific_reader(record_types, output_folder, use_logical_types, class_naming_strategy=None):
"""
Writes specific reader for a avro schema into generated root module
:param record_types:
:param output_folder:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
with open(os.path.join(output_folder, "__init__.py"), "a+") as f:
Expand All @@ -152,23 +157,25 @@ def write_specific_reader(record_types, output_folder, use_logical_types):
writer.write('\nfrom .schema_classes import SCHEMA as get_schema_type')
writer.write('\nfrom .schema_classes import _json_converter as json_converter')
for t in record_types:
writer.write(f'\nfrom .schema_classes import {t.split(".")[-1]}Class')
classname = class_naming_strategy(t) if class_naming_strategy else f'{t.split(".")[-1]}Class'
writer.write(f'\nfrom .schema_classes import {classname}')
writer.write('\nfrom avro.io import DatumReader')
if use_logical_types:
writer.write('\nfrom avrogen import logical')

write_reader_impl(record_types, writer, use_logical_types)
write_reader_impl(record_types, writer, use_logical_types, class_naming_strategy=class_naming_strategy)


def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None):
def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None, class_naming_strategy=None):
"""
Generates concrete classes, namespace modules, and a SpecificRecordReader for a given avro schema
:param str schema_json: JSON containing avro schema
:param str output_folder: Folder in which to create generated files
:param list[str] custom_imports: Add additional import modules
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
schema_py, names = generate_schema(schema_json, use_logical_types, custom_imports)
schema_py, names = generate_schema(schema_json, use_logical_types, custom_imports, class_naming_strategy=class_naming_strategy)
names = sorted(names)

if not os.path.isdir(output_folder):
Expand All @@ -185,5 +192,5 @@ def write_schema_files(schema_json, output_folder, use_logical_types=False, cust
with open(os.path.join(output_folder, "__init__.py"), "w+") as f:
pass # make sure we create this file from scratch

write_namespace_modules(ns_dict, output_folder)
write_specific_reader(names, output_folder, use_logical_types)
write_namespace_modules(ns_dict, output_folder, class_naming_strategy=class_naming_strategy)
write_specific_reader(names, output_folder, use_logical_types, class_naming_strategy=class_naming_strategy)
Loading

0 comments on commit 9e3740f

Please sign in to comment.