diff --git a/avrogen/__init__.py b/avrogen/__init__.py index 323679a..1853494 100644 --- a/avrogen/__init__.py +++ b/avrogen/__init__.py @@ -2,6 +2,10 @@ from .schema import write_schema_files from .protocol import generate_protocol from .protocol import write_protocol_files -__all__ = ['generate_schema', 'generate_protocol', 'write_schema_files', 'write_protocol_files'] - +__all__ = [ + "generate_schema", + "generate_protocol", + "write_schema_files", + "write_protocol_files", +] diff --git a/avrogen/__main__.py b/avrogen/__main__.py index 647dd8d..c13be44 100644 --- a/avrogen/__main__.py +++ b/avrogen/__main__.py @@ -10,11 +10,11 @@ def main(): if "-o" in argv: avdl, _, output = argv[1:] else: - avdl, output = argv[1], './' - write_protocol_files(open(avdl, 'r').read(), output) + avdl, output = argv[1], "./" + write_protocol_files(open(avdl, "r").read(), output) -if __name__ == '__main__': +if __name__ == "__main__": try: exit(main()) except Exception: diff --git a/avrogen/avrojson.py b/avrogen/avrojson.py index 79ced34..7b8e662 100644 --- a/avrogen/avrojson.py +++ b/avrogen/avrojson.py @@ -1,30 +1,33 @@ -from avrogen.core_writer import find_type_of_default import collections + import six +from avro import io, schema + +from avrogen.core_writer import find_type_of_default from . import logical from .dict_wrapper import DictWrapper -from avro import schema -from avro import io io_validate = io.validate try: - from avro.io import SchemaResolutionException, AvroTypeException + from avro.io import AvroTypeException, SchemaResolutionException except ImportError: - from avro.errors import SchemaResolutionException, AvroTypeException + from avro.errors import AvroTypeException, SchemaResolutionException _PRIMITIVE_TYPES = set(schema.PRIMITIVE_TYPES) _json_converter = None _json_converter_tuples = None + def set_global_json_converter(json_converter: "AvroJsonConverter") -> None: global _json_converter _json_converter = json_converter global _json_converter_tuples _json_converter_tuples = json_converter.with_tuple_union(True) + def get_global_json_converter(tuples: bool = False) -> "AvroJsonConverter": if tuples: assert _json_converter_tuples @@ -34,30 +37,54 @@ def get_global_json_converter(tuples: bool = False) -> "AvroJsonConverter": class AvroJsonConverter(object): - def __init__(self, use_logical_types=False, logical_types=logical.DEFAULT_LOGICAL_TYPES, fastavro: bool = False, schema_types=None): + def __init__( + self, + use_logical_types=False, + logical_types=logical.DEFAULT_LOGICAL_TYPES, + fastavro: bool = False, + schema_types=None, + ): self.use_logical_types = use_logical_types self.logical_types = logical_types or {} self.schema_types = schema_types or {} self.fastavro = fastavro - - def with_tuple_union(self, tuples=True) -> 'AvroJsonConverter': - return AvroJsonConverter(self.use_logical_types, self.logical_types, tuples, self.schema_types) + + def with_tuple_union(self, tuples=True) -> "AvroJsonConverter": + return AvroJsonConverter( + self.use_logical_types, self.logical_types, tuples, self.schema_types + ) def validate(self, expected_schema, datum, skip_logical_types=False) -> bool: - if self.use_logical_types and expected_schema.props.get('logicalType') and not skip_logical_types \ - and expected_schema.props.get('logicalType') in self.logical_types: - return self.logical_types[expected_schema.props.get('logicalType')].can_convert(expected_schema) \ - and self.logical_types[expected_schema.props.get('logicalType')].validate(expected_schema, datum) + if ( + self.use_logical_types + and expected_schema.props.get("logicalType") + and not skip_logical_types + and expected_schema.props.get("logicalType") in self.logical_types + ): + return self.logical_types[ + expected_schema.props.get("logicalType") + ].can_convert(expected_schema) and self.logical_types[ + expected_schema.props.get("logicalType") + ].validate( + expected_schema, datum + ) schema_type = expected_schema.type - if schema_type == 'array': - return (isinstance(datum, list) and - all(self.validate(expected_schema.items, d, skip_logical_types) for d in datum)) - elif schema_type == 'map': - return (isinstance(datum, dict) and - False not in [isinstance(k, six.string_types) for k in datum.keys()] and - False not in - [self.validate(expected_schema.values, v, skip_logical_types) for v in datum.values()]) - elif schema_type in ['union', 'error_union']: + if schema_type == "array": + return isinstance(datum, list) and all( + self.validate(expected_schema.items, d, skip_logical_types) + for d in datum + ) + elif schema_type == "map": + return ( + isinstance(datum, dict) + and False not in [isinstance(k, six.string_types) for k in datum.keys()] + and False + not in [ + self.validate(expected_schema.values, v, skip_logical_types) + for v in datum.values() + ] + ) + elif schema_type in ["union", "error_union"]: if isinstance(datum, DictWrapper): # Match the type based on the declared schema. data_schema = self._get_record_schema_if_available(datum) @@ -75,7 +102,9 @@ def validate(self, expected_schema, datum, skip_logical_types=False) -> bool: return None value_type = items[0][0] value = items[0][1] - elif self.fastavro and (isinstance(datum, list) or isinstance(datum, tuple)): + elif self.fastavro and ( + isinstance(datum, list) or isinstance(datum, tuple) + ): if len(datum) == 2: value_type = datum[0] value = datum[1] @@ -87,23 +116,38 @@ def validate(self, expected_schema, datum, skip_logical_types=False) -> bool: return True # If the specialized validation fails, we still attempt normal validation. - return any(self.validate(s, datum, skip_logical_types) for s in expected_schema.schemas) - elif schema_type in ['record', 'error', 'request']: + return any( + self.validate(s, datum, skip_logical_types) + for s in expected_schema.schemas + ) + elif schema_type in ["record", "error", "request"]: if isinstance(datum, dict): - return all(self.validate(f.type, datum.get(f.name, f.default) if f.has_default else datum.get(f.name), skip_logical_types) for f in expected_schema.fields) + return all( + self.validate( + f.type, + datum.get(f.name, f.default) + if f.has_default + else datum.get(f.name), + skip_logical_types, + ) + for f in expected_schema.fields + ) elif isinstance(datum, DictWrapper): # DictWrapper types should have defaults initialized already. - return all(self.validate(f.type, datum.get(f.name), skip_logical_types) for f in expected_schema.fields) + return all( + self.validate(f.type, datum.get(f.name), skip_logical_types) + for f in expected_schema.fields + ) else: return False # PERF: We're basically "inlining" this logic from avro to avoid a few extra function calls. # This seems to have a ~10-15% impact on validation speed. - elif schema_type == 'null': + elif schema_type == "null": return datum is None - elif schema_type == 'string': + elif schema_type == "string": return isinstance(datum, str) - elif schema_type == 'bytes': + elif schema_type == "bytes": # Specialization for bytes, which we are encoding as strings in JSON. if not self.fastavro and isinstance(datum, str): return True @@ -112,8 +156,8 @@ def validate(self, expected_schema, datum, skip_logical_types=False) -> bool: else: # Defer to underlying avro lib for other types. return io_validate(expected_schema, datum) - - assert False, 'this code should be unreachable' + + assert False, "this code should be unreachable" def from_json_object(self, json_obj, writers_schema=None, readers_schema=None): if readers_schema is None: @@ -122,10 +166,12 @@ def from_json_object(self, json_obj, writers_schema=None, readers_schema=None): writers_schema = readers_schema if writers_schema is None: - raise Exception('At least one schema must be specified') + raise Exception("At least one schema must be specified") if not writers_schema.match(readers_schema): - raise SchemaResolutionException('Could not match schemas', writers_schema, readers_schema) + raise SchemaResolutionException( + "Could not match schemas", writers_schema, readers_schema + ) return self._generic_from_json(json_obj, writers_schema, readers_schema) @@ -134,7 +180,9 @@ def to_json_object(self, data_obj, writers_schema=None): writers_schema = self._get_record_schema_if_available(data_obj) if writers_schema is None: - raise Exception("Could not determine writer's schema from the object type and schema was not passed") + raise Exception( + "Could not determine writer's schema from the object type and schema was not passed" + ) assert isinstance(writers_schema, schema.Schema) if not self.validate(writers_schema, data_obj): @@ -144,40 +192,44 @@ def to_json_object(self, data_obj, writers_schema=None): def _fullname(self, schema_): if isinstance(schema_, schema.NamedSchema): - return schema_.fullname.lstrip('.') + return schema_.fullname.lstrip(".") return schema_.type def _get_record_schema_if_available(self, data_obj): - if hasattr(type(data_obj), 'RECORD_SCHEMA'): + if hasattr(type(data_obj), "RECORD_SCHEMA"): return type(data_obj).RECORD_SCHEMA return None def _generic_to_json(self, data_obj, writers_schema, was_within_array=False): - if self.use_logical_types and writers_schema.props.get('logicalType'): - lt = self.logical_types.get(writers_schema.props.get('logicalType')) # type: logical.LogicalTypeProcessor + if self.use_logical_types and writers_schema.props.get("logicalType"): + lt = self.logical_types.get( + writers_schema.props.get("logicalType") + ) # type: logical.LogicalTypeProcessor if lt.can_convert(writers_schema): if lt.validate(writers_schema, data_obj): data_obj = lt.convert(writers_schema, data_obj) else: raise schema.AvroException( - 'Wrong object for %s logical type' % writers_schema.props.get('logicalType')) + "Wrong object for %s logical type" + % writers_schema.props.get("logicalType") + ) if writers_schema.type in _PRIMITIVE_TYPES: result = self._primitive_to_json(data_obj, writers_schema) - elif writers_schema.type == 'fixed': + elif writers_schema.type == "fixed": result = self._fixed_to_json(data_obj, writers_schema) - elif writers_schema.type == 'enum': + elif writers_schema.type == "enum": result = self._enum_to_json(data_obj, writers_schema) - elif writers_schema.type == 'array': + elif writers_schema.type == "array": result = self._array_to_json(data_obj, writers_schema) - elif writers_schema.type == 'map': + elif writers_schema.type == "map": result = self._map_to_json(data_obj, writers_schema) - elif writers_schema.type in ['record', 'error', 'request']: + elif writers_schema.type in ["record", "error", "request"]: result = self._record_to_json(data_obj, writers_schema) - elif writers_schema.type in ['union', 'error_union']: + elif writers_schema.type in ["union", "error_union"]: result = self._union_to_json(data_obj, writers_schema, was_within_array) else: - raise schema.AvroException('Invalid schema type: %s' % writers_schema.type) + raise schema.AvroException("Invalid schema type: %s" % writers_schema.type) return result @@ -193,10 +245,16 @@ def _enum_to_json(self, data_obj, writers_schema): return data_obj def _array_to_json(self, data_obj, writers_schema): - return [self._generic_to_json(x, writers_schema.items, was_within_array=True) for x in data_obj] + return [ + self._generic_to_json(x, writers_schema.items, was_within_array=True) + for x in data_obj + ] def _map_to_json(self, data_obj, writers_schema): - return {name: self._generic_to_json(x, writers_schema.values) for name, x in six.iteritems(data_obj)} + return { + name: self._generic_to_json(x, writers_schema.values) + for name, x in six.iteritems(data_obj) + } def _record_to_json(self, data_obj, writers_schema): result = collections.OrderedDict() @@ -216,10 +274,16 @@ def _record_to_json(self, data_obj, writers_schema): result[field.name] = self._generic_to_json(obj, field.type) return result - + def _is_unambiguous_union(self, writers_schema) -> bool: - if any(isinstance(candidate_schema, schema.EnumSchema) for candidate_schema in writers_schema.schemas): - if len(writers_schema.schemas) == 2 and any(candidate_schema.type == 'null' for candidate_schema in writers_schema.schemas): + if any( + isinstance(candidate_schema, schema.EnumSchema) + for candidate_schema in writers_schema.schemas + ): + if len(writers_schema.schemas) == 2 and any( + candidate_schema.type == "null" + for candidate_schema in writers_schema.schemas + ): # Enums and null do not conflict, so this is fine. return True else: @@ -228,7 +292,7 @@ def _is_unambiguous_union(self, writers_schema) -> bool: advanced_count = 0 for candidate_schema in writers_schema.schemas: - if candidate_schema.type != 'null': + if candidate_schema.type != "null": advanced_count += 1 if advanced_count <= 1: return True @@ -236,26 +300,33 @@ def _is_unambiguous_union(self, writers_schema) -> bool: def _union_to_json(self, data_obj, writers_schema, was_within_array=False): index_of_schema = -1 + + # Check for exact matches first. data_schema = self._get_record_schema_if_available(data_obj) for i, candidate_schema in enumerate(writers_schema.schemas): - # Check for exact matches first. if data_schema and candidate_schema.fullname == data_schema.fullname: index_of_schema = i break + if index_of_schema < 0: # Fallback to schema guessing based on validation. - if self.validate(candidate_schema, data_obj): - index_of_schema = i - if candidate_schema.type == 'boolean': - break + for i, candidate_schema in enumerate(writers_schema.schemas): + if self.validate(candidate_schema, data_obj): + index_of_schema = i + if candidate_schema.type == "boolean": + break if index_of_schema < 0: raise AvroTypeException(writers_schema, data_obj) candidate_schema = writers_schema.schemas[index_of_schema] - if candidate_schema.type == 'null': + if candidate_schema.type == "null": return None - + output_obj = self._generic_to_json(data_obj, candidate_schema) - if not self.fastavro and not was_within_array and self._is_unambiguous_union(writers_schema): + if ( + not self.fastavro + and not was_within_array + and self._is_unambiguous_union(writers_schema) + ): # If the union is unambiguous, we can avoid wrapping it in # an extra layer of tuples or dicts. Fastavro doesn't like this though. # Arrays with unions inside must specify the type. @@ -266,43 +337,49 @@ def _union_to_json(self, data_obj, writers_schema, was_within_array=False): return {self._fullname(candidate_schema): output_obj} def _generic_from_json(self, json_obj, writers_schema, readers_schema): - if (writers_schema.type not in ['union', 'error_union'] - and readers_schema.type in ['union', 'error_union']): + if writers_schema.type not in [ + "union", + "error_union", + ] and readers_schema.type in ["union", "error_union"]: for s in readers_schema.schemas: if writers_schema.match(s): return self._generic_from_json(json_obj, writers_schema, s) - raise SchemaResolutionException('Schemas do not match', writers_schema, readers_schema) + raise SchemaResolutionException( + "Schemas do not match", writers_schema, readers_schema + ) result = None - if writers_schema.type == 'null': + if writers_schema.type == "null": result = None elif writers_schema.type in _PRIMITIVE_TYPES: result = self._primitive_from_json(json_obj, writers_schema, readers_schema) - elif writers_schema.type == 'fixed': + elif writers_schema.type == "fixed": result = self._fixed_from_json(json_obj, writers_schema, readers_schema) - elif writers_schema.type == 'enum': + elif writers_schema.type == "enum": result = self._enum_from_json(json_obj, writers_schema, readers_schema) - elif writers_schema.type == 'array': + elif writers_schema.type == "array": result = self._array_from_json(json_obj, writers_schema, readers_schema) - elif writers_schema.type == 'map': + elif writers_schema.type == "map": result = self._map_from_json(json_obj, writers_schema, readers_schema) - elif writers_schema.type in ('union', 'error_union'): + elif writers_schema.type in ("union", "error_union"): result = self._union_from_json(json_obj, writers_schema, readers_schema) - elif writers_schema.type in ('record', 'error', 'request'): + elif writers_schema.type in ("record", "error", "request"): result = self._record_from_json(json_obj, writers_schema, readers_schema) result = self._logical_type_from_json(writers_schema, readers_schema, result) return result def _logical_type_from_json(self, writers_schema, readers_schema, result): - if self.use_logical_types and readers_schema.props.get('logicalType'): - lt = self.logical_types.get(readers_schema.props.get('logicalType')) # type: logical.LogicalTypeProcessor + if self.use_logical_types and readers_schema.props.get("logicalType"): + lt = self.logical_types.get( + readers_schema.props.get("logicalType") + ) # type: logical.LogicalTypeProcessor if lt and lt.does_match(writers_schema, readers_schema): result = lt.convert_back(writers_schema, readers_schema, result) return result def _primitive_from_json(self, json_obj, writers_schema, readers_schema): - if not self.fastavro and writers_schema.type == 'bytes': + if not self.fastavro and writers_schema.type == "bytes": if isinstance(json_obj, str): return json_obj.encode() return json_obj @@ -314,12 +391,18 @@ def _enum_from_json(self, json_obj, writers_schema, readers_schema): return json_obj def _array_from_json(self, json_obj, writers_schema, readers_schema): - return [self._generic_from_json(x, writers_schema.items, readers_schema.items) - for x in json_obj] + return [ + self._generic_from_json(x, writers_schema.items, readers_schema.items) + for x in json_obj + ] def _map_from_json(self, json_obj, writers_schema, readers_schema): - return {name: self._generic_from_json(value, writers_schema.values, readers_schema.values) - for name, value in six.iteritems(json_obj)} + return { + name: self._generic_from_json( + value, writers_schema.values, readers_schema.values + ) + for name, value in six.iteritems(json_obj) + } def _union_from_json(self, json_obj, writers_schema, readers_schema): if json_obj is None: @@ -331,7 +414,9 @@ def _union_from_json(self, json_obj, writers_schema, readers_schema): if len(items) == 1: value_type = items[0][0] value = items[0][1] - if self.fastavro and (isinstance(json_obj, list) or isinstance(json_obj, tuple)): + if self.fastavro and ( + isinstance(json_obj, list) or isinstance(json_obj, tuple) + ): if len(json_obj) == 2: value_type = json_obj[0] value = json_obj[1] @@ -345,8 +430,8 @@ def _union_from_json(self, json_obj, writers_schema, readers_schema): for s in writers_schema.schemas: if self.validate(s, json_obj, skip_logical_types=True): return self._generic_from_json(json_obj, s, readers_schema) - raise schema.AvroException('Datum union type not in schema: %s', value_type) - + raise schema.AvroException("Datum union type not in schema: %s", value_type) + def _make_type(self, tp, record): if issubclass(tp, DictWrapper): return tp._construct(record) @@ -363,7 +448,9 @@ def _instantiate_record(self, decoded_record, writers_schema, readers_schema): return self._make_type(self.schema_types[readers_name], decoded_record) return decoded_record - def _record_from_json(self, json_obj, writers_schema, readers_schema, fail_on_extra_fields=False): + def _record_from_json( + self, json_obj, writers_schema, readers_schema, fail_on_extra_fields=False + ): writer_fields = writers_schema.fields_dict input_keys = set(json_obj.keys()) @@ -372,23 +459,33 @@ def _record_from_json(self, json_obj, writers_schema, readers_schema, fail_on_ex for field in readers_schema.fields: writers_field = writer_fields.get(field.name) if writers_field is None: - field_value = self._generic_from_json(field.default, field.type, field.type) \ - if field.has_default else None + field_value = ( + self._generic_from_json(field.default, field.type, field.type) + if field.has_default + else None + ) else: if field.name in json_obj: - field_value = self._generic_from_json(json_obj[field.name], writers_field.type, field.type) + field_value = self._generic_from_json( + json_obj[field.name], writers_field.type, field.type + ) input_keys.remove(field.name) else: _, nullable = find_type_of_default(field.type) if writers_field.has_default: - field_value = self._generic_from_json(writers_field.default, - writers_field.type, field.type) + field_value = self._generic_from_json( + writers_field.default, writers_field.type, field.type + ) elif nullable: field_value = None else: - raise ValueError(f'{readers_schema.fullname} is missing required field: {field.name}') + raise ValueError( + f"{readers_schema.fullname} is missing required field: {field.name}" + ) result[field.name] = field_value if input_keys and fail_on_extra_fields: # only throw errors if there are fields that we do not know about and fail_on_extra_fields is set to True - raise ValueError(f'{readers_schema.fullname} contains extra fields: {input_keys}') + raise ValueError( + f"{readers_schema.fullname} contains extra fields: {input_keys}" + ) return self._instantiate_record(result, writers_schema, readers_schema) diff --git a/avrogen/core_writer.py b/avrogen/core_writer.py index 95e4886..9386aeb 100644 --- a/avrogen/core_writer.py +++ b/avrogen/core_writer.py @@ -10,40 +10,43 @@ PRIMITIVE_TYPES = { - 'null', - 'boolean', - 'int', - 'long', - 'float', - 'double', - 'bytes', - 'string' + "null", + "boolean", + "int", + "long", + "float", + "double", + "bytes", + "string", } __PRIMITIVE_TYPE_MAPPING = { - 'null': None, - 'boolean': bool, - 'int': int, - 'long': long, - 'float': float, - 'double': float, - 'bytes': bytes, - 'string': str, + "null": None, + "boolean": bool, + "int": int, + "long": long, + "float": float, + "double": float, + "bytes": bytes, + "string": str, } def clean_fullname(fullname): - return fullname.lstrip('.') + return fullname.lstrip(".") + def _python_safe_name(name): if keyword.iskeyword(name): - return f'{name}_' + return f"{name}_" return name def convert_default(idx, full_name=None, do_json=True): if do_json: - return (f'_json_converter.from_json_object(self.RECORD_SCHEMA.fields_dict["{idx}"].default,' - + f' writers_schema=self.RECORD_SCHEMA.fields_dict["{idx}"].type)') + return ( + f'_json_converter.from_json_object(self.RECORD_SCHEMA.fields_dict["{idx}"].default,' + + f' writers_schema=self.RECORD_SCHEMA.fields_dict["{idx}"].type)' + ) else: return f'self.RECORD_SCHEMA.fields_dict["{idx}"].default' @@ -54,43 +57,61 @@ def get_default(field, use_logical_types, my_full_name=None): default_type, nullable = find_type_of_default(field.type) if field.has_default: - if use_logical_types and default_type.props.get('logicalType') \ - and default_type.props.get('logicalType') in logical.DEFAULT_LOGICAL_TYPES: - lt = logical.DEFAULT_LOGICAL_TYPES[default_type.props.get('logicalType')] - v = lt.initializer(convert_default(idx=f_name, full_name=my_full_name, do_json=isinstance(default_type, schema.RecordSchema))) + if ( + use_logical_types + and default_type.props.get("logicalType") + and default_type.props.get("logicalType") in logical.DEFAULT_LOGICAL_TYPES + ): + lt = logical.DEFAULT_LOGICAL_TYPES[default_type.props.get("logicalType")] + v = lt.initializer( + convert_default( + idx=f_name, + full_name=my_full_name, + do_json=isinstance(default_type, schema.RecordSchema), + ) + ) return v elif isinstance(default_type, schema.RecordSchema): d = convert_default(idx=f_name, do_json=True) return d - elif isinstance(default_type, (schema.PrimitiveSchema, schema.EnumSchema, schema.FixedSchema)): + elif isinstance( + default_type, + (schema.PrimitiveSchema, schema.EnumSchema, schema.FixedSchema), + ): d = convert_default(full_name=my_full_name, idx=f_name, do_json=False) return d if not default_written: default_written = True if nullable: - return 'None' - elif use_logical_types and default_type.props.get('logicalType') \ - and default_type.props.get('logicalType') in logical.DEFAULT_LOGICAL_TYPES: - lt = logical.DEFAULT_LOGICAL_TYPES[default_type.props.get('logicalType')] + return "None" + elif ( + use_logical_types + and default_type.props.get("logicalType") + and default_type.props.get("logicalType") in logical.DEFAULT_LOGICAL_TYPES + ): + lt = logical.DEFAULT_LOGICAL_TYPES[default_type.props.get("logicalType")] return str(lt.initializer()) - elif isinstance(default_type, schema.PrimitiveSchema) and not default_type.props.get('logicalType'): + elif isinstance( + default_type, schema.PrimitiveSchema + ) and not default_type.props.get("logicalType"): d = get_primitive_field_initializer(default_type) return d elif isinstance(default_type, schema.EnumSchema): f = clean_fullname(default_type.name) s = default_type.symbols[0] - return f'{f}Class.{s}' + return f"{f}Class.{s}" elif isinstance(default_type, schema.MapSchema): - return 'dict()' + return "dict()" elif isinstance(default_type, schema.ArraySchema): - return 'list()' + return "list()" elif isinstance(default_type, schema.FixedSchema): - return 'bytes()' + return "bytes()" elif isinstance(default_type, schema.RecordSchema): f = clean_fullname(default_type.name) - return f'{f}Class._construct_with_defaults()' - raise AttributeError('cannot get default for field') + return f"{f}Class._construct_with_defaults()" + raise AttributeError("cannot get default for field") + def write_defaults(record, writer, my_full_name=None, use_logical_types=False): """ @@ -107,11 +128,11 @@ def write_defaults(record, writer, my_full_name=None, use_logical_types=False): for field in record.fields: f_name = get_field_name(field, use_logical_types) default = get_default(field, use_logical_types, my_full_name=my_full_name) - writer.write(f'\nself.{f_name} = {default}') + writer.write(f"\nself.{f_name} = {default}") something_written = True i += 1 if not something_written: - writer.write('\npass') + writer.write("\npass") def write_fields(record, writer, use_logical_types): @@ -121,13 +142,15 @@ def write_fields(record, writer, use_logical_types): :param TabbedWriter writer: Writer to write to :return: """ - writer.write('\n\n') + writer.write("\n\n") for field in record.fields: # type: schema.Field write_field(field, writer, use_logical_types) + def get_field_name(field, use_logical_types): return _python_safe_name(field.name) + def write_field(field, writer, use_logical_types): """ Write a single field definition @@ -138,7 +161,8 @@ def write_field(field, writer, use_logical_types): name = get_field_name(field, use_logical_types) doc = field.doc get_docstring = f'"""{doc}"""' if doc else "# No docs available." - writer.write(''' + writer.write( + """ @property def {name}(self) -> {ret_type_name}: {get_docstring} @@ -148,7 +172,13 @@ def {name}(self) -> {ret_type_name}: def {name}(self, value: {ret_type_name}) -> None: self._inner_dict['{raw_name}'] = value -'''.format(name=name, get_docstring=get_docstring, raw_name=field.name, ret_type_name=get_field_type_name(field.type, use_logical_types))) +""".format( + name=name, + get_docstring=get_docstring, + raw_name=field.name, + ret_type_name=get_field_type_name(field.type, use_logical_types), + ) + ) def get_primitive_field_initializer(field_schema): @@ -159,8 +189,8 @@ def get_primitive_field_initializer(field_schema): :return: """ - if field_schema.type == 'null': - return 'None' + if field_schema.type == "null": + return "None" return get_field_type_name(field_schema, False) + "()" @@ -170,18 +200,19 @@ def get_field_type_name(field_schema, use_logical_types): :param schema.Schema field_schema: :return: String containing python type hint """ - if use_logical_types and field_schema.props.get('logicalType'): + if use_logical_types and field_schema.props.get("logicalType"): from avrogen.logical import DEFAULT_LOGICAL_TYPES - lt = DEFAULT_LOGICAL_TYPES.get(field_schema.props.get('logicalType')) + + lt = DEFAULT_LOGICAL_TYPES.get(field_schema.props.get("logicalType")) if lt: return lt.typename() if isinstance(field_schema, schema.PrimitiveSchema): - if field_schema.fullname == 'null': - return 'None' + if field_schema.fullname == "null": + return "None" return __PRIMITIVE_TYPE_MAPPING[field_schema.fullname].__name__ elif isinstance(field_schema, schema.FixedSchema): - return 'bytes' + return "bytes" elif isinstance(field_schema, schema.EnumSchema): # For enums, we have their "class" types, but they're actually # represented as strings. This is a decent hack to work around @@ -190,17 +221,26 @@ def get_field_type_name(field_schema, use_logical_types): elif isinstance(field_schema, schema.NamedSchema): return f'"{field_schema.name}Class"' elif isinstance(field_schema, schema.ArraySchema): - return 'List[' + get_field_type_name(field_schema.items, use_logical_types) + ']' + return ( + "List[" + get_field_type_name(field_schema.items, use_logical_types) + "]" + ) elif isinstance(field_schema, schema.MapSchema): - return 'Dict[str, ' + get_field_type_name(field_schema.values, use_logical_types) + ']' + return ( + "Dict[str, " + + get_field_type_name(field_schema.values, use_logical_types) + + "]" + ) elif isinstance(field_schema, schema.UnionSchema): - type_names = [get_field_type_name(x, use_logical_types) for x in field_schema.schemas if - get_field_type_name(x, use_logical_types)] + type_names = [ + get_field_type_name(x, use_logical_types) + for x in field_schema.schemas + if get_field_type_name(x, use_logical_types) + ] if len(type_names) > 1: - return 'Union[' + ', '.join(type_names) + ']' + return "Union[" + ", ".join(type_names) + "]" elif len(type_names) == 1: return type_names[0] - return '' + return "" def find_type_of_default(field_type): @@ -223,7 +263,7 @@ def find_type_of_default(field_type): # type_, nullable = field_type.schemas[0], True # return type_, nullable elif isinstance(field_type, schema.PrimitiveSchema): - return field_type, field_type.fullname == 'null' + return field_type, field_type.fullname == "null" else: return field_type, False @@ -242,13 +282,13 @@ def start_namespace(current, target, writer): while i < min(len(current), len(target)) and current[i] == target[i]: i += 1 - writer.write('\n\n') + writer.write("\n\n") writer.set_tab(0) - writer.write('\n') + writer.write("\n") for component in target[i:]: - writer.write('class {name}(object):'.format(name=component)) + writer.write("class {name}(object):".format(name=component)) writer.tab() - writer.write('\n') + writer.write("\n") def write_preamble(writer, use_logical_types, custom_imports): @@ -257,22 +297,22 @@ def write_preamble(writer, use_logical_types, custom_imports): :param writer: :return: """ - writer.write('import json\n') - writer.write('import os.path\n') - writer.write('import decimal\n') - writer.write('import datetime\n') - writer.write('import six\n') - - for cs in (custom_imports or []): - writer.write(f'import {cs}\n') - writer.write('from avrogen.dict_wrapper import DictWrapper\n') - writer.write('from avrogen import avrojson\n') + writer.write("import json\n") + writer.write("import os.path\n") + writer.write("import decimal\n") + writer.write("import datetime\n") + writer.write("import six\n") + + for cs in custom_imports or []: + writer.write(f"import {cs}\n") + writer.write("from avrogen.dict_wrapper import DictWrapper\n") + writer.write("from avrogen import avrojson\n") if use_logical_types: - writer.write('from avrogen import logical\n') - writer.write('from avro.schema import RecordSchema, make_avsc_object\n') - writer.write('from avro import schema as avro_schema\n') - writer.write('from typing import ClassVar, List, Dict, Union, Optional, Type\n') - writer.write('\n') + writer.write("from avrogen import logical\n") + writer.write("from avro.schema import RecordSchema, make_avsc_object\n") + writer.write("from avro import schema as avro_schema\n") + writer.write("from typing import ClassVar, List, Dict, Union, Optional, Type\n") + writer.write("\n") def write_read_file(writer): @@ -281,11 +321,11 @@ def write_read_file(writer): :param writer: :return: """ - writer.write('\ndef __read_file(file_name):') + writer.write("\ndef __read_file(file_name):") with writer.indent(): writer.write('\nwith open(file_name, "r") as f:') with writer.indent(): - writer.write('\nreturn f.read()\n') + writer.write("\nreturn f.read()\n") def write_get_schema(writer): @@ -294,10 +334,10 @@ def write_get_schema(writer): :param writer: :return: """ - writer.write('\n__SCHEMAS: Dict[str, RecordSchema] = {}\n\n\n') - writer.write('def get_schema_type(fullname: str) -> RecordSchema:') + writer.write("\n__SCHEMAS: Dict[str, RecordSchema] = {}\n\n\n") + writer.write("def get_schema_type(fullname: str) -> RecordSchema:") with writer.indent(): - writer.write('\nreturn __SCHEMAS[fullname]\n\n') + writer.write("\nreturn __SCHEMAS[fullname]\n\n") def write_reader_impl(record_types, writer, use_logical_types): @@ -307,37 +347,54 @@ def write_reader_impl(record_types, writer, use_logical_types): :param writer: :return: """ - writer.write('\n\n\nclass SpecificDatumReader(%s):' % ( - 'DatumReader' if not use_logical_types else 'logical.LogicalDatumReader')) + writer.write( + "\n\n\nclass SpecificDatumReader(%s):" + % ("DatumReader" if not use_logical_types else "logical.LogicalDatumReader") + ) with writer.indent(): - writer.write('\nSCHEMA_TYPES = {') + writer.write("\nSCHEMA_TYPES = {") with writer.indent(): for t in record_types: - t_class = t.split('.')[-1] + t_class = t.split(".")[-1] writer.write('\n"{t_class}": {t_class}Class,'.format(t_class=t_class)) writer.write('\n".{t_class}": {t_class}Class,'.format(t_class=t_class)) - writer.write('\n"{f_class}": {t_class}Class,'.format(t_class=t_class, f_class=t)) + writer.write( + '\n"{f_class}": {t_class}Class,'.format(t_class=t_class, f_class=t) + ) - writer.write('\n}') - writer.write('\n\n\ndef __init__(self, readers_schema=None, **kwargs):') + writer.write("\n}") + writer.write("\n\n\ndef __init__(self, readers_schema=None, **kwargs):") with writer.indent(): - writer.write('\nwriters_schema = kwargs.pop("writers_schema", readers_schema)') - writer.write('\nwriters_schema = kwargs.pop("writer_schema", writers_schema)') - writer.write('\nsuper(SpecificDatumReader, self).__init__(writers_schema, readers_schema, **kwargs)') + writer.write( + '\nwriters_schema = kwargs.pop("writers_schema", readers_schema)' + ) + writer.write( + '\nwriters_schema = kwargs.pop("writer_schema", writers_schema)' + ) + writer.write( + "\nsuper(SpecificDatumReader, self).__init__(writers_schema, readers_schema, **kwargs)" + ) - writer.write('\n\n\ndef read_record(self, writers_schema, readers_schema, decoder):') + writer.write( + "\n\n\ndef read_record(self, writers_schema, readers_schema, decoder):" + ) with writer.indent(): writer.write( - '\nresult = super(SpecificDatumReader, self).read_record(writers_schema, readers_schema, decoder)') - writer.write('\n\nif readers_schema.fullname in SpecificDatumReader.SCHEMA_TYPES:') + "\nresult = super(SpecificDatumReader, self).read_record(writers_schema, readers_schema, decoder)" + ) + writer.write( + "\n\nif readers_schema.fullname in SpecificDatumReader.SCHEMA_TYPES:" + ) with writer.indent(): - writer.write('\ntp = SpecificDatumReader.SCHEMA_TYPES[readers_schema.fullname]') - writer.write('\nif issubclass(tp, DictWrapper):') - writer.write('\n result = tp._construct(result)') - writer.write('\nelse:') - writer.write('\n # tp is an enum') - writer.write('\n result = tp(result) # type: ignore') - writer.write('\n\nreturn result') + writer.write( + "\ntp = SpecificDatumReader.SCHEMA_TYPES[readers_schema.fullname]" + ) + writer.write("\nif issubclass(tp, DictWrapper):") + writer.write("\n result = tp._construct(result)") + writer.write("\nelse:") + writer.write("\n # tp is an enum") + writer.write("\n result = tp(result) # type: ignore") + writer.write("\n\nreturn result") def generate_namespace_modules(names, output_folder): @@ -351,16 +408,17 @@ def generate_namespace_modules(names, output_folder): """ ns_dict = {} for name in names: - name_parts = name.split('.') + name_parts = name.split(".") full_path = output_folder for part in name_parts[:-1]: full_path = os.path.join(full_path, part) if not os.path.isdir(full_path): os.mkdir(full_path) # make sure __init__.py is created for every namespace level - with open(os.path.join(full_path, "__init__.py"), "w+"): pass + with open(os.path.join(full_path, "__init__.py"), "w+"): + pass - ns = '.'.join(name_parts[:-1]) + ns = ".".join(name_parts[:-1]) if not ns in ns_dict: ns_dict[ns] = [] ns_dict[ns].append(name_parts[-1]) @@ -376,14 +434,14 @@ def write_schema_record(record, writer, use_logical_types): """ _, type_name = ns_.split_fullname(record.fullname) - writer.write('''\nclass {name}Class(DictWrapper):'''.format(name=type_name)) + writer.write("""\nclass {name}Class(DictWrapper):""".format(name=type_name)) with writer.indent(): - writer.write('\n') + writer.write("\n") if record.doc: writer.write(f'"""{record.doc}"""') else: - writer.write('# No docs available.') + writer.write("# No docs available.") writer.write('\n\nRECORD_SCHEMA = get_schema_type("%s")' % (record.fullname)) write_record_init(record, writer, use_logical_types) @@ -392,7 +450,7 @@ def write_schema_record(record, writer, use_logical_types): def write_record_init(record, writer, use_logical_types): - writer.write('\ndef __init__(self,') + writer.write("\ndef __init__(self,") with writer.indent(): delayed_lines = [] default_map: Dict[str, str] = {} @@ -408,32 +466,32 @@ def write_record_init(record, writer, use_logical_types): ret_type_name = f"Optional[{ret_type_name}]" nullable = True if nullable: - delayed_lines.append(f'\n{name}: {ret_type_name}=None,') + delayed_lines.append(f"\n{name}: {ret_type_name}=None,") else: - writer.write(f'\n{name}: {ret_type_name},') + writer.write(f"\n{name}: {ret_type_name},") # default = get_default(field, use_logical_types) # writer.write(f'\n{name}: {ret_type_name} = {default},') for line in delayed_lines: writer.write(line) - writer.write('\n):') + writer.write("\n):") with writer.indent(): - writer.write('\n') - writer.write('super().__init__()') - writer.write('\n') + writer.write("\n") + writer.write("super().__init__()") + writer.write("\n") for field in record.fields: # type: schema.Field name = get_field_name(field, use_logical_types) if name in default_map: - writer.write(f'\nif {name} is None:') - writer.write(f'\n # default: {repr(field.default)}') - writer.write(f'\n self.{name} = {default_map[name]}') - writer.write(f'\nelse:') - writer.write(f'\n self.{name} = {name}') + writer.write(f"\nif {name} is None:") + writer.write(f"\n # default: {repr(field.default)}") + writer.write(f"\n self.{name} = {default_map[name]}") + writer.write(f"\nelse:") + writer.write(f"\n self.{name} = {name}") else: - writer.write(f'\nself.{name} = {name}') + writer.write(f"\nself.{name} = {name}") - writer.write('\n') - writer.write(f'\ndef _restore_defaults(self) -> None:') + writer.write("\n") + writer.write(f"\ndef _restore_defaults(self) -> None:") with writer.indent(): write_defaults(record, writer, use_logical_types=use_logical_types) @@ -446,22 +504,22 @@ def write_enum(enum, writer): :return: """ _, type_name = ns_.split_fullname(enum.fullname) - writer.write('''\nclass {name}Class(object):'''.format(name=type_name)) + writer.write("""\nclass {name}Class(object):""".format(name=type_name)) with writer.indent(): - writer.write('\n') + writer.write("\n") if enum.doc: writer.write(f'"""{enum.doc}"""') else: - writer.write('# No docs available.') + writer.write("# No docs available.") - writer.write('\n\n') - symbolDocs = enum.other_props.get('symbolDocs', {}) + writer.write("\n\n") + symbolDocs = enum.other_props.get("symbolDocs", {}) for field in enum.symbols: # Docs for enum fields go _below_ the field. writer.write('{name} = "{name}"\n'.format(name=field)) if field in symbolDocs: writer.write(f'"""{symbolDocs[field]}"""\n') if symbolDocs: - writer.write('\n') - writer.write('\n') + writer.write("\n") + writer.write("\n") diff --git a/avrogen/logical.py b/avrogen/logical.py index f640277..ae4564d 100644 --- a/avrogen/logical.py +++ b/avrogen/logical.py @@ -43,14 +43,17 @@ def initializer(self, value=None): class DecimalLogicalTypeProcessor(LogicalTypeProcessor): def can_convert(self, writers_schema): - return isinstance(writers_schema, schema.PrimitiveSchema) and writers_schema.type == 'string' + return ( + isinstance(writers_schema, schema.PrimitiveSchema) + and writers_schema.type == "string" + ) def validate(self, expected_schema, datum): return isinstance(datum, (int, float, long, decimal.Decimal)) def convert(self, writers_schema, value): if not isinstance(value, (int, float, long, decimal.Decimal)): - raise Exception('Wrong type for decimal conversion') + raise Exception("Wrong type for decimal conversion") return str(value) def convert_back(self, writers_schema, readers_schema, value): @@ -58,22 +61,25 @@ def convert_back(self, writers_schema, readers_schema, value): def does_match(self, writers_schema, readers_schema): if isinstance(writers_schema, schema.PrimitiveSchema): - if writers_schema.type == 'string': + if writers_schema.type == "string": return True return False def typename(self): - return 'decimal.Decimal' + return "decimal.Decimal" def initializer(self, value=None): - return 'decimal.Decimal(%s)' % (0 if value is None else value) + return "decimal.Decimal(%s)" % (0 if value is None else value) class DateLogicalTypeProcessor(LogicalTypeProcessor): - _matching_types = {'int', 'long', 'float', 'double'} + _matching_types = {"int", "long", "float", "double"} def can_convert(self, writers_schema): - return isinstance(writers_schema, schema.PrimitiveSchema) and writers_schema.type == 'int' + return ( + isinstance(writers_schema, schema.PrimitiveSchema) + and writers_schema.type == "int" + ) def validate(self, expected_schema, datum): return isinstance(datum, datetime.date) @@ -93,30 +99,43 @@ def does_match(self, writers_schema, readers_schema): return False def typename(self): - return 'datetime.date' + return "datetime.date" def initializer(self, value=None): - return (( - 'logical.DateLogicalTypeProcessor().convert_back(None, None, %s)' % value) if value is not None - else 'datetime.datetime.today().date()') + return ( + ("logical.DateLogicalTypeProcessor().convert_back(None, None, %s)" % value) + if value is not None + else "datetime.datetime.today().date()" + ) class TimeMicrosLogicalTypeProcessor(LogicalTypeProcessor): - _matching_types = {'int', 'long', 'float', 'double'} + _matching_types = {"int", "long", "float", "double"} def can_convert(self, writers_schema): - return isinstance(writers_schema, schema.PrimitiveSchema) and writers_schema.type == 'long' + return ( + isinstance(writers_schema, schema.PrimitiveSchema) + and writers_schema.type == "long" + ) def validate(self, expected_schema, datum): return isinstance(datum, datetime.time) def convert(self, writers_schema, value): if not isinstance(value, datetime.time): - raise Exception('Wrong type for time conversion') - return ((value.hour * 60 + value.minute) * 60 + value.second) * 1000000 + value.microsecond + raise Exception("Wrong type for time conversion") + return ( + (value.hour * 60 + value.minute) * 60 + value.second + ) * 1000000 + value.microsecond def convert_back(self, writers_schema, readers_schema, value): - _, hours, minutes, seconds, microseconds = TimeMicrosLogicalTypeProcessor.extract_time_parts(value) + ( + _, + hours, + minutes, + seconds, + microseconds, + ) = TimeMicrosLogicalTypeProcessor.extract_time_parts(value) return datetime.time(hours, minutes, seconds, microseconds) @staticmethod @@ -139,37 +158,58 @@ def does_match(self, writers_schema, readers_schema): return False def typename(self): - return 'datetime.time' + return "datetime.time" def initializer(self, value=None): - return (( - 'logical.TimeMicrosLogicalTypeProcessor().convert_back(None, None, %s)' % value) if value is not None - else 'datetime.datetime.today().time()') + return ( + ( + "logical.TimeMicrosLogicalTypeProcessor().convert_back(None, None, %s)" + % value + ) + if value is not None + else "datetime.datetime.today().time()" + ) class TimeMillisLogicalTypeProcessor(TimeMicrosLogicalTypeProcessor): def can_convert(self, writers_schema): - return isinstance(writers_schema, schema.PrimitiveSchema) and writers_schema.type == 'int' + return ( + isinstance(writers_schema, schema.PrimitiveSchema) + and writers_schema.type == "int" + ) def convert(self, writers_schema, value): if not isinstance(value, datetime.time): - raise Exception('Wrong type for time conversion') - return int(super(TimeMillisLogicalTypeProcessor, self).convert(writers_schema, value) // 1000) + raise Exception("Wrong type for time conversion") + return int( + super(TimeMillisLogicalTypeProcessor, self).convert(writers_schema, value) + // 1000 + ) def convert_back(self, writers_schema, readers_schema, value): - return super(TimeMillisLogicalTypeProcessor, self).convert_back(writers_schema, readers_schema, value * 1000) + return super(TimeMillisLogicalTypeProcessor, self).convert_back( + writers_schema, readers_schema, value * 1000 + ) def initializer(self, value=None): - return (( - 'logical.TimeMillisLogicalTypeProcessor().convert_back(None, None, %s)' % value) if value is not None - else 'datetime.datetime.today().time()') + return ( + ( + "logical.TimeMillisLogicalTypeProcessor().convert_back(None, None, %s)" + % value + ) + if value is not None + else "datetime.datetime.today().time()" + ) class TimestampMicrosLogicalTypeProcessor(LogicalTypeProcessor): - _matching_types = {'int', 'long', 'float', 'double'} + _matching_types = {"int", "long", "float", "double"} def can_convert(self, writers_schema): - return isinstance(writers_schema, schema.PrimitiveSchema) and writers_schema.type == 'long' + return ( + isinstance(writers_schema, schema.PrimitiveSchema) + and writers_schema.type == "long" + ) def validate(self, expected_schema, datum): return isinstance(datum, datetime.datetime) @@ -177,7 +217,16 @@ def validate(self, expected_schema, datum): def convert(self, writers_schema, value): if not isinstance(value, datetime.datetime): if isinstance(value, datetime.date): - value = datetime.datetime(value.year, value.month, value.day, 0, 0, 0, 0, datetime.timezone.utc) + value = datetime.datetime( + value.year, + value.month, + value.day, + 0, + 0, + 0, + 0, + datetime.timezone.utc, + ) if value.tzinfo is None: value = value.replace(tzinfo=datetime.timezone.utc) @@ -190,45 +239,69 @@ def convert_back(self, writers_schema, readers_schema, value): def does_match(self, writers_schema, readers_schema): if isinstance(writers_schema, schema.PrimitiveSchema): - if writers_schema.type in TimestampMicrosLogicalTypeProcessor._matching_types: + if ( + writers_schema.type + in TimestampMicrosLogicalTypeProcessor._matching_types + ): return True return False def typename(self): - return 'datetime.datetime' + return "datetime.datetime" def initializer(self, value=None): - return (( - 'logical.TimestampMicrosLogicalTypeProcessor().convert_back(None, None, %s)' % value) if value is not None - else 'datetime.datetime.now()') + return ( + ( + "logical.TimestampMicrosLogicalTypeProcessor().convert_back(None, None, %s)" + % value + ) + if value is not None + else "datetime.datetime.now()" + ) class TimestampMillisLogicalTypeProcessor(TimestampMicrosLogicalTypeProcessor): def convert(self, writers_schema, value): - return super(TimestampMillisLogicalTypeProcessor, self).convert(writers_schema, value) // 1000 + return ( + super(TimestampMillisLogicalTypeProcessor, self).convert( + writers_schema, value + ) + // 1000 + ) def convert_back(self, writers_schema, readers_schema, value): - return super(TimestampMillisLogicalTypeProcessor, self).convert_back(writers_schema, readers_schema, - value * 1000) + return super(TimestampMillisLogicalTypeProcessor, self).convert_back( + writers_schema, readers_schema, value * 1000 + ) def initializer(self, value=None): - return (( - 'logical.TimestampMillisLogicalTypeProcessor().convert_back(None, None, %s)' % value) if value is not None - else 'datetime.datetime.now()') + return ( + ( + "logical.TimestampMillisLogicalTypeProcessor().convert_back(None, None, %s)" + % value + ) + if value is not None + else "datetime.datetime.now()" + ) DEFAULT_LOGICAL_TYPES = { - 'decimal': DecimalLogicalTypeProcessor(), - 'date': DateLogicalTypeProcessor(), - 'time-millis': TimeMillisLogicalTypeProcessor(), - 'time-micros': TimeMicrosLogicalTypeProcessor(), - 'timestamp-millis': TimestampMillisLogicalTypeProcessor(), - 'timestamp-micros': TimestampMicrosLogicalTypeProcessor(), + "decimal": DecimalLogicalTypeProcessor(), + "date": DateLogicalTypeProcessor(), + "time-millis": TimeMillisLogicalTypeProcessor(), + "time-micros": TimeMicrosLogicalTypeProcessor(), + "timestamp-millis": TimestampMillisLogicalTypeProcessor(), + "timestamp-micros": TimestampMicrosLogicalTypeProcessor(), } class LogicalDatumReader(io.DatumReader): - def __init__(self, writers_schema=None, readers_schema=None, logical_types=DEFAULT_LOGICAL_TYPES): + def __init__( + self, + writers_schema=None, + readers_schema=None, + logical_types=DEFAULT_LOGICAL_TYPES, + ): """ Initializes DatumReader with logical type support @@ -236,7 +309,9 @@ def __init__(self, writers_schema=None, readers_schema=None, logical_types=DEFAU :param schema.Schema readers_schema: Optional reader's schema :param dict[str, LogicalTypeProcessor] logical_types: Optional logical types dict """ - super(LogicalDatumReader, self).__init__(writers_schema=writers_schema, readers_schema=readers_schema) + super(LogicalDatumReader, self).__init__( + writers_schema=writers_schema, readers_schema=readers_schema + ) self.logical_types = logical_types or {} def read_data(self, writers_schema, readers_schema, decoder): @@ -248,39 +323,51 @@ def read_data(self, writers_schema, readers_schema, decoder): :param io.BinaryDecoder decoder: :return: """ - result = super(LogicalDatumReader, self).read_data(writers_schema, readers_schema, decoder) - logical_type = readers_schema.props.get('logicalType') + result = super(LogicalDatumReader, self).read_data( + writers_schema, readers_schema, decoder + ) + logical_type = readers_schema.props.get("logicalType") if logical_type: logical_type_handler = self.logical_types.get(logical_type) - if logical_type_handler and logical_type_handler.does_match(writers_schema, readers_schema): - result = logical_type_handler.convert_back(writers_schema, readers_schema, result) + if logical_type_handler and logical_type_handler.does_match( + writers_schema, readers_schema + ): + result = logical_type_handler.convert_back( + writers_schema, readers_schema, result + ) return result class LogicalDatumWriter(io.DatumWriter): """ - Initializes DatumWriter with logical type support + Initializes DatumWriter with logical type support - :param schema.Schema writers_schema: Writer's schema - :param dict[str, LogicalTypeProcessor] logical_types: Optional logical types dict - """ + :param schema.Schema writers_schema: Writer's schema + :param dict[str, LogicalTypeProcessor] logical_types: Optional logical types dict + """ def __init__(self, writers_schema=None, logical_types=DEFAULT_LOGICAL_TYPES): super(LogicalDatumWriter, self).__init__(writers_schema=writers_schema) self.logical_types = logical_types def write_data(self, writers_schema, datum, encoder): - logical_type = writers_schema.props.get('logicalType') + logical_type = writers_schema.props.get("logicalType") if logical_type: logical_type_handler = self.logical_types.get(logical_type) - if logical_type_handler and logical_type_handler.can_convert(writers_schema): - return super(LogicalDatumWriter, self).write_data(writers_schema, - logical_type_handler.convert(writers_schema, datum), - encoder) - return super(LogicalDatumWriter, self).write_data(writers_schema, datum, encoder) + if logical_type_handler and logical_type_handler.can_convert( + writers_schema + ): + return super(LogicalDatumWriter, self).write_data( + writers_schema, + logical_type_handler.convert(writers_schema, datum), + encoder, + ) + return super(LogicalDatumWriter, self).write_data( + writers_schema, datum, encoder + ) def __validate(self, writers_schema, datum): - logical_type = writers_schema.props.get('logicalType') + logical_type = writers_schema.props.get("logicalType") if logical_type: lt = self.logical_types.get(logical_type) if lt: @@ -290,20 +377,26 @@ def __validate(self, writers_schema, datum): return False schema_type = writers_schema.type - if schema_type == 'array': - return (isinstance(datum, list) and - False not in [self.__validate(writers_schema.items, d) for d in datum]) - elif schema_type == 'map': - return (isinstance(datum, dict) and - False not in [isinstance(k, basestring) for k in datum.keys()] and - False not in - [self.__validate(writers_schema.values, v) for v in datum.values()]) - elif schema_type in ['union', 'error_union']: + if schema_type == "array": + return isinstance(datum, list) and False not in [ + self.__validate(writers_schema.items, d) for d in datum + ] + elif schema_type == "map": + return ( + isinstance(datum, dict) + and False not in [isinstance(k, basestring) for k in datum.keys()] + and False + not in [ + self.__validate(writers_schema.values, v) for v in datum.values() + ] + ) + elif schema_type in ["union", "error_union"]: return True in [self.__validate(s, datum) for s in writers_schema.schemas] - elif schema_type in ['record', 'error', 'request']: - return (isinstance(datum, dict) and - False not in - [self.__validate(f.type, datum.get(f.name)) for f in writers_schema.fields]) + elif schema_type in ["record", "error", "request"]: + return isinstance(datum, dict) and False not in [ + self.__validate(f.type, datum.get(f.name)) + for f in writers_schema.fields + ] return io.validate(writers_schema, datum) diff --git a/avrogen/namespace.py b/avrogen/namespace.py index 3d780ad..02deb6f 100644 --- a/avrogen/namespace.py +++ b/avrogen/namespace.py @@ -1,13 +1,13 @@ def make_fullname(ns, name): - return ((ns + '.') if ns else '') + name + return ((ns + ".") if ns else "") + name def split_fullname(fullname): - idx = fullname.rfind('.') + idx = fullname.rfind(".") if idx < 0: - return '', fullname + return "", fullname - return fullname[:idx], fullname[idx + 1:] + return fullname[:idx], fullname[idx + 1 :] def get_shortname(fullname): diff --git a/avrogen/protocol.py b/avrogen/protocol.py index 94c613a..f3ff2c3 100644 --- a/avrogen/protocol.py +++ b/avrogen/protocol.py @@ -19,7 +19,12 @@ from .protocol_writer import write_protocol_request -def generate_protocol(protocol_json, use_logical_types=False, custom_imports=None, avro_json_converter=None): +def generate_protocol( + protocol_json, + use_logical_types=False, + custom_imports=None, + avro_json_converter=None, +): """ Generate content of the file which will contain concrete classes for RecordSchemas and requests contained in the avro protocol @@ -31,14 +36,16 @@ def generate_protocol(protocol_json, use_logical_types=False, custom_imports=Non """ if avro_json_converter is None: - avro_json_converter = 'avrojson.AvroJsonConverter' + avro_json_converter = "avrojson.AvroJsonConverter" - if '(' not in avro_json_converter: - avro_json_converter += '(use_logical_types=%s, schema_types=__SCHEMA_TYPES)' % use_logical_types + if "(" not in avro_json_converter: + avro_json_converter += ( + "(use_logical_types=%s, schema_types=__SCHEMA_TYPES)" % use_logical_types + ) custom_imports = custom_imports or [] - if not hasattr(protocol, 'parse'): + if not hasattr(protocol, "parse"): # Older versions of avro used a capital P in Parse. proto = protocol.Parse(protocol_json) else: @@ -55,9 +62,19 @@ def generate_protocol(protocol_json, use_logical_types=False, custom_imports=Non schemas.append((schema_idx, record_schema)) known_types.add(clean_fullname(record_schema.fullname)) - for message in (six.itervalues(proto.messages) if six.PY2 else proto.messages): - messages.append((message, message.request, message.response if isinstance(message.response, ( - schema.EnumSchema, schema.RecordSchema)) and clean_fullname(message.response.fullname) not in known_types else None)) + for message in six.itervalues(proto.messages) if six.PY2 else proto.messages: + messages.append( + ( + message, + message.request, + message.response + if isinstance( + message.response, (schema.EnumSchema, schema.RecordSchema) + ) + and clean_fullname(message.response.fullname) not in known_types + else None, + ) + ) if isinstance(message.response, (schema.EnumSchema, schema.RecordSchema)): known_types.add(clean_fullname(message.response.fullname)) @@ -65,17 +82,17 @@ def generate_protocol(protocol_json, use_logical_types=False, custom_imports=Non for schema_idx, record_schema in schemas: ns, name = ns_.split_fullname(clean_fullname(record_schema.fullname)) if ns not in namespaces: - namespaces[ns] = {'requests': [], 'records': [], 'responses': []} - namespaces[ns]['records'].append((schema_idx, record_schema)) + namespaces[ns] = {"requests": [], "records": [], "responses": []} + namespaces[ns]["records"].append((schema_idx, record_schema)) for message, request, response in messages: fullname = ns_.make_fullname(proto.namespace, clean_fullname(message.name)) ns, name = ns_.split_fullname(fullname) if ns not in namespaces: - namespaces[ns] = {'requests': [], 'records': [], 'responses': []} - namespaces[ns]['requests'].append(message) + namespaces[ns] = {"requests": [], "records": [], "responses": []} + namespaces[ns]["requests"].append(message) if response: - namespaces[ns]['responses'].append(message) + namespaces[ns]["responses"].append(message) main_out = StringIO() writer = TabbedWriter(main_out) @@ -85,83 +102,94 @@ def generate_protocol(protocol_json, use_logical_types=False, custom_imports=Non write_get_schema(writer) write_populate_schemas(writer) - writer.write('\n\n\nclass SchemaClasses(object):') + writer.write("\n\n\nclass SchemaClasses(object):") with writer.indent(): - writer.write('\n\n') + writer.write("\n\n") current_namespace = tuple() all_ns = sorted(namespaces.keys()) for ns in all_ns: - if not (namespaces[ns]['responses'] or namespaces[ns]['records']): + if not (namespaces[ns]["responses"] or namespaces[ns]["records"]): continue - namespace = ns.split('.') + namespace = ns.split(".") if namespace != current_namespace: start_namespace(current_namespace, namespace, writer) - for idx, record in namespaces[ns]['records']: + for idx, record in namespaces[ns]["records"]: schema_names.add(clean_fullname(record.fullname)) if isinstance(record, schema.RecordSchema): write_schema_record(record, writer, use_logical_types) elif isinstance(record, schema.EnumSchema): write_enum(record, writer) - for message in namespaces[ns]['responses']: + for message in namespaces[ns]["responses"]: schema_names.add(clean_fullname(message.response.fullname)) if isinstance(message.response, schema.RecordSchema): write_schema_record(message.response, writer, use_logical_types) elif isinstance(message.response, schema.EnumSchema): write_enum(message.response, writer) - writer.write('\n\npass') + writer.write("\n\npass") writer.set_tab(0) - writer.write('\n\n\nclass RequestClasses(object):') + writer.write("\n\n\nclass RequestClasses(object):") with writer.indent() as indent: - writer.write('\n\n') + writer.write("\n\n") current_namespace = tuple() all_ns = sorted(namespaces.keys()) for ns in all_ns: - if not (namespaces[ns]['requests'] or namespaces[ns]['responses']): + if not (namespaces[ns]["requests"] or namespaces[ns]["responses"]): continue - namespace = ns.split('.') + namespace = ns.split(".") if namespace != current_namespace: start_namespace(current_namespace, namespace, writer) - for message in namespaces[ns]['requests']: - request_names.add(ns_.make_fullname(proto.namespace, clean_fullname(message.name))) - write_protocol_request(message, proto.namespace, writer, use_logical_types) + for message in namespaces[ns]["requests"]: + request_names.add( + ns_.make_fullname(proto.namespace, clean_fullname(message.name)) + ) + write_protocol_request( + message, proto.namespace, writer, use_logical_types + ) - writer.write('\n\npass') + writer.write("\n\npass") writer.untab() writer.set_tab(0) - writer.write('\n__SCHEMA_TYPES = {\n') + writer.write("\n__SCHEMA_TYPES = {\n") writer.tab() all_ns = sorted(namespaces.keys()) for ns in all_ns: - for idx, record in (namespaces[ns]['records'] or []): - writer.write("'%s': SchemaClasses.%sClass,\n" % (clean_fullname(record.fullname), - clean_fullname(record.fullname))) - - for message in (namespaces[ns]['responses'] or []): - writer.write("'%s': SchemaClasses.%sClass,\n" % (clean_fullname(message.response.fullname), - clean_fullname(message.response.fullname))) - - for message in (namespaces[ns]['requests'] or []): + for idx, record in namespaces[ns]["records"] or []: + writer.write( + "'%s': SchemaClasses.%sClass,\n" + % (clean_fullname(record.fullname), clean_fullname(record.fullname)) + ) + + for message in namespaces[ns]["responses"] or []: + writer.write( + "'%s': SchemaClasses.%sClass,\n" + % ( + clean_fullname(message.response.fullname), + clean_fullname(message.response.fullname), + ) + ) + + for message in namespaces[ns]["requests"] or []: name = ns_.make_fullname(proto.namespace, clean_fullname(message.name)) writer.write("'%s': RequestClasses.%sRequestClass, \n" % (name, name)) writer.untab() - writer.write('\n}\n') + writer.write("\n}\n") - writer.write('_json_converter = %s\n' % avro_json_converter) - writer.write('avrojson.set_global_json_converter(_json_converter)\n') + writer.write("_json_converter = %s\n" % avro_json_converter) + writer.write("avrojson.set_global_json_converter(_json_converter)\n") value = main_out.getvalue() main_out.close() return value, schema_names, request_names @@ -177,18 +205,20 @@ def write_protocol_preamble(writer, use_logical_types, custom_imports): :return: """ write_read_file(writer) - writer.write('\nfrom avro import protocol as avro_protocol') + writer.write("\nfrom avro import protocol as avro_protocol") - for i in (custom_imports or []): - writer.write('import %s\n' % i) + for i in custom_imports or []: + writer.write("import %s\n" % i) if use_logical_types: - writer.write('\nfrom avrogen import logical') - writer.write('\n\ndef __get_protocol(file_name):') + writer.write("\nfrom avrogen import logical") + writer.write("\n\ndef __get_protocol(file_name):") with writer.indent(): - writer.write('\nproto = avro_protocol.parse(__read_file(file_name))') - writer.write('\nreturn proto') - writer.write('\n\nPROTOCOL = __get_protocol(os.path.join(os.path.dirname(__file__), "protocol.avpr"))') + writer.write("\nproto = avro_protocol.parse(__read_file(file_name))") + writer.write("\nreturn proto") + writer.write( + '\n\nPROTOCOL = __get_protocol(os.path.join(os.path.dirname(__file__), "protocol.avpr"))' + ) def write_populate_schemas(writer): @@ -197,20 +227,28 @@ def write_populate_schemas(writer): :param writer: :return: """ - writer.write('\nfor rec in PROTOCOL.types:') + writer.write("\nfor rec in PROTOCOL.types:") with writer.indent(): - writer.write('\n__SCHEMAS[rec.fullname] = rec') + writer.write("\n__SCHEMAS[rec.fullname] = rec") - writer.write('\nfor resp in (six.itervalues(PROTOCOL.messages) if six.PY2 else PROTOCOL.messages):') + writer.write( + "\nfor resp in (six.itervalues(PROTOCOL.messages) if six.PY2 else PROTOCOL.messages):" + ) with writer.indent(): - writer.write('\nif isinstance(resp.response, (avro_schema.RecordSchema, avro_schema.EnumSchema)):') + writer.write( + "\nif isinstance(resp.response, (avro_schema.RecordSchema, avro_schema.EnumSchema)):" + ) with writer.indent(): - writer.write('\n__SCHEMAS[resp.response.fullname] = resp.response') + writer.write("\n__SCHEMAS[resp.response.fullname] = resp.response") - writer.write('\nPROTOCOL_MESSAGES = {m.name.lstrip("."):m for m in (six.itervalues(PROTOCOL.messages) if six.PY2 else PROTOCOL.messages)}\n') + writer.write( + '\nPROTOCOL_MESSAGES = {m.name.lstrip("."):m for m in (six.itervalues(PROTOCOL.messages) if six.PY2 else PROTOCOL.messages)}\n' + ) -def write_protocol_files(protocol_json, output_folder, use_logical_types=False, custom_imports=None): +def write_protocol_files( + protocol_json, output_folder, use_logical_types=False, custom_imports=None +): """ Generates concrete classes for RecordSchemas and requests and a SpecificReader for types and messages contained in the avro protocol. @@ -219,7 +257,9 @@ def write_protocol_files(protocol_json, output_folder, use_logical_types=False, :param list[str] custom_imports: Add additional import modules :return: """ - proto_py, record_names, request_names = generate_protocol(protocol_json, use_logical_types, custom_imports) + proto_py, record_names, request_names = generate_protocol( + protocol_json, use_logical_types, custom_imports + ) names = sorted(list(record_names) + list(request_names)) if not os.path.isdir(output_folder): os.mkdir(output_folder) @@ -248,8 +288,10 @@ def write_specific_reader(record_types, output_folder, use_logical_types): """ with open(os.path.join(output_folder, "__init__.py"), "a+") as f: writer = TabbedWriter(f) - writer.write('\n\nfrom .schema_classes import SchemaClasses, PROTOCOL as my_proto, get_schema_type') - writer.write('\nfrom avro.io import DatumReader') + writer.write( + "\n\nfrom .schema_classes import SchemaClasses, PROTOCOL as my_proto, get_schema_type" + ) + writer.write("\nfrom avro.io import DatumReader") write_reader_impl(record_types, writer, use_logical_types) @@ -264,18 +306,33 @@ def write_namespace_modules(ns_dict, request_names, output_folder): :return: """ for ns in six.iterkeys(ns_dict): - with open(os.path.join(output_folder, ns.replace('.', os.path.sep), "__init__.py"), "w+") as f: - currency = '.' - if ns != '': - currency += '.' * len(ns.split('.')) - f.write('from {currency}schema_classes import SchemaClasses\n'.format(currency=currency)) - f.write('from {currency}schema_classes import RequestClasses\n'.format(currency=currency)) + with open( + os.path.join(output_folder, ns.replace(".", os.path.sep), "__init__.py"), + "w+", + ) as f: + currency = "." + if ns != "": + currency += "." * len(ns.split(".")) + f.write( + "from {currency}schema_classes import SchemaClasses\n".format( + currency=currency + ) + ) + f.write( + "from {currency}schema_classes import RequestClasses\n".format( + currency=currency + ) + ) for name in ns_dict[ns]: if ns_.make_fullname(ns, name) in request_names: f.write( - "{name}Request = RequestClasses.{ns}{name}RequestClass\n".format(name=name, - ns=ns if not ns else ( - ns + "."))) + "{name}Request = RequestClasses.{ns}{name}RequestClass\n".format( + name=name, ns=ns if not ns else (ns + ".") + ) + ) else: - f.write("{name} = SchemaClasses.{ns}{name}Class\n".format(name=name, - ns=ns if not ns else (ns + "."))) + f.write( + "{name} = SchemaClasses.{ns}{name}Class\n".format( + name=name, ns=ns if not ns else (ns + ".") + ) + ) diff --git a/avrogen/protocol_writer.py b/avrogen/protocol_writer.py index 25528cf..26dd81b 100644 --- a/avrogen/protocol_writer.py +++ b/avrogen/protocol_writer.py @@ -16,19 +16,31 @@ def write_protocol_request(message, namespace, writer, use_logical_types): fullname = ns_.make_fullname(namespace, clean_fullname(message.name)) namespace, type_name = ns_.split_fullname(fullname) - writer.write('''\nclass {name}RequestClass(DictWrapper):'''.format(name=type_name)) + writer.write("""\nclass {name}RequestClass(DictWrapper):""".format(name=type_name)) with writer.indent(): writer.write("\n\n") - writer.write('\nRECORD_SCHEMA = PROTOCOL_MESSAGES["%s"].request' % clean_fullname(message.name)) + writer.write( + '\nRECORD_SCHEMA = PROTOCOL_MESSAGES["%s"].request' + % clean_fullname(message.name) + ) - writer.write('\n\n\ndef __init__(self, inner_dict=None):') + writer.write("\n\n\ndef __init__(self, inner_dict=None):") with writer.indent(): - writer.write('\n') - writer.write('super(RequestClasses.{name}RequestClass, self).__init__(inner_dict)'.format(name=fullname)) - - writer.write('\nif inner_dict is None:') + writer.write("\n") + writer.write( + "super(RequestClasses.{name}RequestClass, self).__init__(inner_dict)".format( + name=fullname + ) + ) + + writer.write("\nif inner_dict is None:") with writer.indent(): - write_defaults(message.request, writer, my_full_name=fullname + "Request", use_logical_types=use_logical_types) + write_defaults( + message.request, + writer, + my_full_name=fullname + "Request", + use_logical_types=use_logical_types, + ) write_fields(message.request, writer, use_logical_types) diff --git a/avrogen/schema.py b/avrogen/schema.py index 8db5679..c13a783 100644 --- a/avrogen/schema.py +++ b/avrogen/schema.py @@ -18,11 +18,13 @@ from .core_writer import write_reader_impl import logging -logger = logging.getLogger('avrogen.schema') +logger = logging.getLogger("avrogen.schema") logger.setLevel(logging.INFO) -def generate_schema(schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None): +def generate_schema( + schema_json, use_logical_types=False, custom_imports=None, avro_json_converter=None +): """ Generate file containing concrete classes for RecordSchemas in given avro schema json :param str schema_json: JSON representing avro schema @@ -32,16 +34,22 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a """ if avro_json_converter is None: - avro_json_converter = 'avrojson.AvroJsonConverter' + avro_json_converter = "avrojson.AvroJsonConverter" - if '(' not in avro_json_converter: - avro_json_converter += f'(use_logical_types={use_logical_types}, schema_types=__SCHEMA_TYPES)' + if "(" not in avro_json_converter: + avro_json_converter += ( + f"(use_logical_types={use_logical_types}, schema_types=__SCHEMA_TYPES)" + ) custom_imports = custom_imports or [] names = schema.Names() make_avsc_object(json.loads(schema_json), names) - names = [k for k in six.iteritems(names.names) if isinstance(k[1], (schema.RecordSchema, schema.EnumSchema))] + names = [ + k + for k in six.iteritems(names.names) + if isinstance(k[1], (schema.RecordSchema, schema.EnumSchema)) + ] names = sorted(names, key=lambda x: x[0]) main_out = StringIO() @@ -56,17 +64,19 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a for name, field_schema in names: # type: str, schema.Schema name = clean_fullname(name) - namespace = tuple(name.split('.')[:-1]) + namespace = tuple(name.split(".")[:-1]) if namespace != current_namespace: current_namespace = namespace if isinstance(field_schema, schema.RecordSchema): - logger.debug(f'Writing schema: {clean_fullname(field_schema.fullname)}') + logger.debug(f"Writing schema: {clean_fullname(field_schema.fullname)}") write_schema_record(field_schema, writer, use_logical_types) elif isinstance(field_schema, schema.EnumSchema): - logger.debug(f'Writing enum: {field_schema.fullname}', field_schema.fullname) + logger.debug( + f"Writing enum: {field_schema.fullname}", field_schema.fullname + ) write_enum(field_schema, writer) writer.set_tab(0) - writer.write('\n__SCHEMA_TYPES = {') + writer.write("\n__SCHEMA_TYPES = {") writer.tab() # Lookup table for fullname. @@ -81,10 +91,10 @@ def generate_schema(schema_json, use_logical_types=False, custom_imports=None, a writer.write(f"\n'{n}': {n}Class,") writer.untab() - writer.write('\n}\n\n') + writer.write("\n}\n\n") - writer.write(f'_json_converter = {avro_json_converter}\n') - writer.write('avrojson.set_global_json_converter(_json_converter)\n\n') + writer.write(f"_json_converter = {avro_json_converter}\n") + writer.write("avrojson.set_global_json_converter(_json_converter)\n\n") value = main_out.getvalue() main_out.close() @@ -99,13 +109,15 @@ def write_schema_preamble(writer): :return: """ write_read_file(writer) - writer.write('\n\ndef __get_names_and_schema(json_str):') + writer.write("\n\ndef __get_names_and_schema(json_str):") with writer.indent(): - writer.write('\nnames = avro_schema.Names()') - writer.write('\nschema = make_avsc_object(json.loads(json_str), names)') - writer.write('\nreturn names, schema') - writer.write('\n\n\n_SCHEMA_JSON_STR = __read_file(os.path.join(os.path.dirname(__file__), "schema.avsc"))') - writer.write('\n\n\n__NAMES, _SCHEMA = __get_names_and_schema(_SCHEMA_JSON_STR)') + writer.write("\nnames = avro_schema.Names()") + writer.write("\nschema = make_avsc_object(json.loads(json_str), names)") + writer.write("\nreturn names, schema") + writer.write( + '\n\n\n_SCHEMA_JSON_STR = __read_file(os.path.join(os.path.dirname(__file__), "schema.avsc"))' + ) + writer.write("\n\n\n__NAMES, _SCHEMA = __get_names_and_schema(_SCHEMA_JSON_STR)") def write_populate_schemas(writer): @@ -114,7 +126,9 @@ def write_populate_schemas(writer): :param writer: :return: """ - writer.write('\n__SCHEMAS = dict((n.fullname.lstrip("."), n) for n in six.itervalues(__NAMES.names))\n') + writer.write( + '\n__SCHEMAS = dict((n.fullname.lstrip("."), n) for n in six.itervalues(__NAMES.names))\n' + ) def write_namespace_modules(ns_dict, output_folder): @@ -126,14 +140,17 @@ def write_namespace_modules(ns_dict, output_folder): :return: """ for ns in six.iterkeys(ns_dict): - with open(os.path.join(output_folder, ns.replace('.', os.path.sep), "__init__.py"), "w+") as f: - currency = '.' - if ns != '': - currency += '.' * len(ns.split('.')) + with open( + os.path.join(output_folder, ns.replace(".", os.path.sep), "__init__.py"), + "w+", + ) as f: + currency = "." + if ns != "": + currency += "." * len(ns.split(".")) for name in ns_dict[ns]: - f.write(f'from {currency}schema_classes import {name}Class\n') + f.write(f"from {currency}schema_classes import {name}Class\n") - f.write('\n\n') + f.write("\n\n") for name in ns_dict[ns]: f.write(f"{name} = {name}Class\n") @@ -148,20 +165,22 @@ def write_specific_reader(record_types, output_folder, use_logical_types): """ with open(os.path.join(output_folder, "__init__.py"), "a+") as f: writer = TabbedWriter(f) - writer.write('from typing import cast') - writer.write('\nfrom avrogen.dict_wrapper import DictWrapper') - writer.write('\nfrom .schema_classes import _SCHEMA as get_schema_type') - writer.write('\nfrom .schema_classes import _json_converter as json_converter') + writer.write("from typing import cast") + writer.write("\nfrom avrogen.dict_wrapper import DictWrapper") + writer.write("\nfrom .schema_classes import _SCHEMA as get_schema_type") + writer.write("\nfrom .schema_classes import _json_converter as json_converter") for t in record_types: writer.write(f'\nfrom .schema_classes import {t.split(".")[-1]}Class') - writer.write('\nfrom avro.io import DatumReader') + writer.write("\nfrom avro.io import DatumReader") if use_logical_types: - writer.write('\nfrom avrogen import logical') + writer.write("\nfrom avrogen import logical") write_reader_impl(record_types, writer, use_logical_types) -def write_schema_files(schema_json, output_folder, use_logical_types=False, custom_imports=None): +def write_schema_files( + schema_json, output_folder, use_logical_types=False, custom_imports=None +): """ Generates concrete classes, namespace modules, and a SpecificRecordReader for a given avro schema :param str schema_json: JSON containing avro schema diff --git a/avrogen/tabbed_writer.py b/avrogen/tabbed_writer.py index 5968974..470f6c4 100644 --- a/avrogen/tabbed_writer.py +++ b/avrogen/tabbed_writer.py @@ -12,23 +12,23 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.writer.untab() - def __init__(self, inner_writer, tab_symbol=' '): + def __init__(self, inner_writer, tab_symbol=" "): self.__inner_writer = inner_writer self.__tabs = 0 self.__tab_symbol = tab_symbol - self.__current_tab = '' + self.__current_tab = "" def write(self, text): assert isinstance(text, six.string_types) start_pos = 0 - last_pos = text.find('\n') + last_pos = text.find("\n") while last_pos >= 0: - self.__inner_writer.write(text[start_pos:last_pos + 1]) + self.__inner_writer.write(text[start_pos : last_pos + 1]) self.__inner_writer.write(self.__current_tab) start_pos = last_pos + 1 - last_pos = text.find('\n', start_pos) + last_pos = text.find("\n", start_pos) self.__inner_writer.write(text[start_pos:])