Skip to content

Commit

Permalink
RNS: allow mod arith types as basis types
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Nov 23, 2024
1 parent 1e5cd7e commit 14300e1
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 10 deletions.
2 changes: 2 additions & 0 deletions lib/Dialect/RNS/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cc_library(
":ops_inc_gen",
":type_interfaces_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand All @@ -45,6 +46,7 @@ cc_library(
":dialect_inc_gen",
":type_interfaces_inc_gen",
":types_inc_gen",
"@heir//lib/Dialect/ModArith/IR:Dialect",
"@heir//lib/Dialect/Polynomial/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
Expand Down
10 changes: 6 additions & 4 deletions lib/Dialect/RNS/IR/RNSTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ def RNSBasisTypeInterface : TypeInterface<"RNSBasisTypeInterface"> {
are both polynomial types, even if they have different coefficient
moduli.

`isCompatibleWith` must be both commutative and associative, in the sense
Another example is using mod arith types as the basis types, where
by the nature of chinese reminder theorem, it is required that
the modulus of them must be mutually coprime.

`isCompatibleWith` must be commutative, in the sense
that `type1.isCompatibleWith(type2)` if and only if
`type2.isCompatibleWith(type1)`, and further
`type2.isCompatibleWith(type3)` if and only if
`type1.isCompatibleWith(type3)`.
`type2.isCompatibleWith(type1)`.
}],
"bool", "isCompatibleWith", (ins "::mlir::Type":$otherRnsBasisType)>
];
Expand Down
63 changes: 58 additions & 5 deletions lib/Dialect/RNS/IR/RNSTypes.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "lib/Dialect/RNS/IR/RNSTypes.h"

#include "lib/Dialect/ModArith/IR/ModArithDialect.h"
#include "lib/Dialect/ModArith/IR/ModArithTypes.h"
#include "lib/Dialect/Polynomial/IR/PolynomialDialect.h"
#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h"
#include "lib/Dialect/RNS/IR/RNSTypeInterfaces.h"
Expand All @@ -17,18 +19,38 @@ using polynomial::PolynomialDialect;
using polynomial::PolynomialType;

namespace heir {

using mod_arith::ModArithDialect;
using mod_arith::ModArithType;

namespace rns {

LogicalResult RNSType::verify(
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
llvm::ArrayRef<mlir::Type> basisTypes) {
bool compatible = true;
RNSBasisTypeInterface first =
llvm::dyn_cast<RNSBasisTypeInterface>(basisTypes[0]);
if (!first) return failure();

for (auto other : basisTypes) {
compatible &= first.isCompatibleWith(other);
auto getInterface = [&](Type type) -> FailureOr<RNSBasisTypeInterface> {
auto res = mlir::dyn_cast<RNSBasisTypeInterface>(type);
if (!res) {
return emitError() << type << " does not have RNSBasisTypeInterface";
}
return res;
};

size_t numTypes = basisTypes.size();
for (auto i = 0; i != numTypes; ++i) {
for (auto j = i + 1; j != numTypes; ++j) {
auto resI = getInterface(basisTypes[i]);
if (failed(resI)) {
return resI;
}
auto resJ = getInterface(basisTypes[j]);
if (failed(resJ)) {
return resJ;
}
compatible &= (*resI).isCompatibleWith(*resJ);
}
}

if (!compatible) {
Expand Down Expand Up @@ -60,11 +82,42 @@ struct PolynomialRNSBasisTypeInterface
}
};

struct ModArithRNSBasisTypeInterface
: public RNSBasisTypeInterface::ExternalModel<ModArithRNSBasisTypeInterface,
ModArithType> {
bool isCompatibleWith(Type type, Type otherRnsBasisType) const {
auto thisType = mlir::dyn_cast<ModArithType>(type);
if (!thisType) {
return false;
}

auto other = mlir::dyn_cast<ModArithType>(otherRnsBasisType);
if (!other) {
return false;
}

APInt thisModulus = thisType.getModulus().getValue();
APInt otherModulus = other.getModulus().getValue();

// require same bitwidth
if (thisModulus.getBitWidth() != otherModulus.getBitWidth()) {
return false;
}

// coprime test
return llvm::APIntOps::GreatestCommonDivisor(thisModulus, otherModulus) ==
1;
}
};

void registerExternalRNSTypeInterfaces(DialectRegistry &registry) {
registry.addExtension(
+[](MLIRContext *ctx, ::mlir::polynomial::PolynomialDialect *dialect) {
PolynomialType::attachInterface<PolynomialRNSBasisTypeInterface>(*ctx);
});
registry.addExtension(+[](MLIRContext *ctx, ModArithDialect *dialect) {
ModArithType::attachInterface<ModArithRNSBasisTypeInterface>(*ctx);
});
}

} // namespace rns
Expand Down
31 changes: 30 additions & 1 deletion tests/Dialect/RNS/IR/syntax.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: heir-opt --verify-diagnostics %s
// RUN: heir-opt --verify-diagnostics --split-input-file %s

#ideal = #polynomial.int_polynomial<1 + x**1024>
#ideal_2 = #polynomial.int_polynomial<1 + x**2048>
Expand All @@ -22,3 +22,32 @@ func.func @test_syntax(%arg0: !ty) -> !ty {
!poly_ty_bad = !polynomial.polynomial<ring=#ring_bad>
// expected-error@+1 {{RNS type has incompatible basis types}}
!ty_bad = !rns.rns<!poly_ty_1, !poly_ty_2, !poly_ty_bad>

// -----
// mod arith

!Zp1 = !mod_arith.int<3721063133 : i64>
!Zp2 = !mod_arith.int<2737228591 : i64>
!Zp3 = !mod_arith.int<3180146689 : i64>

!ty_modarith = !rns.rns<!Zp1, !Zp2, !Zp3>

func.func @test_syntax_modarith(%arg0: !ty_modarith) -> !ty_modarith {
return %arg0 : !ty_modarith
}

// expected-error@+1 {{RNS type has incompatible basis types}}
!ty_modarith_bad = !rns.rns<!Zp1, !Zp2, !Zp1>

// -----

!Zp1 = !mod_arith.int<3721063133 : i64>
!Zp2 = !mod_arith.int<65537 : i32>

// expected-error@+1 {{RNS type has incompatible basis types}}
!ty_modarith_bad = !rns.rns<!Zp1, !Zp2>

// -----

// expected-error@+1 {{does not have RNSBasisTypeInterface}}
!ty_int_bad = !rns.rns<i32, i64>

0 comments on commit 14300e1

Please sign in to comment.