diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 1777d2461b..3ae6cc56bb 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -322,6 +322,10 @@ [(#2253](https://github.com/PennyLaneAI/catalyst/pull/2253) + * Removed the `getRotationKind` and `setRotationKind` methods from + the QEC interface `QECOpInterface` to simplify the interface. + [(#2250)](https://github.com/PennyLaneAI/catalyst/pull/2250) +

Documentation 📝

* A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected. diff --git a/mlir/include/QEC/IR/QECOpInterfaces.td b/mlir/include/QEC/IR/QECOpInterfaces.td index eed6dbb0cc..dc2fbd3e4a 100644 --- a/mlir/include/QEC/IR/QECOpInterfaces.td +++ b/mlir/include/QEC/IR/QECOpInterfaces.td @@ -17,12 +17,16 @@ include "mlir/IR/OpBase.td" +//===----------------------------------------------------------------------===// +// QEC Operation Interface +//===----------------------------------------------------------------------===// + def QECOpInterface : OpInterface<"QECOpInterface"> { let description = [{ This interface provides a generic way to interact with instructions that are - considered QEC Operations. These are characterized by operating on zero - or more qubit values, and returning the same amount of qubit values. + considered QEC Operations. These are characterized by operating on one or more qubit values, + and returning the same amount of qubit values. }]; let cppNamespace = "::catalyst::qec"; @@ -40,32 +44,16 @@ def QECOpInterface : OpInterface<"QECOpInterface"> { >, InterfaceMethod< /*desc=*/"Get the Pauli product for this operation.", - /*retTy=*/" ::mlir::ArrayAttr", + /*retTy=*/"::mlir::ArrayAttr", /*methodName=*/"getPauliProduct" >, - InterfaceMethod< - /*desc=*/"Get the Pauli product for this operation.", - /*retTy=*/" ::mlir::ArrayAttr", - /*methodName=*/"getPauliProductAttr" - >, - InterfaceMethod< - /*desc=*/"Get the rotation kind for this operation.", - /*retTy=*/"uint16_t", - /*methodName=*/"getRotationKind" - >, InterfaceMethod< /*desc=*/"Set the Pauli product for this operation.", /*retTy=*/"void", - /*methodName=*/"setPauliProductAttr", (ins " ::mlir::ArrayAttr":$attr) - >, - InterfaceMethod< - /*desc=*/"Set the rotation kind for this operation.", - /*retTy=*/"void", - /*methodName=*/"setRotationKind", (ins "uint16_t":$attrValue) + /*methodName=*/"setPauliProductAttr", (ins "const ::mlir::ArrayAttr&":$propValue) > ]; - } #endif // QECOP_INTERFACES diff --git a/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp b/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp index 6ddd67a711..4eb415d854 100644 --- a/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp +++ b/mlir/lib/QEC/Transforms/CountPPMSpecs.cpp @@ -157,7 +157,11 @@ struct CountPPMSpecsPass : public impl::CountPPMSpecsPassBase assert(!layer.empty() && "Layer is empty"); auto op = layer.getOps().back(); - int16_t absRk = std::abs(static_cast(op.getRotationKind())); + + int16_t absRk = 0; + if (auto pprOp = dyn_cast(op.getOperation())) { + absRk = std::abs(static_cast(pprOp.getRotationKind())); + } auto parentFuncOp = op->getParentOfType(); StringRef funcName = parentFuncOp.getName(); llvm::StringSaver saver(stringAllocator); diff --git a/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp b/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp index e1361f401d..247a1777ae 100644 --- a/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp +++ b/mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp @@ -34,7 +34,7 @@ using namespace catalyst::quantum; namespace { // Return the magic state or complex conjugate of the magic state -LogicalInitKind getMagicState(QECOpInterface op) +LogicalInitKind getMagicState(PPRotationOp op) { int16_t rotationKind = static_cast(op.getRotationKind()); if (rotationKind > 0) { diff --git a/mlir/lib/QEC/Transforms/PPRToMBQC.cpp b/mlir/lib/QEC/Transforms/PPRToMBQC.cpp index d76038f667..f254f29378 100644 --- a/mlir/lib/QEC/Transforms/PPRToMBQC.cpp +++ b/mlir/lib/QEC/Transforms/PPRToMBQC.cpp @@ -142,11 +142,14 @@ void constructKernelOperation(SmallVector &qubits, Value &measResult, QEC measResult = measOp.getMres(); qubits[0] = measOp.getOutQubit(); } - else { - int16_t signedRk = static_cast(op.getRotationKind()); + else if (auto pprOp = dyn_cast(op.getOperation())) { + int16_t signedRk = static_cast(pprOp.getRotationKind()); double rk = llvm::numbers::pi / (static_cast(signedRk) / 2); qubits[0] = buildSingleQubitGate(qubits[0], "RZ", {rk}, rewriter).getOutQubits().front(); } + else if (isa(op)) { + op->emitError("Unsupported qec.ppr.arbitrary operation."); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/QEC/Transforms/TLayerReduction.cpp b/mlir/lib/QEC/Transforms/TLayerReduction.cpp index d6596cb402..65f86b2f70 100644 --- a/mlir/lib/QEC/Transforms/TLayerReduction.cpp +++ b/mlir/lib/QEC/Transforms/TLayerReduction.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "QEC/IR/QECOps.h" #include "QEC/Utils/PauliStringWrapper.h" #include "QEC/Utils/QECLayer.h" #include "QEC/Utils/QECOpUtils.h" @@ -64,6 +65,8 @@ std::pair checkCommutationAndFindMerge(QECOpInterface rhsO auto normalizedOps = normalizePPROps(lhsOp, rhsOp, lhsOp.getInQubits(), rhsOpInQubitsFromLhsOp); + // TODO: Handle PPRotationArbitraryOp properly + if (!normalizedOps.first.commutes(normalizedOps.second)) { return std::pair(false, nullptr); } @@ -133,8 +136,11 @@ void moveOpToLayer(QECOpInterface rhsOp, QECLayer &rhsLayer, QECOpInterface merg // then just remove the `rhsOp` from the rhsLayer. void mergePPR(QECOpInterface rhsOp, QECLayer &rhsLayer, QECOpInterface mergeOp, IRRewriter &writer) { - int16_t signedRk = static_cast(mergeOp.getRotationKind()); - mergeOp.setRotationKind(static_cast(signedRk / 2)); + auto mergeOpPprOp = dyn_cast(mergeOp.getOperation()); + assert(mergeOpPprOp != nullptr && "Op is not a PPRotationOp"); + + int16_t signedRk = static_cast(mergeOpPprOp.getRotationKind()); + mergeOpPprOp.setRotationKind(static_cast(signedRk / 2)); rhsLayer.eraseOp(rhsOp); writer.replaceOp(rhsOp, rhsOp->getOperands()); diff --git a/mlir/lib/QEC/Utils/PauliStringWrapper.cpp b/mlir/lib/QEC/Utils/PauliStringWrapper.cpp index 8a72978d06..5670de52d8 100644 --- a/mlir/lib/QEC/Utils/PauliStringWrapper.cpp +++ b/mlir/lib/QEC/Utils/PauliStringWrapper.cpp @@ -151,8 +151,21 @@ PauliWordPair normalizePPROps(QECOpInterface lhs, QECOpInterface rhs, ValueRange lhsPSWrapper.op = lhs; rhsPSWrapper.op = rhs; - lhsPSWrapper.updateSign((int16_t)lhs.getRotationKind() < 0); - rhsPSWrapper.updateSign((int16_t)rhs.getRotationKind() < 0); + auto applySignFromOp = [](PauliStringWrapper &wrapper, QECOpInterface qecOp) { + Operation *operation = qecOp.getOperation(); + + if (auto pprOp = dyn_cast(operation)) { + wrapper.updateSign(static_cast(pprOp.getRotationKind()) < 0); + return; + } + + if (auto ppmOp = dyn_cast(operation)) { + wrapper.updateSign(static_cast(ppmOp.getRotationSign()) < 0); + } + }; + + applySignFromOp(lhsPSWrapper, lhs); + applySignFromOp(rhsPSWrapper, rhs); return std::make_pair(std::move(lhsPSWrapper), std::move(rhsPSWrapper)); } @@ -211,10 +224,17 @@ void updatePauliWord(QECOpInterface op, const PauliWord &newPauliWord, PatternRe void updatePauliWordSign(QECOpInterface op, bool isNegated, PatternRewriter &rewriter) { - int16_t rotationKind = static_cast(op.getRotationKind()); - int16_t sign = isNegated ? -1 : 1; - rotationKind = (rotationKind < 0 ? -rotationKind : rotationKind) * sign; - op.setRotationKind(rotationKind); + if (auto pprOp = dyn_cast(op.getOperation())) { + int16_t rotationKind = static_cast(pprOp.getRotationKind()); + int16_t sign = isNegated ? -1 : 1; + rotationKind = (rotationKind < 0 ? -rotationKind : rotationKind) * sign; + pprOp.setRotationKind(rotationKind); + } + else if (auto ppmOp = dyn_cast(op.getOperation())) { + int16_t rotationSign = static_cast(ppmOp.getRotationSign()); + rotationSign = (rotationSign < 0 ? -rotationSign : rotationSign) * (isNegated ? -1 : 1); + ppmOp.setRotationSign(rotationSign); + } } SmallVector extractPauliString(QECOpInterface op)