Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add schema class naming strategy #14

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 57 additions & 30 deletions avrogen/core_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def convert_default(idx, full_name=None, do_json=True):
return f'self.RECORD_SCHEMA.fields_dict["{idx}"].default'


def get_default(field, use_logical_types, my_full_name=None, f_name=None):
def get_default(field, use_logical_types, my_full_name=None, f_name=None, class_naming_strategy=None):
default_written = False
f_name = f_name if f_name is not None else field.name
if keyword.iskeyword(field.name):
Expand Down Expand Up @@ -85,16 +85,21 @@ def get_default(field, use_logical_types, my_full_name=None, f_name=None):
elif isinstance(default_type, schema.FixedSchema):
return 'bytes()'
elif isinstance(default_type, schema.RecordSchema):
f = clean_fullname(default_type.name)
return f'{f}Class.construct_with_defaults()'
classname = (
class_naming_strategy(default_type.fullname)
if class_naming_strategy
else f'{clean_fullname(default_type.name)}Class'
)
return f'{classname}.construct_with_defaults()'
raise AttributeError('cannot get default for field')

def write_defaults(record, writer, my_full_name=None, use_logical_types=False):
def write_defaults(record, writer, my_full_name=None, use_logical_types=False, class_naming_strategy=None):
"""
Write concrete record class's constructor part which initializes fields with default values
:param schema.RecordSchema record: Avro RecordSchema whose class we are generating
:param TabbedWriter writer: Writer to write to
:param str my_full_name: Full name of the RecordSchema we are writing. Should only be provided for protocol requests.
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
i = 0
Expand All @@ -105,36 +110,38 @@ def write_defaults(record, writer, my_full_name=None, use_logical_types=False):
f_name = field.name
if keyword.iskeyword(field.name):
f_name = field.name + get_field_type_name(field.type, use_logical_types)
default = get_default(field, use_logical_types, my_full_name=my_full_name, f_name=f_name)
default = get_default(field, use_logical_types, my_full_name=my_full_name, f_name=f_name, class_naming_strategy=class_naming_strategy)
writer.write(f'\nself.{f_name} = {default}')
something_written = True
i += 1
if not something_written:
writer.write('\npass')


def write_fields(record, writer, use_logical_types):
def write_fields(record, writer, use_logical_types, class_naming_strategy=None):
"""
Write field definitions for a given RecordSchema
:param schema.RecordSchema record: Avro RecordSchema we are generating
:param TabbedWriter writer: Writer to write to
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
writer.write('\n\n')
for field in record.fields: # type: schema.Field
write_field(field, writer, use_logical_types)
write_field(field, writer, use_logical_types, class_naming_strategy=class_naming_strategy)

def get_field_name(field, use_logical_types):
name = field.name
if keyword.iskeyword(field.name):
name = field.name + get_field_type_name(field.type, use_logical_types)
return name

def write_field(field, writer, use_logical_types):
def write_field(field, writer, use_logical_types, class_naming_strategy=None):
"""
Write a single field definition
:param field:
:param writer:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
name = get_field_name(field, use_logical_types)
Expand All @@ -152,7 +159,13 @@ def {name}(self, value: {ret_type_name}) -> None:
{set_docstring}
self._inner_dict['{raw_name}'] = value

'''.format(name=name, get_docstring=get_docstring, set_docstring=set_docstring, raw_name=field.name, ret_type_name=get_field_type_name(field.type, use_logical_types)))
'''.format(
name=name,
get_docstring=get_docstring,
set_docstring=set_docstring,
raw_name=field.name,
ret_type_name=get_field_type_name(field.type, use_logical_types, class_naming_strategy=class_naming_strategy)
))


def get_primitive_field_initializer(field_schema):
Expand All @@ -168,10 +181,11 @@ def get_primitive_field_initializer(field_schema):
return get_field_type_name(field_schema, False) + "()"


def get_field_type_name(field_schema, use_logical_types):
def get_field_type_name(field_schema, use_logical_types, class_naming_strategy=None):
"""
Gets a python type-hint for a given schema
:param schema.Schema field_schema:
:param function class_naming_strategy: Function to compute class names from full names
:return: String containing python type hint
"""
if use_logical_types and field_schema.props.get('logicalType'):
Expand All @@ -190,16 +204,21 @@ def get_field_type_name(field_schema, use_logical_types):
# For enums, we have their "class" types, but they're actually
# represented as strings. This is a decent hack to work around
# the issue.
return f'Union[str, "{field_schema.name}Class"]'
classname = class_naming_strategy(field_schema.fullname) if class_naming_strategy else f'{field_schema.name}Class'
return f'Union[str, "{classname}"]'
elif isinstance(field_schema, schema.NamedSchema):
return f'"{field_schema.name}Class"'
classname = class_naming_strategy(field_schema.fullname) if class_naming_strategy else f'{field_schema.name}Class'
return f'"{classname}"'
elif isinstance(field_schema, schema.ArraySchema):
return 'List[' + get_field_type_name(field_schema.items, use_logical_types) + ']'
elif isinstance(field_schema, schema.MapSchema):
return 'Dict[str, ' + get_field_type_name(field_schema.values, use_logical_types) + ']'
elif isinstance(field_schema, schema.UnionSchema):
type_names = [get_field_type_name(x, use_logical_types) for x in field_schema.schemas if
get_field_type_name(x, use_logical_types)]
type_names = [
get_field_type_name(x, use_logical_types, class_naming_strategy=class_naming_strategy)
for x in field_schema.schemas
if get_field_type_name(x, use_logical_types, class_naming_strategy=class_naming_strategy)
]
if len(type_names) > 1:
return 'Union[' + ', '.join(type_names) + ']'
elif len(type_names) == 1:
Expand Down Expand Up @@ -304,11 +323,12 @@ def write_get_schema(writer):
writer.write('\nreturn __SCHEMAS.get(fullname)\n\n')


def write_reader_impl(record_types, writer, use_logical_types):
def write_reader_impl(record_types, writer, use_logical_types, class_naming_strategy=None):
"""
Write specific reader implementation
:param list[schema.RecordSchema] record_types:
:param writer:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
writer.write('\n\n\nclass SpecificDatumReader(%s):' % (
Expand All @@ -317,10 +337,11 @@ def write_reader_impl(record_types, writer, use_logical_types):
writer.write('\nSCHEMA_TYPES = {')
with writer.indent():
for t in record_types:
t_class = t.split('.')[-1]
writer.write('\n"{t_class}": {t_class}Class,'.format(t_class=t_class))
writer.write('\n".{t_class}": {t_class}Class,'.format(t_class=t_class))
writer.write('\n"{f_class}": {t_class}Class,'.format(t_class=t_class, f_class=t))
t_class = class_naming_strategy(t) if class_naming_strategy else t.split(".")[-1]
classname = class_naming_strategy(t) if class_naming_strategy else f'{t.split(".")[-1]}Class'
writer.write('\n"{t_class}": {classname},'.format(t_class=t_class, classname=classname))
writer.write('\n".{t_class}": {classname},'.format(t_class=t_class, classname=classname))
writer.write('\n"{f_class}": {classname},'.format(t_class=t_class, f_class=t, classname=classname))

writer.write('\n}')
writer.write('\n\n\ndef __init__(self, readers_schema=None, **kwargs):')
Expand Down Expand Up @@ -371,16 +392,21 @@ def generate_namespace_modules(names, output_folder):
return ns_dict


def write_schema_record(record, writer, use_logical_types):
def write_schema_record(record, writer, use_logical_types, class_naming_strategy=None):
"""
Writes class representing Avro record schema
:param avro.schema.RecordSchema record:
:param TabbedWriter writer:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""

_, type_name = ns_.split_fullname(record.fullname)
writer.write('''\nclass {name}Class(DictWrapper):'''.format(name=type_name))
fullname = record.fullname
classname = (
class_naming_strategy(fullname)
if class_naming_strategy
else f'{ns_.split_fullname(fullname)[-1]}Class'
)
writer.write("""\nclass {name}(DictWrapper):""".format(name=classname))

with writer.indent():
writer.write('\n')
Expand All @@ -390,24 +416,24 @@ def write_schema_record(record, writer, use_logical_types):
writer.write('# No docs available.')
writer.write('\n\nRECORD_SCHEMA = get_schema_type("%s")' % (record.fullname))

write_record_init(record, writer, use_logical_types)
write_record_init(record, writer, use_logical_types, class_naming_strategy=class_naming_strategy)

write_fields(record, writer, use_logical_types)
write_fields(record, writer, use_logical_types, class_naming_strategy=class_naming_strategy)


def write_record_init(record, writer, use_logical_types):
def write_record_init(record, writer, use_logical_types, class_naming_strategy=None):
writer.write('\ndef __init__(self,')
with writer.indent():
delayed_lines = []
default_map: Dict[str, str] = {}
for field in record.fields: # type: schema.Field
name = get_field_name(field, use_logical_types)
ret_type_name = get_field_type_name(field.type, use_logical_types)
ret_type_name = get_field_type_name(field.type, use_logical_types, class_naming_strategy=class_naming_strategy)
default_type, nullable = find_type_of_default(field.type)

if not nullable and field.has_default:
# print(record.name, field.name, field.default)
default = get_default(field, use_logical_types, f_name=field.name)
default = get_default(field, use_logical_types, f_name=field.name, class_naming_strategy=class_naming_strategy)
default_map[name] = default
ret_type_name = f"Optional[{ret_type_name}]"
nullable = True
Expand Down Expand Up @@ -437,7 +463,8 @@ def write_record_init(record, writer, use_logical_types):
writer.write(f'\nself.{name} = {name}')

writer.write('\n\n@classmethod')
writer.write(f'\ndef construct_with_defaults(cls) -> "{record.name}Class":')
classname = class_naming_strategy(record.fullname) if class_naming_strategy else f'{record.name}Class'
writer.write(f'\ndef construct_with_defaults(cls) -> "{classname}":')
with writer.indent():
writer.write('\nself = cls.construct({})')
writer.write('\nself._restore_defaults()')
Expand All @@ -447,7 +474,7 @@ def write_record_init(record, writer, use_logical_types):
writer.write('\n')
writer.write(f'\ndef _restore_defaults(self) -> None:')
with writer.indent():
write_defaults(record, writer, use_logical_types=use_logical_types)
write_defaults(record, writer, use_logical_types=use_logical_types, class_naming_strategy=class_naming_strategy)


def write_enum(enum, writer):
Expand Down
39 changes: 23 additions & 16 deletions avrogen/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
logger.setLevel(logging.INFO)


def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None):
def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None, class_naming_strategy=None):
"""
Generate file containing concrete classes for RecordSchemas in given avro schema json
:param str schema_json: JSON representing avro schema
:param list[str] custom_imports: Add additional import modules
:param str avro_json_converter: AvroJsonConverter type to use for default values
:param function class_naming_strategy: Function to compute class names from full names
:return Dict[str, str]:
"""

Expand Down Expand Up @@ -61,7 +62,7 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a
current_namespace = namespace
if isinstance(field_schema, schema.RecordSchema):
logger.debug(f'Writing schema: {clean_fullname(field_schema.fullname)}')
write_schema_record(field_schema, writer, use_logical_types)
write_schema_record(field_schema, writer, use_logical_types, class_naming_strategy=class_naming_strategy)
elif isinstance(field_schema, schema.EnumSchema):
logger.debug(f'Writing enum: {field_schema.fullname}', field_schema.fullname)
write_enum(field_schema, writer)
Expand All @@ -71,14 +72,14 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a

# Lookup table for fullname.
for name, field_schema in names:
n = clean_fullname(field_schema.name)
classname = class_naming_strategy(name) if class_naming_strategy else f'{clean_fullname(field_schema.name)}Class'
full = field_schema.fullname
writer.write(f"\n'{full}': {n}Class,")
writer.write(f"\n'{full}': {classname},")

# Lookup table for names without namespace.
for name, field_schema in names:
n = clean_fullname(field_schema.name)
writer.write(f"\n'{n}': {n}Class,")
classname = class_naming_strategy(name) if class_naming_strategy else f'{clean_fullname(field_schema.name)}Class'
writer.write(f"\n'{classname}': {classname},")

writer.untab()
writer.write('\n}\n\n')
Expand Down Expand Up @@ -116,12 +117,13 @@ def write_populate_schemas(writer):
writer.write('\n__SCHEMAS = dict((n.fullname.lstrip("."), n) for n in six.itervalues(__NAMES.names))\n')


def write_namespace_modules(ns_dict, output_folder):
def write_namespace_modules(ns_dict, output_folder, class_naming_strategy=None):
"""
Writes content of the generated namespace modules. A python module will be created for each namespace
and will import concrete schema classes from SchemaClasses
:param ns_dict:
:param output_folder:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
for ns in six.iterkeys(ns_dict):
Expand All @@ -130,19 +132,22 @@ def write_namespace_modules(ns_dict, output_folder):
if ns != '':
currency += '.' * len(ns.split('.'))
for name in ns_dict[ns]:
f.write(f'from {currency}schema_classes import {name}Class\n')
classname = class_naming_strategy(ns + '.' + name) if class_naming_strategy else f'{name}Class'
f.write(f'from {currency}schema_classes import {classname}\n')

f.write('\n\n')

for name in ns_dict[ns]:
f.write(f"{name} = {name}Class\n")
if not class_naming_strategy:
f.write(f"{name} = {name}Class\n")


def write_specific_reader(record_types, output_folder, use_logical_types):
def write_specific_reader(record_types, output_folder, use_logical_types, class_naming_strategy=None):
"""
Writes specific reader for a avro schema into generated root module
:param record_types:
:param output_folder:
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
with open(os.path.join(output_folder, "__init__.py"), "a+") as f:
Expand All @@ -152,23 +157,25 @@ def write_specific_reader(record_types, output_folder, use_logical_types):
writer.write('\nfrom .schema_classes import SCHEMA as get_schema_type')
writer.write('\nfrom .schema_classes import _json_converter as json_converter')
for t in record_types:
writer.write(f'\nfrom .schema_classes import {t.split(".")[-1]}Class')
classname = class_naming_strategy(t) if class_naming_strategy else f'{t.split(".")[-1]}Class'
writer.write(f'\nfrom .schema_classes import {classname}')
writer.write('\nfrom avro.io import DatumReader')
if use_logical_types:
writer.write('\nfrom avrogen import logical')

write_reader_impl(record_types, writer, use_logical_types)
write_reader_impl(record_types, writer, use_logical_types, class_naming_strategy=class_naming_strategy)


def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None):
def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None, class_naming_strategy=None):
"""
Generates concrete classes, namespace modules, and a SpecificRecordReader for a given avro schema
:param str schema_json: JSON containing avro schema
:param str output_folder: Folder in which to create generated files
:param list[str] custom_imports: Add additional import modules
:param function class_naming_strategy: Function to compute class names from full names
:return:
"""
schema_py, names = generate_schema(schema_json, use_logical_types, custom_imports)
schema_py, names = generate_schema(schema_json, use_logical_types, custom_imports, class_naming_strategy=class_naming_strategy)
names = sorted(names)

if not os.path.isdir(output_folder):
Expand All @@ -185,5 +192,5 @@ def write_schema_files(schema_json, output_folder, use_logical_types=False, cust
with open(os.path.join(output_folder, "__init__.py"), "w+") as f:
pass # make sure we create this file from scratch

write_namespace_modules(ns_dict, output_folder)
write_specific_reader(names, output_folder, use_logical_types)
write_namespace_modules(ns_dict, output_folder, class_naming_strategy=class_naming_strategy)
write_specific_reader(names, output_folder, use_logical_types, class_naming_strategy=class_naming_strategy)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

setup(
name="avro-gen3",
version="0.7.6",
version="0.7.7",
description="Avro record class and specific record reader generator",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
Loading