Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

resolve pyright issues for access pattern and scheduler #323

Merged
merged 5 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions compiler/ir/stream/access_pattern.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from typing import Generic
from typing import Generic, cast

from typing_extensions import Self, TypeVar, deprecated, overload
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineMap
from typing_extensions import Self, TypeVar, overload
from xdsl.ir.affine import AffineConstantExpr, AffineDimExpr, AffineExpr, AffineMap

from compiler.util.canonicalize_affine import canonicalize_map

Expand Down Expand Up @@ -76,7 +76,7 @@ class SchedulePattern(AccessPattern):
bounds: tuple[int, ...]

def __init__(self, bounds: Sequence[int], pattern: AffineMap):
if any(bound is None or bound <= 0 for bound in bounds):
if any(bound <= 0 for bound in bounds):
raise ValueError(
"All bounds must be static, strictly positive integers for a schedule"
)
Expand Down Expand Up @@ -220,10 +220,10 @@ def __iter__(self) -> Iterator[P]:
def __eq__(self, other: object) -> bool:
if not isinstance(other, PatternCollection):
return False
other = cast(PatternCollection[P], other)
return self._patterns == other._patterns

@property
@deprecated("only valid in trivial cases")
def num_dims(self) -> int:
return self[0].num_dims

Expand All @@ -244,7 +244,7 @@ def clear_unused_dims(self, bounds: tuple[int] | None = None) -> Self:
else:
pattern_bounds = bounds
unused_dims = tuple(i for i, bound in enumerate(pattern_bounds) if bound == 1)
dim_substitutions = []
dim_substitutions: list[AffineExpr] = []
unused_counter = 0
for dim in range(self.num_dims):
if dim not in unused_dims:
Expand Down
2 changes: 1 addition & 1 deletion pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ typeCheckingMode = "strict"
"compiler/inference/helpers.py",
"compiler/inference/scoped_setups.py",
"compiler/inference/trace_acc_state.py",
"compiler/ir/stream/access_pattern.py",
"compiler/ir/stream/scheduler.py",
"compiler/ir/tsl/tiled_strided_layout.py",
"compiler/parser/tsl_parser.py",
"compiler/transforms/accfg_dedup.py",
Expand All @@ -121,7 +119,6 @@ typeCheckingMode = "strict"
"tests/dialects/test_snax.py",
"tests/dialects/test_tsl.py",
"tests/inference/test_accfg_state_tracing.py",
"tests/ir/stream/test_access_pattern.py",
"tests/ir/tsl/test_stride.py",
"tests/ir/tsl/test_tiled_stride.py",
"tests/ir/tsl/test_tiled_strided_layout.py",
Expand Down
7 changes: 1 addition & 6 deletions tests/ir/stream/test_access_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_access_pattern_disable_dims():

# test 4: disable 3 dims (all)
disabled_pattern = access_pattern.disable_dims(3)
expected_bounds = tuple()
expected_bounds: tuple[int, ...] = tuple()
expected_results = (
AffineConstantExpr(0),
AffineConstantExpr(0),
Expand Down Expand Up @@ -87,11 +87,6 @@ def test_schedule_pattern_invalid_bounds():
pattern = AffineMap(
num_dims=2, num_symbols=0, results=(AffineDimExpr(0), AffineDimExpr(1))
)
with pytest.raises(
ValueError,
match="All bounds must be static, strictly positive integers for a schedule",
):
SchedulePattern((10, None), pattern) # pyright: ignore
with pytest.raises(
ValueError,
match="All bounds must be static, strictly positive integers for a schedule",
Expand Down
Loading