From 6a1991cc3a2695ba361fa3982001a57fa1b2f097 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Mon, 6 Jan 2025 12:11:37 +0100 Subject: [PATCH] affine transform util: add interoperability with AffineMaps (#324) * add to_affine_map * fix import * add some comments * add check for pure linearity --- compiler/ir/autoflow/affine_transform.py | 62 ++++++++++++++++++++++ tests/ir/autoflow/test_affine_transform.py | 28 ++++++++++ 2 files changed, 90 insertions(+) diff --git a/compiler/ir/autoflow/affine_transform.py b/compiler/ir/autoflow/affine_transform.py index 67ce6045..cab29721 100644 --- a/compiler/ir/autoflow/affine_transform.py +++ b/compiler/ir/autoflow/affine_transform.py @@ -3,6 +3,14 @@ import numpy as np import numpy.typing as npt from typing_extensions import Self +from xdsl.ir.affine import ( + AffineBinaryOpExpr, + AffineBinaryOpKind, + AffineConstantExpr, + AffineDimExpr, + AffineExpr, + AffineMap, +) @dataclass(frozen=True) @@ -25,6 +33,60 @@ def __post_init__(self): if self.A.shape[0] != self.b.shape[0]: raise ValueError("Matrix A and vector b must have compatible dimensions.") + @classmethod + def from_affine_map(cls, map: AffineMap) -> Self: + """ + Return the affine transform representation of the given affine map. + + For this, the affine map must be a pure linear transformation (i.e., no floordiv/ceildiv/modulo operations) + """ + + # check for pure linear transformation + for result in map.results: + for expr in result.dfs(): + if isinstance(expr, AffineBinaryOpExpr): + if expr.kind in ( + AffineBinaryOpKind.FloorDiv, + AffineBinaryOpKind.CeilDiv, + AffineBinaryOpKind.Mod, + ): + raise ValueError( + "Affine map is not a pure linear transformation" + ) + + # generate a list with n zeros and a 1 at index d: + # [0, 0, 0, 1] + def generate_one_list(n: int, d: int): + return [1 if x == d else 0 for x in range(n)] + + # determine indices of the matrices a and b by getting the unit response of every dimension + + # bias b is determined by setting all dimensions to zero + b = np.array(map.eval(generate_one_list(map.num_dims, -1), [])) + + # columns of a are determined by toggling every dimension separately + a = np.zeros((len(map.results), map.num_dims), dtype=np.int_) + for dim in range(map.num_dims): + temp = np.array(map.eval(generate_one_list(map.num_dims, dim), [])) + a[:, dim] = temp - b + + return cls(a, b) + + def to_affine_map(self) -> AffineMap: + """ + Return the xDSL AffineMap representation of this AffineTransform + """ + results: list[AffineExpr] = [] + for result in range(self.num_results): + expr = AffineConstantExpr(int(self.b[result])) + for dim in range(self.num_dims): + if self.A[result, dim] != 0: + expr += AffineConstantExpr( + int(self.A[result, dim]) + ) * AffineDimExpr(dim) + results.append(expr) + return AffineMap(self.num_dims, 0, tuple(results)) + @property def num_dims(self) -> int: return self.A.shape[1] diff --git a/tests/ir/autoflow/test_affine_transform.py b/tests/ir/autoflow/test_affine_transform.py index 04bbe178..c09a7256 100644 --- a/tests/ir/autoflow/test_affine_transform.py +++ b/tests/ir/autoflow/test_affine_transform.py @@ -1,7 +1,9 @@ import numpy as np import pytest +from xdsl.ir.affine import AffineMap from compiler.ir.autoflow import AffineTransform +from compiler.util.canonicalize_affine import canonicalize_map def test_affine_transform_initialization_valid(): @@ -95,3 +97,29 @@ def test_affine_transform_str(): transform = AffineTransform(A, b) expected = "AffineTransform(A=\n[[1 0]\n [0 1]],\nb=[1 2])" assert str(transform) == expected + + +def test_affine_map_interop(): + map = AffineMap.from_callable(lambda a, b, c: (a + 2 * b, -b + c, 3 * a + c + 4)) + + # convert AffineMap to AffineTransform + transform = AffineTransform.from_affine_map(map) + + expected_a = np.array([[1, 2, 0], [0, -1, 1], [3, 0, 1]]) + expected_b = np.array([0, 0, 4]) + + assert (transform.A == expected_a).all() + assert (transform.b == expected_b).all() + + # convert back to AffineMap + + original_map = transform.to_affine_map() + assert canonicalize_map(map) == canonicalize_map(original_map) + + invalid_map = AffineMap.from_callable(lambda a: (a // 2,)) + + with pytest.raises( + ValueError, + match="Affine map is not a pure linear transformation", + ): + AffineTransform.from_affine_map(invalid_map)