From bc1f8c944d5c1a8b99a8fa7863898d2a39b6e8c7 Mon Sep 17 00:00:00 2001 From: Eric Schweitz Date: Thu, 12 Dec 2024 13:43:49 -0800 Subject: [PATCH] Change the constant used for dynamic sized quake.veq type. Signed-off-by: Eric Schweitz --- include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td | 7 +++++-- lib/Optimizer/Dialect/Quake/QuakeTypes.cpp | 4 ++-- python/cudaq/kernel/ast_bridge.py | 2 +- python/cudaq/kernel/kernel_builder.py | 8 ++++---- python/cudaq/kernel/quake_value.py | 2 +- python/runtime/mlir/py_register_dialects.cpp | 3 ++- 6 files changed, 15 insertions(+), 11 deletions(-) diff --git a/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td b/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td index 7dd31469fd..0aac6fa46c 100644 --- a/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td +++ b/include/cudaq/Optimizer/Dialect/Quake/QuakeTypes.td @@ -153,9 +153,12 @@ def VeqType : QuakeType<"Veq", "veq"> { let hasCustomAssemblyFormat = 1; let extraClassDeclaration = [{ - bool hasSpecifiedSize() const { return getSize(); } + static constexpr std::size_t kDynamicSize = + std::numeric_limits::max(); + + bool hasSpecifiedSize() const { return getSize() != kDynamicSize; } static VeqType getUnsized(mlir::MLIRContext *ctx) { - return VeqType::get(ctx, 0); + return VeqType::get(ctx, kDynamicSize); } }]; } diff --git a/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp b/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp index b536a59710..d124e32c3c 100644 --- a/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp +++ b/lib/Optimizer/Dialect/Quake/QuakeTypes.cpp @@ -39,9 +39,9 @@ void quake::VeqType::print(AsmPrinter &os) const { Type quake::VeqType::parse(AsmParser &parser) { if (parser.parseLess()) return {}; - std::size_t size = 0; + std::size_t size = kDynamicSize; if (succeeded(parser.parseOptionalQuestion())) - size = 0; + size = kDynamicSize; else if (parser.parseInteger(size)) return {}; if (parser.parseGreater()) diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index d76a802d72..7783822bd3 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -3168,7 +3168,7 @@ def bodyBuilder(iterVar): # we currently handle `veq` and `stdvec` types if quake.VeqType.isinstance(iterable.type): size = quake.VeqType.getSize(iterable.type) - if size: + if quake.VeqType.hasSpecifiedSize(iterable.type): totalSize = self.getConstantInt(size) else: totalSize = quake.VeqSizeOp(self.getIntegerType(64), diff --git a/python/cudaq/kernel/kernel_builder.py b/python/cudaq/kernel/kernel_builder.py index 192aad8929..59c091cba2 100644 --- a/python/cudaq/kernel/kernel_builder.py +++ b/python/cudaq/kernel/kernel_builder.py @@ -675,8 +675,8 @@ def __applyControlOrAdjoint(self, target, isAdjoint, controls, *args): if (quake.VeqType.isinstance(inTy) and quake.VeqType.isinstance(argTy)): - if quake.VeqType.getSize( - inTy) and not quake.VeqType.getSize(argTy): + if quake.VeqType.hasSpecifiedSize( + inTy) and not quake.VeqType.hasSpecifiedSize(argTy): value = quake.RelaxSizeOp(argTy, value).result mlirValues.append(value) @@ -1029,8 +1029,8 @@ def reset(self, target): return # target is a VeqType - size = quake.VeqType.getSize(target.mlirValue.type) - if size: + if quake.VeqType.hasSpecifiedSize(target.mlirValue.type): + size = quake.VeqType.getSize(target.mlirValue.type) for i in range(size): extracted = quake.ExtractRefOp(quake.RefType.get(self.ctx), target.mlirValue, i).result diff --git a/python/cudaq/kernel/quake_value.py b/python/cudaq/kernel/quake_value.py index 3c55a9170c..41689cbfad 100644 --- a/python/cudaq/kernel/quake_value.py +++ b/python/cudaq/kernel/quake_value.py @@ -67,7 +67,7 @@ def size(self): if quake.VeqType.isinstance(type): size = quake.VeqType.getSize(type) - if size: + if quake.VeqType.hasSpecifiedSize(type): return size return QuakeValue( quake.VeqSizeOp(self.intType, self.mlirValue).result, diff --git a/python/runtime/mlir/py_register_dialects.cpp b/python/runtime/mlir/py_register_dialects.cpp index 9c0c4f2985..cda4f2a30a 100644 --- a/python/runtime/mlir/py_register_dialects.cpp +++ b/python/runtime/mlir/py_register_dialects.cpp @@ -74,7 +74,8 @@ void registerQuakeDialectAndTypes(py::module &m) { [](py::object cls, MlirContext ctx, std::size_t size) { return wrap(quake::VeqType::get(unwrap(ctx), size)); }, - py::arg("cls"), py::arg("context"), py::arg("size") = 0) + py::arg("cls"), py::arg("context"), + py::arg("size") = std::numeric_limits::max()) .def_staticmethod( "hasSpecifiedSize", [](MlirType type) {