Skip to content

Commit

Permalink
override or
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielYang59 committed Nov 28, 2024
1 parent b44eb1a commit 021e524
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/monty/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import collections
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Mapping
from typing import TYPE_CHECKING, Mapping, overload

if TYPE_CHECKING:
from typing import Any, Iterable
Expand Down Expand Up @@ -204,15 +204,24 @@ def __contains__(self, key: Any) -> bool:
"""Checks if a case-insensitive key is `in` the dictionary."""
return super().__contains__(self._converter(key))

def __ior__(self, other: Mapping) -> Self:
"""The |= operator."""
def __ior__(self, other: Mapping | Iterable, /) -> Self:
"""The `|=` operator."""
self.update(other)
return self

def __or__(self, other: Mapping[Any, Any]) -> Self:
"""The `|` operator."""
if not isinstance(other, Mapping):
return NotImplemented

Check warning on line 215 in src/monty/collections.py

View check run for this annotation

Codecov / codecov/patch

src/monty/collections.py#L215

Added line #L215 was not covered by tests

new_dict = type(self)(self)
new_dict.update(other)
return new_dict

def setdefault(self, key: Any, default: Any = None) -> Any:
return super().setdefault(self._converter(key), default)

def update(self, *args: Iterable[Mapping], **kwargs: Any) -> None:
def update(self, *args, **kwargs: Any) -> None:
if args:
for mapping in args:
if isinstance(mapping, Mapping):
Expand Down
23 changes: 23 additions & 0 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ def test_ior_operator(self):
assert self.upper_dict["F"] == 8
assert self.upper_dict["f"] == 8

def test_or_operator(self):
# Test with another CaseInsensitiveDictUpper
other = CaseInsensitiveDictUpper({"E": 7, "F": 8})
result = self.upper_dict | other
assert isinstance(result, CaseInsensitiveDictUpper)
assert result["E"] == 7
assert result["e"] == 7
assert result["F"] == 8
assert result["f"] == 8
assert result["HI"] == "world"
assert result["hi"] == "world"

# Test with a regular dict
other = {"g": 9, "H": 10}
result = self.upper_dict | other
assert isinstance(result, CaseInsensitiveDictUpper)
assert result["G"] == 9
assert result["g"] == 9
assert result["H"] == 10
assert result["h"] == 10
assert result["HI"] == "world"
assert result["hi"] == "world"

def test_setdefault(self):
assert self.upper_dict.setdefault("g", 9) == 9
assert self.upper_dict["G"] == 9
Expand Down

0 comments on commit 021e524

Please sign in to comment.