diff --git a/changes/364.bugfix.rst b/changes/364.bugfix.rst new file mode 100644 index 00000000..96801fac --- /dev/null +++ b/changes/364.bugfix.rst @@ -0,0 +1 @@ +Allow ``merge_property_trees`` to retain input schema id in ``model.schema["id"]``. diff --git a/src/stdatamodels/jwst/datamodels/tests/test_models.py b/src/stdatamodels/jwst/datamodels/tests/test_models.py index cedd4ba2..5a76a00b 100644 --- a/src/stdatamodels/jwst/datamodels/tests/test_models.py +++ b/src/stdatamodels/jwst/datamodels/tests/test_models.py @@ -129,7 +129,7 @@ def test_imagemodel(): def test_model_with_nonstandard_primary_array(): - class NonstandardPrimaryArrayModel(JwstDataModel): + class _NonstandardPrimaryArrayModel(JwstDataModel): schema_url = os.path.join(ROOT_DIR, "nonstandard_primary_array.schema.yaml") # The wavelength array is the primary array. @@ -137,7 +137,7 @@ class NonstandardPrimaryArrayModel(JwstDataModel): def get_primary_array_name(self): return 'wavelength' - m = NonstandardPrimaryArrayModel((10,)) + m = _NonstandardPrimaryArrayModel((10,)) assert 'wavelength' in list(m.keys()) assert m.wavelength.sum() == 0 diff --git a/src/stdatamodels/schema.py b/src/stdatamodels/schema.py index 1edebdd5..70c59a0f 100644 --- a/src/stdatamodels/schema.py +++ b/src/stdatamodels/schema.py @@ -1,7 +1,6 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst import re -from collections import OrderedDict # return_result included for backward compatibility @@ -169,11 +168,17 @@ def merge_property_trees(schema): can be represented in individual files and then referenced elsewhere. They are then combined by this function into a single schema data structure. """ - newschema = OrderedDict() + # track the "combined" and "top" items separately + # this allows the top level "id", "$schema", etc to overwrite + # combined items (so that the schema["id"] doesn't change). + combined_items = {} + top_items = {} def add_entry(path, schema, combiner): + if not top_items: + top_items.update(schema) # TODO: Simplify? - cursor = newschema + cursor = combined_items for i in range(len(path)): part = path[i] if part == combiner: @@ -185,23 +190,23 @@ def add_entry(path, schema, combiner): cursor.append({}) cursor = cursor[part] elif part == 'items': - cursor = cursor.setdefault('items', OrderedDict()) + cursor = cursor.setdefault('items', {}) else: - cursor = cursor.setdefault('properties', OrderedDict()) + cursor = cursor.setdefault('properties', {}) if i < len(path) - 1 and isinstance(path[i + 1], int): cursor = cursor.setdefault(part, []) else: - cursor = cursor.setdefault(part, OrderedDict()) + cursor = cursor.setdefault(part, {}) cursor.update(schema) def callback(schema, path, combiner, ctx, recurse): - type = schema.get('type') - schema = OrderedDict(schema) - if type == 'object': + schema_type = schema.get('type') + schema = dict(schema) # shallow copy + if schema_type == 'object': if 'properties' in schema: del schema['properties'] - elif type == 'array': + elif schema_type == 'array': del schema['items'] if 'allOf' in schema: del schema['allOf'] @@ -210,7 +215,7 @@ def callback(schema, path, combiner, ctx, recurse): walk_schema(schema, callback) - return newschema + return combined_items | top_items def build_docstring(klass, template="{fits_hdu} {title}"): diff --git a/tests/test_schema.py b/tests/test_schema.py index 519834ca..b3cdbade 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -206,6 +206,19 @@ def test_merge_property_trees(combiner): assert f == s +def test_merge_property_tree_top(): + s = { + "id": "foo", + "allOf": [ + { + "id": "bar", + }, + ] + } + f = merge_property_trees(s) + assert f["id"] == "foo" + + def test_schema_docstring(): template = "{fits_hdu} {title}" docstring = build_docstring(FitsModel, template).split("\n")