diff --git a/CHANGELOG.md b/CHANGELOG.md index 2131d99b..93823122 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,14 +9,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- Support for exporting a Data Contract to an Iceberg schema definition. +- When importing in dbt format, add the dbt `not_null` information as a datacontract `required` field (#547) ### Changed - Type conversion when importing contracts into dbt and exporting contracts from dbt (#534) - Ensure 'name' is the first column when exporting in dbt format, considering column attributes (#541) +- Rename dbt's `tests` to `data_tests` (#548) ### Fixed - Modify the arguments to narrow down the import target with `--dbt-model` (#532) - SodaCL: Prevent `KeyError: 'fail'` from happening when testing with SodaCL +- fix: populate database and schema values for bigquery in exported dbt sources (#543) +- Fixing the options for importing and exporting to standard output (#544) ## [0.10.15] - 2024-10-26 diff --git a/README.md b/README.md index 4f70fc35..d59fa40a 100644 --- a/README.md +++ b/README.md @@ -761,7 +761,7 @@ models: │ t-staging-sql|odcs|odcs_v2|odcs_v3|rdf|avro|protobuf │ │ |great-expectations|terraform|avro-idl|sql|sql-query │ │ |html|go|bigquery|dbml|spark|sqlalchemy|data-caterer │ -│ |dcs] │ +│ |dcs|iceberg] │ │ --output PATH Specify the file path where the exported data will be │ │ saved. If no path is provided, the output will be │ │ printed to stdout. │ @@ -822,6 +822,7 @@ Available export options: | `sqlalchemy` | Export to SQLAlchemy Models | ✅ | | `data-caterer` | Export to Data Caterer in YAML format | ✅ | | `dcs` | Export to Data Contract Specification in YAML format | ✅ | +| `iceberg` | Export to an Iceberg JSON Schema Definition | partial | | Missing something? | Please create an issue on GitHub | TBD | @@ -945,6 +946,63 @@ models: - **avroLogicalType**: Specifies the logical type of the field in Avro. In this example, it is `local-timestamp-micros`. - **avroDefault**: Specifies the default value for the field in Avro. In this example, it is 1672534861000000 which corresponds to ` 2023-01-01 01:01:01 UTC`. +#### Iceberg + +Exports to an [Iceberg Table Json Schema Definition](https://iceberg.apache.org/spec/#appendix-c-json-serialization). + +This export only supports a single model export at a time because Iceberg's schema definition is for a single table and the exporter maps 1 model to 1 table, use the `--model` flag +to limit your contract export to a single model. + +```bash + $ datacontract export --format iceberg --model orders https://datacontract.com/examples/orders-latest/datacontract.yaml --output /tmp/orders_iceberg.json + + $ cat /tmp/orders_iceberg.json | jq '.' +{ + "type": "struct", + "fields": [ + { + "id": 1, + "name": "order_id", + "type": "string", + "required": true + }, + { + "id": 2, + "name": "order_timestamp", + "type": "timestamptz", + "required": true + }, + { + "id": 3, + "name": "order_total", + "type": "long", + "required": true + }, + { + "id": 4, + "name": "customer_id", + "type": "string", + "required": false + }, + { + "id": 5, + "name": "customer_email_address", + "type": "string", + "required": true + }, + { + "id": 6, + "name": "processed_timestamp", + "type": "timestamptz", + "required": true + } + ], + "schema-id": 0, + "identifier-field-ids": [ + 1 + ] +} +``` ### import diff --git a/datacontract/cli.py b/datacontract/cli.py index f02675d0..0c49fb30 100644 --- a/datacontract/cli.py +++ b/datacontract/cli.py @@ -221,7 +221,7 @@ def export( ) # Don't interpret console markup in output. if output is None: - console.print(result, markup=False) + console.print(result, markup=False, soft_wrap=True) else: with output.open("w") as f: f.write(result) @@ -298,7 +298,7 @@ def import_( iceberg_table=iceberg_table, ) if output is None: - console.print(result.to_yaml()) + console.print(result.to_yaml(), markup=False, soft_wrap=True) else: with output.open("w") as f: f.write(result.to_yaml()) diff --git a/datacontract/export/dbt_converter.py b/datacontract/export/dbt_converter.py index 48f8e0af..027e72a6 100644 --- a/datacontract/export/dbt_converter.py +++ b/datacontract/export/dbt_converter.py @@ -59,7 +59,7 @@ def to_dbt_staging_sql(data_contract_spec: DataContractSpecification, model_name def to_dbt_sources_yaml(data_contract_spec: DataContractSpecification, server: str = None): - source = {"name": data_contract_spec.id, "tables": []} + source = {"name": data_contract_spec.id} dbt = { "version": 2, "sources": [source], @@ -72,9 +72,14 @@ def to_dbt_sources_yaml(data_contract_spec: DataContractSpecification, server: s adapter_type = None if found_server is not None: adapter_type = found_server.type - source["database"] = found_server.database - source["schema"] = found_server.schema_ + if adapter_type == "bigquery": + source["database"] = found_server.project + source["schema"] = found_server.dataset + else: + source["database"] = found_server.database + source["schema"] = found_server.schema_ + source["tables"] = [] for model_key, model_value in data_contract_spec.models.items(): dbt_model = _to_dbt_source_table(model_key, model_value, adapter_type) source["tables"].append(dbt_model) @@ -144,10 +149,12 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte column = {"name": field_name} adapter_type = adapter_type or "snowflake" dbt_type = convert_to_sql_type(field, adapter_type) + + column["data_tests"] = [] if dbt_type is not None: column["data_type"] = dbt_type else: - column.setdefault("tests", []).append( + column["data_tests"].append( {"dbt_expectations.dbt_expectations.expect_column_values_to_be_of_type": {"column_type": dbt_type}} ) if field.description is not None: @@ -156,21 +163,21 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte if supports_constraints: column.setdefault("constraints", []).append({"type": "not_null"}) else: - column.setdefault("tests", []).append("not_null") + column["data_tests"].append("not_null") if field.unique: if supports_constraints: column.setdefault("constraints", []).append({"type": "unique"}) else: - column.setdefault("tests", []).append("unique") + column["data_tests"].append("unique") if field.enum is not None and len(field.enum) > 0: - column.setdefault("tests", []).append({"accepted_values": {"values": field.enum}}) + column["data_tests"].append({"accepted_values": {"values": field.enum}}) if field.minLength is not None or field.maxLength is not None: length_test = {} if field.minLength is not None: length_test["min_value"] = field.minLength if field.maxLength is not None: length_test["max_value"] = field.maxLength - column.setdefault("tests", []).append( + column["data_tests"].append( {"dbt_expectations.expect_column_value_lengths_to_be_between": length_test} ) if field.pii is not None: @@ -181,7 +188,7 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte column.setdefault("tags", []).extend(field.tags) if field.pattern is not None: # Beware, the data contract pattern is a regex, not a like pattern - column.setdefault("tests", []).append( + column["data_tests"].append( {"dbt_expectations.expect_column_values_to_match_regex": {"regex": field.pattern}} ) if ( @@ -195,7 +202,7 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte range_test["min_value"] = field.minimum if field.maximum is not None: range_test["max_value"] = field.maximum - column.setdefault("tests", []).append({"dbt_expectations.expect_column_values_to_be_between": range_test}) + column["data_tests"].append({"dbt_expectations.expect_column_values_to_be_between": range_test}) elif ( field.exclusiveMinimum is not None or field.exclusiveMaximum is not None @@ -208,18 +215,18 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte if field.exclusiveMaximum is not None: range_test["max_value"] = field.exclusiveMaximum range_test["strictly"] = True - column.setdefault("tests", []).append({"dbt_expectations.expect_column_values_to_be_between": range_test}) + column["data_tests"].append({"dbt_expectations.expect_column_values_to_be_between": range_test}) else: if field.minimum is not None: - column.setdefault("tests", []).append( + column["data_tests"].append( {"dbt_expectations.expect_column_values_to_be_between": {"min_value": field.minimum}} ) if field.maximum is not None: - column.setdefault("tests", []).append( + column["data_tests"].append( {"dbt_expectations.expect_column_values_to_be_between": {"max_value": field.maximum}} ) if field.exclusiveMinimum is not None: - column.setdefault("tests", []).append( + column["data_tests"].append( { "dbt_expectations.expect_column_values_to_be_between": { "min_value": field.exclusiveMinimum, @@ -228,7 +235,7 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte } ) if field.exclusiveMaximum is not None: - column.setdefault("tests", []).append( + column["data_tests"].append( { "dbt_expectations.expect_column_values_to_be_between": { "max_value": field.exclusiveMaximum, @@ -237,5 +244,8 @@ def _to_column(field_name: str, field: Field, supports_constraints: bool, adapte } ) + if not column["data_tests"]: + column.pop("data_tests") + # TODO: all constraints return column diff --git a/datacontract/export/exporter.py b/datacontract/export/exporter.py index 6532ab7d..ab3cefd0 100644 --- a/datacontract/export/exporter.py +++ b/datacontract/export/exporter.py @@ -40,6 +40,8 @@ class ExportFormat(str, Enum): sqlalchemy = "sqlalchemy" data_caterer = "data-caterer" dcs = "dcs" + iceberg = "iceberg" + @classmethod def get_supported_formats(cls): diff --git a/datacontract/export/exporter_factory.py b/datacontract/export/exporter_factory.py index 7923f651..d059dab7 100644 --- a/datacontract/export/exporter_factory.py +++ b/datacontract/export/exporter_factory.py @@ -168,3 +168,7 @@ def load_module_class(module_path, class_name): exporter_factory.register_lazy_exporter( name=ExportFormat.dcs, module_path="datacontract.export.dcs_exporter", class_name="DcsExporter" ) + +exporter_factory.register_lazy_exporter( + name=ExportFormat.iceberg, module_path="datacontract.export.iceberg_converter", class_name="IcebergExporter" +) diff --git a/datacontract/export/iceberg_converter.py b/datacontract/export/iceberg_converter.py new file mode 100644 index 00000000..7953edfb --- /dev/null +++ b/datacontract/export/iceberg_converter.py @@ -0,0 +1,188 @@ +from pyiceberg import types +from pyiceberg.schema import Schema, assign_fresh_schema_ids + +from datacontract.export.exporter import Exporter +from datacontract.model.data_contract_specification import ( + DataContractSpecification, + Field, + Model, +) + + +class IcebergExporter(Exporter): + """ + Exporter class for exporting data contracts to Iceberg schemas. + """ + + def export( + self, + data_contract: DataContractSpecification, + model, + server, + sql_server_type, + export_args, + ): + """ + Export the given data contract model to an Iceberg schema. + + Args: + data_contract (DataContractSpecification): The data contract specification. + model: The model to export, currently just supports one model. + server: Not used in this implementation. + sql_server_type: Not used in this implementation. + export_args: Additional arguments for export. + + Returns: + str: A string representation of the Iceberg json schema. + """ + + return to_iceberg(data_contract, model) + + +def to_iceberg(contract: DataContractSpecification, model: str) -> str: + """ + Converts a DataContractSpecification into an Iceberg json schema string. JSON string follows https://iceberg.apache.org/spec/#appendix-c-json-serialization. + + Args: + contract (DataContractSpecification): The data contract specification containing models. + model: The model to export, currently just supports one model. + + Returns: + str: A string representation of the Iceberg json schema. + """ + if model is None or model == "all": + if len(contract.models.items()) != 1: + # Iceberg doesn't have a way to combine multiple models into a single schema, an alternative would be to export json lines + raise Exception(f"Can only output one model at a time, found {len(contract.models.items())} models") + for model_name, model in contract.models.items(): + schema = to_iceberg_schema(model) + else: + if model not in contract.models: + raise Exception(f"model {model} not found in contract") + schema = to_iceberg_schema(contract.models[model]) + + return schema.model_dump_json() + + +def to_iceberg_schema(model: Model) -> types.StructType: + """ + Convert a model to a Iceberg schema. + + Args: + model (Model): The model to convert. + + Returns: + types.StructType: The corresponding Iceberg schema. + """ + iceberg_fields = [] + primary_keys = [] + for field_name, spec_field in model.fields.items(): + iceberg_field = make_field(field_name, spec_field) + iceberg_fields.append(iceberg_field) + + if spec_field.primaryKey: + primary_keys.append(iceberg_field.name) + + schema = Schema(*iceberg_fields) + + # apply non-0 field IDs so we can set the identifier fields for the schema + schema = assign_fresh_schema_ids(schema) + for field in schema.fields: + if field.name in primary_keys: + schema.identifier_field_ids.append(field.field_id) + + return schema + + +def make_field(field_name, field): + field_type = get_field_type(field) + + # Note: might want to re-populate field_id from config['icebergFieldId'] if it exists, however, it gets + # complicated since field_ids impact the list and map element_ids, and the importer is not keeping track of those. + # Even if IDs are re-constituted, it seems like the SDK code would still reset them before any operation against a catalog, + # so it's likely not worth it. + + # Note 2: field_id defaults to 0 to signify that the exporter is not attempting to populate meaningful values. + # also, the Iceberg sdk catalog code will re-set the fieldIDs prior to executing any table operations on the schema + # ref: https://github.com/apache/iceberg-python/pull/1072 + return types.NestedField(field_id=0, name=field_name, field_type=field_type, required=field.required) + + +def make_list(item): + field_type = get_field_type(item) + + # element_id defaults to 0 to signify that the exporter is not attempting to populate meaningful values (see #make_field) + return types.ListType(element_id=0, element_type=field_type, element_required=item.required) + + +def make_map(field): + key_type = get_field_type(field.keys) + value_type = get_field_type(field.values) + + # key_id and value_id defaults to 0 to signify that the exporter is not attempting to populate meaningful values (see #make_field) + return types.MapType( + key_id=0, key_type=key_type, value_id=0, value_type=value_type, value_required=field.values.required + ) + + +def to_struct_type(fields: dict[str, Field]) -> types.StructType: + """ + Convert a dictionary of fields to a Iceberg StructType. + + Args: + fields (dict[str, Field]): The fields to convert. + + Returns: + types.StructType: The corresponding Iceberg StructType. + """ + struct_fields = [] + for field_name, field in fields.items(): + struct_field = make_field(field_name, field) + struct_fields.append(struct_field) + return types.StructType(*struct_fields) + + +def get_field_type(field: Field) -> types.IcebergType: + """ + Convert a field to a Iceberg IcebergType. + + Args: + field (Field): The field to convert. + + Returns: + types.IcebergType: The corresponding Iceberg IcebergType. + """ + field_type = field.type + if field_type is None or field_type in ["null"]: + return types.NullType() + if field_type == "array": + return make_list(field.items) + if field_type == "map": + return make_map(field) + if field_type in ["object", "record", "struct"]: + return to_struct_type(field.fields) + if field_type in ["string", "varchar", "text"]: + return types.StringType() + if field_type in ["number", "decimal", "numeric"]: + precision = field.precision if field.precision is not None else 38 + scale = field.scale if field.scale is not None else 0 + return types.DecimalType(precision=precision, scale=scale) + if field_type in ["integer", "int"]: + return types.IntegerType() + if field_type in ["bigint", "long"]: + return types.LongType() + if field_type == "float": + return types.FloatType() + if field_type == "double": + return types.DoubleType() + if field_type == "boolean": + return types.BooleanType() + if field_type in ["timestamp", "timestamp_tz"]: + return types.TimestamptzType() + if field_type == "timestamp_ntz": + return types.TimestampType() + if field_type == "date": + return types.DateType() + if field_type == "bytes": + return types.BinaryType() + return types.BinaryType() diff --git a/datacontract/imports/dbt_importer.py b/datacontract/imports/dbt_importer.py index 95320697..96eece78 100644 --- a/datacontract/imports/dbt_importer.py +++ b/datacontract/imports/dbt_importer.py @@ -3,6 +3,8 @@ from dbt.artifacts.resources.v1.components import ColumnInfo from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.nodes import GenericTestNode +from dbt_common.contracts.constraints import ConstraintType from datacontract.imports.bigquery_importer import map_type_from_bigquery from datacontract.imports.importer import Importer @@ -44,7 +46,9 @@ def read_dbt_manifest(manifest_path: str) -> Manifest: """Read a manifest from file.""" with open(file=manifest_path, mode="r", encoding="utf-8") as f: manifest_dict: dict = json.load(f) - return Manifest.from_dict(manifest_dict) + manifest = Manifest.from_dict(manifest_dict) + manifest.build_parent_and_child_maps() + return manifest def import_dbt_manifest( @@ -74,27 +78,81 @@ def import_dbt_manifest( dc_model = Model( description=model_contents.description, tags=model_contents.tags, - fields=create_fields(columns=model_contents.columns, adapter_type=adapter_type), + fields=create_fields( + manifest, + model_unique_id=model_contents.unique_id, + columns=model_contents.columns, + adapter_type=adapter_type, + ), ) data_contract_specification.models[model_contents.name] = dc_model return data_contract_specification + def convert_data_type_by_adapter_type(data_type: str, adapter_type: str) -> str: if adapter_type == "bigquery": return map_type_from_bigquery(data_type) return data_type -def create_fields(columns: dict[str, ColumnInfo], adapter_type: str) -> dict[str, Field]: - fields = { - column.name: Field( - description=column.description, - type=convert_data_type_by_adapter_type(column.data_type, adapter_type) if column.data_type else "", - tags=column.tags, +def create_fields( + manifest: Manifest, model_unique_id: str, columns: dict[str, ColumnInfo], adapter_type: str +) -> dict[str, Field]: + fields = {column.name: create_field(manifest, model_unique_id, column, adapter_type) for column in columns.values()} + return fields + + +def get_column_tests(manifest: Manifest, model_name: str, column_name: str) -> list[dict[str, str]]: + column_tests = [] + model_node = manifest.nodes.get(model_name) + if not model_node: + raise ValueError(f"Model {model_name} not found in manifest.") + + model_unique_id = model_node.unique_id + test_ids = manifest.child_map.get(model_unique_id, []) + + for test_id in test_ids: + test_node = manifest.nodes.get(test_id) + if not test_node or test_node.resource_type != "test": + continue + + if not isinstance(test_node, GenericTestNode): + continue + + if test_node.column_name != column_name: + continue + + if test_node.config.where is not None: + continue + + column_tests.append( + { + "test_name": test_node.name, + "test_type": test_node.test_metadata.name, + "column": test_node.column_name, + } ) - for column in columns.values() - } + return column_tests - return fields + +def create_field(manifest: Manifest, model_unique_id: str, column: ColumnInfo, adapter_type: str) -> Field: + column_type = convert_data_type_by_adapter_type(column.data_type, adapter_type) if column.data_type else "" + field = Field( + description=column.description, + type=column_type, + tags=column.tags, + ) + + all_tests = get_column_tests(manifest, model_unique_id, column.name) + + required = False + if any(constraint.type == ConstraintType.not_null for constraint in column.constraints): + required = True + if [test for test in all_tests if test["test_type"] == "not_null"]: + required = True + if required: + field.required = required + + return field diff --git a/datacontract/imports/iceberg_importer.py b/datacontract/imports/iceberg_importer.py index 86e1982b..f63db25f 100644 --- a/datacontract/imports/iceberg_importer.py +++ b/datacontract/imports/iceberg_importer.py @@ -42,8 +42,19 @@ def import_iceberg( model = Model(type="table", title=table_name) + # Iceberg identifier_fields aren't technically primary keys since Iceberg doesn't support primary keys, + # but they are close enough that we can probably treat them as primary keys on the conversion. + # ref: https://iceberg.apache.org/spec/#identifier-field-ids + # this code WILL NOT support finding nested primary key fields. + identifier_fields_ids = schema.identifier_field_ids + for field in schema.fields: - model.fields[field.name] = _field_from_nested_field(field) + model_field = _field_from_nested_field(field) + + if field.field_id in identifier_fields_ids: + model_field.primaryKey = True + + model.fields[field.name] = model_field data_contract_specification.models[table_name] = model return data_contract_specification diff --git a/pyproject.toml b/pyproject.toml index 2579b491..9c7f3e2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fastapi==0.115.6", # move to extra? "uvicorn==0.32.1", # move to extra? "fastjsonschema>=2.19.1,<2.22.0", - "fastparquet==2024.5.0", + "fastparquet==2024.11.0", "python-multipart==0.0.19", "rich>=13.7,<13.10", "simple-ddl-parser==1.7.1", @@ -57,7 +57,7 @@ databricks = [ ] iceberg = [ - "pyiceberg==0.8.0" + "pyiceberg==0.8.1" ] kafka = [ @@ -104,7 +104,7 @@ all = [ dev = [ "datacontract-cli[all]", - "httpx==0.28.0", + "httpx==0.28.1", "kafka-python", "moto==5.0.22", "pandas>=2.1.0", @@ -113,7 +113,7 @@ dev = [ "pytest-xdist", "pymssql==2.3.2", "ruff", - "testcontainers[minio,postgres,kafka,mssql]==4.8.2", + "testcontainers[minio,postgres,kafka,mssql]==4.9.0", "trino==0.330.0", ] diff --git a/tests/fixtures/dbt/export/datacontract.yaml b/tests/fixtures/dbt/export/datacontract.yaml index a3335891..046ef09a 100644 --- a/tests/fixtures/dbt/export/datacontract.yaml +++ b/tests/fixtures/dbt/export/datacontract.yaml @@ -20,8 +20,8 @@ servers: type: bigquery environment: production account: my-account - database: my-database - schema: my-schema + project: my-database + dataset: my-schema roles: - name: analyst_us description: Access to the data for US region diff --git a/tests/test_export_dbt_models.py b/tests/test_export_dbt_models.py index aa62b96b..302e407f 100644 --- a/tests/test_export_dbt_models.py +++ b/tests/test_export_dbt_models.py @@ -22,7 +22,7 @@ def test_to_dbt_models(): expected_dbt_model = """ version: 2 models: - - name: orders + - name: orders config: meta: owner: checkout @@ -37,12 +37,12 @@ def test_to_dbt_models(): constraints: - type: not_null - type: unique - tests: + data_tests: - dbt_expectations.expect_column_value_lengths_to_be_between: min_value: 8 max_value: 10 - dbt_expectations.expect_column_values_to_match_regex: - regex: ^B[0-9]+$ + regex: ^B[0-9]+$ meta: classification: sensitive pii: true @@ -51,9 +51,9 @@ def test_to_dbt_models(): - name: order_total data_type: NUMBER constraints: - - type: not_null + - type: not_null description: The order_total field - tests: + data_tests: - dbt_expectations.expect_column_values_to_be_between: min_value: 0 max_value: 1000000 @@ -61,7 +61,7 @@ def test_to_dbt_models(): data_type: TEXT constraints: - type: not_null - tests: + data_tests: - accepted_values: values: - 'pending' diff --git a/tests/test_export_dbt_sources.py b/tests/test_export_dbt_sources.py index 43922fd8..26bc2e98 100644 --- a/tests/test_export_dbt_sources.py +++ b/tests/test_export_dbt_sources.py @@ -43,7 +43,7 @@ def test_to_dbt_sources(): columns: - name: order_id data_type: VARCHAR - tests: + data_tests: - not_null - unique - dbt_expectations.expect_column_value_lengths_to_be_between: @@ -59,14 +59,14 @@ def test_to_dbt_sources(): - name: order_total description: The order_total field data_type: NUMBER - tests: + data_tests: - not_null - dbt_expectations.expect_column_values_to_be_between: min_value: 0 max_value: 1000000 - name: order_status data_type: TEXT - tests: + data_tests: - not_null - accepted_values: values: @@ -97,7 +97,7 @@ def test_to_dbt_sources_bigquery(): columns: - name: order_id data_type: STRING - tests: + data_tests: - not_null - unique - dbt_expectations.expect_column_value_lengths_to_be_between: @@ -113,14 +113,14 @@ def test_to_dbt_sources_bigquery(): - name: order_total description: The order_total field data_type: INT64 - tests: + data_tests: - not_null - dbt_expectations.expect_column_values_to_be_between: min_value: 0 max_value: 1000000 - name: order_status data_type: STRING - tests: + data_tests: - not_null - accepted_values: values: diff --git a/tests/test_export_iceberg.py b/tests/test_export_iceberg.py new file mode 100644 index 00000000..cc15eab6 --- /dev/null +++ b/tests/test_export_iceberg.py @@ -0,0 +1,254 @@ +import tempfile + +from pyiceberg import types +from pyiceberg.schema import Schema, assign_fresh_schema_ids +from typer.testing import CliRunner + +from datacontract.cli import app +from datacontract.export.iceberg_converter import IcebergExporter +from datacontract.lint.resolve import resolve_data_contract + + +def test_cli(): + with tempfile.NamedTemporaryFile(delete=True) as tmp_input_file: + # create temp file with content + tmp_input_file.write(b""" + dataContractSpecification: 0.9.3 + id: my-id + info: + title: My Title + version: 1.0.0 + models: + orders: + fields: + order_id: + type: int + required: true + primaryKey: true + """) + tmp_input_file.flush() + + with tempfile.NamedTemporaryFile(delete=True) as tmp_output_file: + runner = CliRunner() + result = runner.invoke( + app, ["export", tmp_input_file.name, "--format", "iceberg", "--output", tmp_output_file.name] + ) + assert result.exit_code == 0 + + with open(tmp_output_file.name, "r") as f: + schema = Schema.model_validate_json(f.read()) + + assert len(schema.fields) == 1 + _assert_field(schema, "order_id", types.IntegerType(), True) + assert schema.identifier_field_ids == [1] + + +def test_type_conversion(): + datacontract = resolve_data_contract( + data_contract_str=""" + dataContractSpecification: 0.9.3 + id: my-id + info: + title: My Title + version: 1.0.0 + models: + datatypes: + fields: + string_type: + type: string + text_type: + type: text + varchar_type: + type: varchar + number_type: + type: number + decimal_type: + type: decimal + precision: 4 + scale: 2 + numeric_type: + type: numeric + int_type: + type: int + integer_type: + type: integer + long_type: + type: long + bigint_type: + type: bigint + float_type: + type: float + double_type: + type: double + boolean_type: + type: boolean + timestamp_type: + type: timestamp + timestamp_tz_type: + type: timestamp_tz + timestamp_ntz_type: + type: timestamp_ntz + date_type: + type: date + array_type: + type: array + items: + type: string + map_type: + type: map + keys: + type: string + values: + type: int + object_type: + type: object + fields: + object_field1: + type: string + record_type: + type: record + fields: + record_field1: + type: int + struct_type: + type: struct + fields: + struct_field1: + type: float + """, + inline_definitions=True, + ) + datacontract.model_dump() + schema = Schema.model_validate_json(_export(datacontract)) + + assert len(schema.fields) == 22 + _assert_field(schema, "string_type", types.StringType(), False) + _assert_field(schema, "text_type", types.StringType(), False) + _assert_field(schema, "varchar_type", types.StringType(), False) + _assert_field(schema, "number_type", types.DecimalType(precision=38, scale=0), False) + _assert_field(schema, "decimal_type", types.DecimalType(precision=4, scale=2), False) + _assert_field(schema, "numeric_type", types.DecimalType(precision=38, scale=0), False) + _assert_field(schema, "int_type", types.IntegerType(), False) + _assert_field(schema, "integer_type", types.IntegerType(), False) + _assert_field(schema, "long_type", types.LongType(), False) + _assert_field(schema, "bigint_type", types.LongType(), False) + _assert_field(schema, "float_type", types.FloatType(), False) + _assert_field(schema, "double_type", types.DoubleType(), False) + _assert_field(schema, "boolean_type", types.BooleanType(), False) + _assert_field(schema, "timestamp_type", types.TimestamptzType(), False) + _assert_field(schema, "timestamp_tz_type", types.TimestamptzType(), False) + _assert_field(schema, "timestamp_ntz_type", types.TimestampType(), False) + _assert_field(schema, "date_type", types.DateType(), False) + _assert_field( + schema, + "array_type", + types.ListType(element_id=0, element_type=types.StringType(), element_required=False), + False, + ) + _assert_field( + schema, + "map_type", + types.MapType( + key_id=0, key_type=types.StringType(), value_id=0, value_type=types.IntegerType(), value_required=False + ), + False, + ) + _assert_field( + schema, + "object_type", + types.StructType( + types.NestedField(field_id=0, name="object_field1", field_type=types.StringType(), required=False) + ), + False, + ) + _assert_field( + schema, + "record_type", + types.StructType( + types.NestedField(field_id=0, name="record_field1", field_type=types.IntegerType(), required=False) + ), + False, + ) + _assert_field( + schema, + "struct_type", + types.StructType( + types.NestedField(field_id=0, name="struct_field1", field_type=types.FloatType(), required=False) + ), + False, + ) + + +def test_round_trip(): + with open("fixtures/iceberg/nested_schema.json", "r") as f: + starting_schema = Schema.model_validate_json(f.read()) + + import_runner = CliRunner() + import_result = import_runner.invoke( + app, + [ + "import", + "--format", + "iceberg", + "--source", + "fixtures/iceberg/nested_schema.json", + "--iceberg-table", + "test-table", + ], + ) + + assert import_result.exit_code == 0 + output = import_result.stdout.strip() + + with tempfile.NamedTemporaryFile(delete=True) as tmp_input_file: + # create temp file with content + tmp_input_file.write(output.encode()) + tmp_input_file.flush() + + with tempfile.NamedTemporaryFile(delete=True) as tmp_output_file: + runner = CliRunner() + result = runner.invoke( + app, ["export", tmp_input_file.name, "--format", "iceberg", "--output", tmp_output_file.name] + ) + assert result.exit_code == 0 + + with open(tmp_output_file.name, "r") as f: + ending_schema = Schema.model_validate_json(f.read()) + + # don't use IDs in equality check since SDK resets them anyway + starting_schema = assign_fresh_schema_ids(starting_schema) + ending_schema = assign_fresh_schema_ids(ending_schema) + assert starting_schema == ending_schema + + +def _assert_field(schema, field_name, field_type, required): + field = None + for f in schema.fields: + if f.name == field_name: + field = f + break + + assert field is not None + assert field.name == field_name + assert field.required == required + + found_type = field.field_type + if found_type.is_primitive: + assert found_type == field_type + elif isinstance(found_type, types.ListType): + assert found_type.element_type == field_type.element_type + assert found_type.element_required == field_type.element_required + elif isinstance(found_type, types.MapType): + assert found_type.key_type == field_type.key_type + assert found_type.value_type == field_type.value_type + assert found_type.value_required == field_type.value_required + elif isinstance(found_type, types.StructType): + assert len(found_type.fields) == len(field_type.fields) + for nested_field in field_type.fields: + _assert_field(found_type, nested_field.name, nested_field.field_type, nested_field.required) + else: + raise ValueError(f"Unexpected field type: {found_type}") + + +def _export(datacontract, model=None): + return IcebergExporter("iceberg").export(datacontract, model, None, None, None) diff --git a/tests/test_import_dbt.py b/tests/test_import_dbt.py index 7a9cf4bd..fa4a3f23 100644 --- a/tests/test_import_dbt.py +++ b/tests/test_import_dbt.py @@ -84,9 +84,11 @@ def test_import_dbt_manifest(): order_id: type: integer description: This is a unique identifier for an order + required: true customer_id: type: integer description: Foreign key to the customers table + required: true order_date: type: date description: Date (UTC) that the order was placed @@ -114,18 +116,23 @@ def test_import_dbt_manifest(): credit_card_amount: type: double description: Amount of the order (AUD) paid for by credit card + required: true coupon_amount: type: double description: Amount of the order (AUD) paid for by coupon + required: true bank_transfer_amount: type: double description: Amount of the order (AUD) paid for by bank transfer + required: true gift_card_amount: type: double description: Amount of the order (AUD) paid for by gift card + required: true amount: type: double description: Total amount (AUD) of the order + required: true tags: [] stg_customers: description: '' @@ -133,6 +140,7 @@ def test_import_dbt_manifest(): customer_id: type: integer description: '' + required: true first_name: type: varchar description: '' @@ -146,6 +154,7 @@ def test_import_dbt_manifest(): order_id: type: integer description: '' + required: true customer_id: type: integer description: '' @@ -162,6 +171,7 @@ def test_import_dbt_manifest(): payment_id: type: integer description: '' + required: true order_id: type: integer description: '' @@ -179,6 +189,7 @@ def test_import_dbt_manifest(): customer_id: type: integer description: This is a unique identifier for a customer + required: true first_name: type: varchar description: Customer's first name. PII. @@ -226,9 +237,11 @@ def test_import_dbt_manifest_bigquery(): order_id: type: bigint description: This is a unique identifier for an order + required: true customer_id: type: bigint description: Foreign key to the customers table + required: true order_date: type: date description: Date (UTC) that the order was placed @@ -256,18 +269,23 @@ def test_import_dbt_manifest_bigquery(): credit_card_amount: type: double description: Amount of the order (AUD) paid for by credit card + required: true coupon_amount: type: double description: Amount of the order (AUD) paid for by coupon + required: true bank_transfer_amount: type: double description: Amount of the order (AUD) paid for by bank transfer + required: true gift_card_amount: type: double description: Amount of the order (AUD) paid for by gift card + required: true amount: type: double description: Total amount (AUD) of the order + required: true tags: [] stg_customers: description: '' @@ -275,6 +293,7 @@ def test_import_dbt_manifest_bigquery(): customer_id: type: bigint description: '' + required: true first_name: type: string description: '' @@ -288,6 +307,7 @@ def test_import_dbt_manifest_bigquery(): order_id: type: bigint description: '' + required: true customer_id: type: bigint description: '' @@ -304,6 +324,7 @@ def test_import_dbt_manifest_bigquery(): payment_id: type: bigint description: '' + required: true order_id: type: bigint description: '' @@ -321,6 +342,7 @@ def test_import_dbt_manifest_bigquery(): customer_id: type: bigint description: This is a unique identifier for a customer + required: true first_name: type: string description: Customer's first name. PII. @@ -390,6 +412,7 @@ def test_import_dbt_manifest_with_filter(): customer_id: type: integer description: This is a unique identifier for a customer + required: true first_name: type: varchar description: Customer's first name. PII. diff --git a/tests/test_import_iceberg.py b/tests/test_import_iceberg.py index 1b25fc8d..6518b15d 100644 --- a/tests/test_import_iceberg.py +++ b/tests/test_import_iceberg.py @@ -28,6 +28,7 @@ title: bar type: integer required: true + primaryKey: true config: icebergFieldId: 2 baz: