diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md
index 1777d2461b..3ae6cc56bb 100644
--- a/doc/releases/changelog-dev.md
+++ b/doc/releases/changelog-dev.md
@@ -322,6 +322,10 @@
[(#2253](https://github.com/PennyLaneAI/catalyst/pull/2253)
+ * Removed the `getRotationKind` and `setRotationKind` methods from
+ the QEC interface `QECOpInterface` to simplify the interface.
+ [(#2250)](https://github.com/PennyLaneAI/catalyst/pull/2250)
+
Documentation 📝
* A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected.
diff --git a/mlir/include/QEC/IR/QECOpInterfaces.td b/mlir/include/QEC/IR/QECOpInterfaces.td
index eed6dbb0cc..dc2fbd3e4a 100644
--- a/mlir/include/QEC/IR/QECOpInterfaces.td
+++ b/mlir/include/QEC/IR/QECOpInterfaces.td
@@ -17,12 +17,16 @@
include "mlir/IR/OpBase.td"
+//===----------------------------------------------------------------------===//
+// QEC Operation Interface
+//===----------------------------------------------------------------------===//
+
def QECOpInterface : OpInterface<"QECOpInterface"> {
let description = [{
This interface provides a generic way to interact with instructions that are
- considered QEC Operations. These are characterized by operating on zero
- or more qubit values, and returning the same amount of qubit values.
+ considered QEC Operations. These are characterized by operating on one or more qubit values,
+ and returning the same amount of qubit values.
}];
let cppNamespace = "::catalyst::qec";
@@ -40,32 +44,16 @@ def QECOpInterface : OpInterface<"QECOpInterface"> {
>,
InterfaceMethod<
/*desc=*/"Get the Pauli product for this operation.",
- /*retTy=*/" ::mlir::ArrayAttr",
+ /*retTy=*/"::mlir::ArrayAttr",
/*methodName=*/"getPauliProduct"
>,
- InterfaceMethod<
- /*desc=*/"Get the Pauli product for this operation.",
- /*retTy=*/" ::mlir::ArrayAttr",
- /*methodName=*/"getPauliProductAttr"
- >,
- InterfaceMethod<
- /*desc=*/"Get the rotation kind for this operation.",
- /*retTy=*/"uint16_t",
- /*methodName=*/"getRotationKind"
- >,
InterfaceMethod<
/*desc=*/"Set the Pauli product for this operation.",
/*retTy=*/"void",
- /*methodName=*/"setPauliProductAttr", (ins " ::mlir::ArrayAttr":$attr)
- >,
- InterfaceMethod<
- /*desc=*/"Set the rotation kind for this operation.",
- /*retTy=*/"void",
- /*methodName=*/"setRotationKind", (ins "uint16_t":$attrValue)
+ /*methodName=*/"setPauliProductAttr", (ins "const ::mlir::ArrayAttr&":$propValue)
>
];
-
}
#endif // QECOP_INTERFACES
diff --git a/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp b/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp
index 6ddd67a711..4eb415d854 100644
--- a/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp
+++ b/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp
@@ -157,7 +157,11 @@ struct CountPPMSpecsPass : public impl::CountPPMSpecsPassBase
assert(!layer.empty() && "Layer is empty");
auto op = layer.getOps().back();
- int16_t absRk = std::abs(static_cast(op.getRotationKind()));
+
+ int16_t absRk = 0;
+ if (auto pprOp = dyn_cast(op.getOperation())) {
+ absRk = std::abs(static_cast(pprOp.getRotationKind()));
+ }
auto parentFuncOp = op->getParentOfType();
StringRef funcName = parentFuncOp.getName();
llvm::StringSaver saver(stringAllocator);
diff --git a/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp b/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp
index e1361f401d..247a1777ae 100644
--- a/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp
+++ b/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp
@@ -34,7 +34,7 @@ using namespace catalyst::quantum;
namespace {
// Return the magic state or complex conjugate of the magic state
-LogicalInitKind getMagicState(QECOpInterface op)
+LogicalInitKind getMagicState(PPRotationOp op)
{
int16_t rotationKind = static_cast(op.getRotationKind());
if (rotationKind > 0) {
diff --git a/mlir/lib/QEC/Transforms/PPRToMBQC.cpp b/mlir/lib/QEC/Transforms/PPRToMBQC.cpp
index d76038f667..f254f29378 100644
--- a/mlir/lib/QEC/Transforms/PPRToMBQC.cpp
+++ b/mlir/lib/QEC/Transforms/PPRToMBQC.cpp
@@ -142,11 +142,14 @@ void constructKernelOperation(SmallVector &qubits, Value &measResult, QEC
measResult = measOp.getMres();
qubits[0] = measOp.getOutQubit();
}
- else {
- int16_t signedRk = static_cast(op.getRotationKind());
+ else if (auto pprOp = dyn_cast(op.getOperation())) {
+ int16_t signedRk = static_cast(pprOp.getRotationKind());
double rk = llvm::numbers::pi / (static_cast(signedRk) / 2);
qubits[0] = buildSingleQubitGate(qubits[0], "RZ", {rk}, rewriter).getOutQubits().front();
}
+ else if (isa(op)) {
+ op->emitError("Unsupported qec.ppr.arbitrary operation.");
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/QEC/Transforms/TLayerReduction.cpp b/mlir/lib/QEC/Transforms/TLayerReduction.cpp
index d6596cb402..65f86b2f70 100644
--- a/mlir/lib/QEC/Transforms/TLayerReduction.cpp
+++ b/mlir/lib/QEC/Transforms/TLayerReduction.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
+#include "QEC/IR/QECOps.h"
#include "QEC/Utils/PauliStringWrapper.h"
#include "QEC/Utils/QECLayer.h"
#include "QEC/Utils/QECOpUtils.h"
@@ -64,6 +65,8 @@ std::pair checkCommutationAndFindMerge(QECOpInterface rhsO
auto normalizedOps =
normalizePPROps(lhsOp, rhsOp, lhsOp.getInQubits(), rhsOpInQubitsFromLhsOp);
+ // TODO: Handle PPRotationArbitraryOp properly
+
if (!normalizedOps.first.commutes(normalizedOps.second)) {
return std::pair(false, nullptr);
}
@@ -133,8 +136,11 @@ void moveOpToLayer(QECOpInterface rhsOp, QECLayer &rhsLayer, QECOpInterface merg
// then just remove the `rhsOp` from the rhsLayer.
void mergePPR(QECOpInterface rhsOp, QECLayer &rhsLayer, QECOpInterface mergeOp, IRRewriter &writer)
{
- int16_t signedRk = static_cast(mergeOp.getRotationKind());
- mergeOp.setRotationKind(static_cast(signedRk / 2));
+ auto mergeOpPprOp = dyn_cast(mergeOp.getOperation());
+ assert(mergeOpPprOp != nullptr && "Op is not a PPRotationOp");
+
+ int16_t signedRk = static_cast(mergeOpPprOp.getRotationKind());
+ mergeOpPprOp.setRotationKind(static_cast(signedRk / 2));
rhsLayer.eraseOp(rhsOp);
writer.replaceOp(rhsOp, rhsOp->getOperands());
diff --git a/mlir/lib/QEC/Utils/PauliStringWrapper.cpp b/mlir/lib/QEC/Utils/PauliStringWrapper.cpp
index 8a72978d06..5670de52d8 100644
--- a/mlir/lib/QEC/Utils/PauliStringWrapper.cpp
+++ b/mlir/lib/QEC/Utils/PauliStringWrapper.cpp
@@ -151,8 +151,21 @@ PauliWordPair normalizePPROps(QECOpInterface lhs, QECOpInterface rhs, ValueRange
lhsPSWrapper.op = lhs;
rhsPSWrapper.op = rhs;
- lhsPSWrapper.updateSign((int16_t)lhs.getRotationKind() < 0);
- rhsPSWrapper.updateSign((int16_t)rhs.getRotationKind() < 0);
+ auto applySignFromOp = [](PauliStringWrapper &wrapper, QECOpInterface qecOp) {
+ Operation *operation = qecOp.getOperation();
+
+ if (auto pprOp = dyn_cast(operation)) {
+ wrapper.updateSign(static_cast(pprOp.getRotationKind()) < 0);
+ return;
+ }
+
+ if (auto ppmOp = dyn_cast(operation)) {
+ wrapper.updateSign(static_cast(ppmOp.getRotationSign()) < 0);
+ }
+ };
+
+ applySignFromOp(lhsPSWrapper, lhs);
+ applySignFromOp(rhsPSWrapper, rhs);
return std::make_pair(std::move(lhsPSWrapper), std::move(rhsPSWrapper));
}
@@ -211,10 +224,17 @@ void updatePauliWord(QECOpInterface op, const PauliWord &newPauliWord, PatternRe
void updatePauliWordSign(QECOpInterface op, bool isNegated, PatternRewriter &rewriter)
{
- int16_t rotationKind = static_cast(op.getRotationKind());
- int16_t sign = isNegated ? -1 : 1;
- rotationKind = (rotationKind < 0 ? -rotationKind : rotationKind) * sign;
- op.setRotationKind(rotationKind);
+ if (auto pprOp = dyn_cast(op.getOperation())) {
+ int16_t rotationKind = static_cast(pprOp.getRotationKind());
+ int16_t sign = isNegated ? -1 : 1;
+ rotationKind = (rotationKind < 0 ? -rotationKind : rotationKind) * sign;
+ pprOp.setRotationKind(rotationKind);
+ }
+ else if (auto ppmOp = dyn_cast(op.getOperation())) {
+ int16_t rotationSign = static_cast(ppmOp.getRotationSign());
+ rotationSign = (rotationSign < 0 ? -rotationSign : rotationSign) * (isNegated ? -1 : 1);
+ ppmOp.setRotationSign(rotationSign);
+ }
}
SmallVector extractPauliString(QECOpInterface op)