diff --git a/src/monty/collections.py b/src/monty/collections.py index b9d463f7..0a6f242a 100644 --- a/src/monty/collections.py +++ b/src/monty/collections.py @@ -6,7 +6,7 @@ import collections from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Mapping +from typing import TYPE_CHECKING, Mapping, overload if TYPE_CHECKING: from typing import Any, Iterable @@ -204,15 +204,24 @@ def __contains__(self, key: Any) -> bool: """Checks if a case-insensitive key is `in` the dictionary.""" return super().__contains__(self._converter(key)) - def __ior__(self, other: Mapping) -> Self: - """The |= operator.""" + def __ior__(self, other: Mapping | Iterable, /) -> Self: + """The `|=` operator.""" self.update(other) return self + def __or__(self, other: Mapping[Any, Any]) -> Self: + """The `|` operator.""" + if not isinstance(other, Mapping): + return NotImplemented + + new_dict = type(self)(self) + new_dict.update(other) + return new_dict + def setdefault(self, key: Any, default: Any = None) -> Any: return super().setdefault(self._converter(key), default) - def update(self, *args: Iterable[Mapping], **kwargs: Any) -> None: + def update(self, *args, **kwargs: Any) -> None: if args: for mapping in args: if isinstance(mapping, Mapping): diff --git a/tests/test_collections.py b/tests/test_collections.py index 504a4262..f98076cc 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -140,6 +140,29 @@ def test_ior_operator(self): assert self.upper_dict["F"] == 8 assert self.upper_dict["f"] == 8 + def test_or_operator(self): + # Test with another CaseInsensitiveDictUpper + other = CaseInsensitiveDictUpper({"E": 7, "F": 8}) + result = self.upper_dict | other + assert isinstance(result, CaseInsensitiveDictUpper) + assert result["E"] == 7 + assert result["e"] == 7 + assert result["F"] == 8 + assert result["f"] == 8 + assert result["HI"] == "world" + assert result["hi"] == "world" + + # Test with a regular dict + other = {"g": 9, "H": 10} + result = self.upper_dict | other + assert isinstance(result, CaseInsensitiveDictUpper) + assert result["G"] == 9 + assert result["g"] == 9 + assert result["H"] == 10 + assert result["h"] == 10 + assert result["HI"] == "world" + assert result["hi"] == "world" + def test_setdefault(self): assert self.upper_dict.setdefault("g", 9) == 9 assert self.upper_dict["G"] == 9