Skip to content

Commit

Permalink
feat(ingest/s3): S3 add partition to schema (datahub-project#8900)
Browse files Browse the repository at this point in the history
Co-authored-by: Pedro Silva <[email protected]>
  • Loading branch information
treff7es and pedro93 authored Oct 21, 2023
1 parent 86e0023 commit 04216e3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
5 changes: 4 additions & 1 deletion metadata-ingestion/src/datahub/ingestion/source/s3/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ class DataLakeSourceConfig(
default=100,
description="Maximum number of rows to use when inferring schemas for TSV and CSV files.",
)

add_partition_columns_to_schema: bool = Field(
default=False,
description="Whether to add partition fields to the schema.",
)
verify_ssl: Union[bool, str] = Field(
default=True,
description="Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use.",
Expand Down
33 changes: 33 additions & 0 deletions metadata-ingestion/src/datahub/ingestion/source/s3/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
NullTypeClass,
NumberTypeClass,
RecordTypeClass,
SchemaField,
SchemaFieldDataType,
SchemaMetadata,
StringTypeClass,
Expand All @@ -90,6 +91,7 @@
OperationClass,
OperationTypeClass,
OtherSchemaClass,
SchemaFieldDataTypeClass,
_Aspect,
)
from datahub.telemetry import stats, telemetry
Expand Down Expand Up @@ -458,8 +460,39 @@ def get_fields(self, table_data: TableData, path_spec: PathSpec) -> List:
logger.debug(f"Extracted fields in schema: {fields}")
fields = sorted(fields, key=lambda f: f.fieldPath)

if self.source_config.add_partition_columns_to_schema:
self.add_partition_columns_to_schema(
fields=fields, path_spec=path_spec, full_path=table_data.full_path
)

return fields

def add_partition_columns_to_schema(
self, path_spec: PathSpec, full_path: str, fields: List[SchemaField]
) -> None:
is_fieldpath_v2 = False
for field in fields:
if field.fieldPath.startswith("[version=2.0]"):
is_fieldpath_v2 = True
break
vars = path_spec.get_named_vars(full_path)
if vars is not None and "partition_key" in vars:
for partition_key in vars["partition_key"].values():
fields.append(
SchemaField(
fieldPath=f"{partition_key}"
if not is_fieldpath_v2
else f"[version=2.0].[type=string].{partition_key}",
nativeDataType="string",
type=SchemaFieldDataType(StringTypeClass())
if not is_fieldpath_v2
else SchemaFieldDataTypeClass(type=StringTypeClass()),
isPartitioningKey=True,
nullable=True,
recursive=False,
)
)

def get_table_profile(
self, table_data: TableData, dataset_urn: str
) -> Iterable[MetadataWorkUnit]:
Expand Down

0 comments on commit 04216e3

Please sign in to comment.