Skip to content

Commit

Permalink
replace unions way with avro library
Browse files Browse the repository at this point in the history
  • Loading branch information
libretto committed Nov 9, 2024
1 parent 445519e commit a1f659f
Show file tree
Hide file tree
Showing 3 changed files with 379 additions and 112 deletions.
59 changes: 23 additions & 36 deletions src/karapace/schema_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from __future__ import annotations

from avro.errors import SchemaParseException
from avro.schema import parse as avro_parse, Schema as AvroSchema
from avro.name import Names as AvroNames
from avro.schema import make_avsc_object, parse as avro_parse, Schema as AvroSchema
from collections.abc import Collection, Mapping, Sequence
from dataclasses import dataclass
from jsonschema import Draft7Validator
Expand All @@ -29,8 +30,8 @@
from karapace.utils import assert_never, json_decode, json_encode, JSONDecodeError
from typing import Any, cast, Final, final

import avro.schema
import hashlib
import json
import logging
import re

Expand Down Expand Up @@ -198,28 +199,17 @@ def schema(self) -> Draft7Validator | AvroSchema | ProtobufSchema:
return parsed_typed_schema.schema


class AvroMerge:
class AvroResolver:
def __init__(self, schema_str: str, dependencies: Mapping[str, Dependency] | None = None):
self.schema_str = json_encode(json_decode(schema_str), compact=True, sort_keys=True)
self.dependencies = dependencies
self.unique_id = 0
self.regex = re.compile(r"^\s*\[")

def union_safe_schema_str(self, schema_str: str) -> str:
# in case we meet union - we use it as is

base_schema = (
f'{{"name": "___RESERVED_KARAPACE_WRAPPER_NAME_{self.unique_id}___",'
f'"type": "record", "fields": [{{"name": "name", "type":'
)
if self.regex.match(schema_str):
return f"{base_schema} {schema_str}}}]}}"
return f"{base_schema} [{schema_str}]}}]}}"

def builder(self, schema_str: str, dependencies: Mapping[str, Dependency] | None = None) -> str:
def builder(self, schema_str: str, dependencies: Mapping[str, Dependency] | None = None) -> list:
"""To support references in AVRO we iteratively merge all referenced schemas with current schema"""
stack: list[tuple[str, Mapping[str, Dependency] | None]] = [(schema_str, dependencies)]
merged_schemas = []
merge: list = []

while stack:
current_schema_str, current_dependencies = stack.pop()
Expand All @@ -229,12 +219,15 @@ def builder(self, schema_str: str, dependencies: Mapping[str, Dependency] | None
stack.append((dependency.schema.schema_str, dependency.schema.dependencies))
else:
self.unique_id += 1
merged_schemas.append(self.union_safe_schema_str(current_schema_str))
merge.append(current_schema_str)

return ",\n".join(merged_schemas)
return merge

def wrap(self) -> str:
return "[\n" + self.builder(self.schema_str, self.dependencies) + "\n]"
def resolve(self) -> list:
"""Resolve the given ``schema_str`` with ``dependencies`` to a list of schemas
sorted in an order where all referenced schemas are located prior to their referrers.
"""
return self.builder(self.schema_str, self.dependencies)


def parse(
Expand All @@ -249,34 +242,30 @@ def parse(
) -> ParsedTypedSchema:
if schema_type not in [SchemaType.AVRO, SchemaType.JSONSCHEMA, SchemaType.PROTOBUF]:
raise InvalidSchema(f"Unknown parser {schema_type} for {schema_str}")
parsed_schema_result: Draft7Validator | AvroSchema | ProtobufSchema
parsed_schema: Draft7Validator | AvroSchema | ProtobufSchema
if schema_type is SchemaType.AVRO:
try:
if dependencies:
wrapped_schema_str = AvroMerge(schema_str, dependencies).wrap()
schemas_list = AvroResolver(schema_str, dependencies).resolve()
names = AvroNames(validate_names=validate_avro_names)
merged_schema = None
for schema in schemas_list:
# Merge dep with all previously merged ones
merged_schema = make_avsc_object(json.loads(schema), names)
merged_schema_str = str(merged_schema)
else:
wrapped_schema_str = schema_str
merged_schema_str = schema_str
parsed_schema = parse_avro_schema_definition(
wrapped_schema_str,
merged_schema_str,
validate_enum_symbols=validate_avro_enum_symbols,
validate_names=validate_avro_names,
)
if dependencies:
if isinstance(parsed_schema, avro.schema.UnionSchema):
parsed_schema_result = parsed_schema.schemas[-1].fields[0].type.schemas[-1]

else:
raise InvalidSchema
else:
parsed_schema_result = parsed_schema
return ParsedTypedSchema(
schema_type=schema_type,
schema_str=schema_str,
schema=parsed_schema_result,
schema=parsed_schema,
references=references,
dependencies=dependencies,
schema_wrapped=parsed_schema,
)
except (SchemaParseException, JSONDecodeError, TypeError) as e:
raise InvalidSchema from e
Expand Down Expand Up @@ -346,10 +335,8 @@ def __init__(
schema: Draft7Validator | AvroSchema | ProtobufSchema,
references: Sequence[Reference] | None = None,
dependencies: Mapping[str, Dependency] | None = None,
schema_wrapped: Draft7Validator | AvroSchema | ProtobufSchema | None = None,
) -> None:
self._schema_cached: Draft7Validator | AvroSchema | ProtobufSchema | None = schema
self.schema_wrapped = schema_wrapped
super().__init__(
schema_type=schema_type,
schema_str=schema_str,
Expand Down
Loading

0 comments on commit a1f659f

Please sign in to comment.