Skip to content

Commit

Permalink
Change the constant used for dynamic sized quake.veq type.
Browse files Browse the repository at this point in the history
Signed-off-by: Eric Schweitz <[email protected]>
  • Loading branch information
schweitzpgi committed Dec 12, 2024
1 parent 105c05a commit bc1f8c9
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 11 deletions.
7 changes: 5 additions & 2 deletions include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,12 @@ def VeqType : QuakeType<"Veq", "veq"> {
let hasCustomAssemblyFormat = 1;

let extraClassDeclaration = [{
bool hasSpecifiedSize() const { return getSize(); }
static constexpr std::size_t kDynamicSize =
std::numeric_limits<std::size_t>::max();

bool hasSpecifiedSize() const { return getSize() != kDynamicSize; }
static VeqType getUnsized(mlir::MLIRContext *ctx) {
return VeqType::get(ctx, 0);
return VeqType::get(ctx, kDynamicSize);
}
}];
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Optimizer/Dialect/Quake/QuakeTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ void quake::VeqType::print(AsmPrinter &os) const {
Type quake::VeqType::parse(AsmParser &parser) {
if (parser.parseLess())
return {};
std::size_t size = 0;
std::size_t size = kDynamicSize;
if (succeeded(parser.parseOptionalQuestion()))
size = 0;
size = kDynamicSize;
else if (parser.parseInteger(size))
return {};
if (parser.parseGreater())
Expand Down
2 changes: 1 addition & 1 deletion python/cudaq/kernel/ast_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3168,7 +3168,7 @@ def bodyBuilder(iterVar):
# we currently handle `veq` and `stdvec` types
if quake.VeqType.isinstance(iterable.type):
size = quake.VeqType.getSize(iterable.type)
if size:
if quake.VeqType.hasSpecifiedSize(iterable.type):
totalSize = self.getConstantInt(size)
else:
totalSize = quake.VeqSizeOp(self.getIntegerType(64),
Expand Down
8 changes: 4 additions & 4 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,8 +675,8 @@ def __applyControlOrAdjoint(self, target, isAdjoint, controls, *args):

if (quake.VeqType.isinstance(inTy) and
quake.VeqType.isinstance(argTy)):
if quake.VeqType.getSize(
inTy) and not quake.VeqType.getSize(argTy):
if quake.VeqType.hasSpecifiedSize(
inTy) and not quake.VeqType.hasSpecifiedSize(argTy):
value = quake.RelaxSizeOp(argTy, value).result

mlirValues.append(value)
Expand Down Expand Up @@ -1029,8 +1029,8 @@ def reset(self, target):
return

# target is a VeqType
size = quake.VeqType.getSize(target.mlirValue.type)
if size:
if quake.VeqType.hasSpecifiedSize(target.mlirValue.type):
size = quake.VeqType.getSize(target.mlirValue.type)
for i in range(size):
extracted = quake.ExtractRefOp(quake.RefType.get(self.ctx),
target.mlirValue, i).result
Expand Down
2 changes: 1 addition & 1 deletion python/cudaq/kernel/quake_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def size(self):

if quake.VeqType.isinstance(type):
size = quake.VeqType.getSize(type)
if size:
if quake.VeqType.hasSpecifiedSize(type):
return size
return QuakeValue(
quake.VeqSizeOp(self.intType, self.mlirValue).result,
Expand Down
3 changes: 2 additions & 1 deletion python/runtime/mlir/py_register_dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ void registerQuakeDialectAndTypes(py::module &m) {
[](py::object cls, MlirContext ctx, std::size_t size) {
return wrap(quake::VeqType::get(unwrap(ctx), size));
},
py::arg("cls"), py::arg("context"), py::arg("size") = 0)
py::arg("cls"), py::arg("context"),
py::arg("size") = std::numeric_limits<std::size_t>::max())
.def_staticmethod(
"hasSpecifiedSize",
[](MlirType type) {
Expand Down

0 comments on commit bc1f8c9

Please sign in to comment.