Skip to content

Commit 1242748

Browse files
authored
Remove setRotationKind and getRotationKind from QEC interface (#2250)
**Context:** Previously, `setRotationKind` and `getRotationKind` were implemented for hooks in the QEC interface, but this should not be in the `QECOpInterface` because that interface also applies `PPMeasurementOp`, where `RotationKind` is not available. [[sc-104780]]
1 parent a452617 commit 1242748

File tree

7 files changed

+57
-32
lines changed

7 files changed

+57
-32
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,10 @@
322322
[(#2253](https://github.com/PennyLaneAI/catalyst/pull/2253)
323323

324324

325+
* Removed the `getRotationKind` and `setRotationKind` methods from
326+
the QEC interface `QECOpInterface` to simplify the interface.
327+
[(#2250)](https://github.com/PennyLaneAI/catalyst/pull/2250)
328+
325329
<h3>Documentation 📝</h3>
326330

327331
* A typo in the code example for :func:`~.passes.ppr_to_ppm` has been corrected.

mlir/include/QEC/IR/QECOpInterfaces.td

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,16 @@
1717

1818
include "mlir/IR/OpBase.td"
1919

20+
//===----------------------------------------------------------------------===//
21+
// QEC Operation Interface
22+
//===----------------------------------------------------------------------===//
23+
2024
def QECOpInterface : OpInterface<"QECOpInterface"> {
2125

2226
let description = [{
2327
This interface provides a generic way to interact with instructions that are
24-
considered QEC Operations. These are characterized by operating on zero
25-
or more qubit values, and returning the same amount of qubit values.
28+
considered QEC Operations. These are characterized by operating on one or more qubit values,
29+
and returning the same amount of qubit values.
2630
}];
2731

2832
let cppNamespace = "::catalyst::qec";
@@ -40,32 +44,16 @@ def QECOpInterface : OpInterface<"QECOpInterface"> {
4044
>,
4145
InterfaceMethod<
4246
/*desc=*/"Get the Pauli product for this operation.",
43-
/*retTy=*/" ::mlir::ArrayAttr",
47+
/*retTy=*/"::mlir::ArrayAttr",
4448
/*methodName=*/"getPauliProduct"
4549
>,
46-
InterfaceMethod<
47-
/*desc=*/"Get the Pauli product for this operation.",
48-
/*retTy=*/" ::mlir::ArrayAttr",
49-
/*methodName=*/"getPauliProductAttr"
50-
>,
51-
InterfaceMethod<
52-
/*desc=*/"Get the rotation kind for this operation.",
53-
/*retTy=*/"uint16_t",
54-
/*methodName=*/"getRotationKind"
55-
>,
5650
InterfaceMethod<
5751
/*desc=*/"Set the Pauli product for this operation.",
5852
/*retTy=*/"void",
59-
/*methodName=*/"setPauliProductAttr", (ins " ::mlir::ArrayAttr":$attr)
60-
>,
61-
InterfaceMethod<
62-
/*desc=*/"Set the rotation kind for this operation.",
63-
/*retTy=*/"void",
64-
/*methodName=*/"setRotationKind", (ins "uint16_t":$attrValue)
53+
/*methodName=*/"setPauliProductAttr", (ins "const ::mlir::ArrayAttr&":$propValue)
6554
>
6655
];
6756

68-
6957
}
7058

7159
#endif // QECOP_INTERFACES

mlir/lib/QEC/Transforms/CountPPMSpecs.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,11 @@ struct CountPPMSpecsPass : public impl::CountPPMSpecsPassBase<CountPPMSpecsPass>
157157
assert(!layer.empty() && "Layer is empty");
158158

159159
auto op = layer.getOps().back();
160-
int16_t absRk = std::abs(static_cast<int16_t>(op.getRotationKind()));
160+
161+
int16_t absRk = 0;
162+
if (auto pprOp = dyn_cast<PPRotationOp>(op.getOperation())) {
163+
absRk = std::abs(static_cast<int16_t>(pprOp.getRotationKind()));
164+
}
161165
auto parentFuncOp = op->getParentOfType<func::FuncOp>();
162166
StringRef funcName = parentFuncOp.getName();
163167
llvm::StringSaver saver(stringAllocator);

mlir/lib/QEC/Transforms/DecomposeNonCliffordPPR.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using namespace catalyst::quantum;
3434
namespace {
3535

3636
// Return the magic state or complex conjugate of the magic state
37-
LogicalInitKind getMagicState(QECOpInterface op)
37+
LogicalInitKind getMagicState(PPRotationOp op)
3838
{
3939
int16_t rotationKind = static_cast<int16_t>(op.getRotationKind());
4040
if (rotationKind > 0) {

mlir/lib/QEC/Transforms/PPRToMBQC.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,14 @@ void constructKernelOperation(SmallVector<Value> &qubits, Value &measResult, QEC
142142
measResult = measOp.getMres();
143143
qubits[0] = measOp.getOutQubit();
144144
}
145-
else {
146-
int16_t signedRk = static_cast<int16_t>(op.getRotationKind());
145+
else if (auto pprOp = dyn_cast<PPRotationOp>(op.getOperation())) {
146+
int16_t signedRk = static_cast<int16_t>(pprOp.getRotationKind());
147147
double rk = llvm::numbers::pi / (static_cast<double>(signedRk) / 2);
148148
qubits[0] = buildSingleQubitGate(qubits[0], "RZ", {rk}, rewriter).getOutQubits().front();
149149
}
150+
else if (isa<PPRotationArbitraryOp>(op)) {
151+
op->emitError("Unsupported qec.ppr.arbitrary operation.");
152+
}
150153
}
151154

152155
//===----------------------------------------------------------------------===//

mlir/lib/QEC/Transforms/TLayerReduction.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/PatternMatch.h"
2121
#include "mlir/Pass/Pass.h"
2222

23+
#include "QEC/IR/QECOps.h"
2324
#include "QEC/Utils/PauliStringWrapper.h"
2425
#include "QEC/Utils/QECLayer.h"
2526
#include "QEC/Utils/QECOpUtils.h"
@@ -64,6 +65,8 @@ std::pair<bool, QECOpInterface> checkCommutationAndFindMerge(QECOpInterface rhsO
6465
auto normalizedOps =
6566
normalizePPROps(lhsOp, rhsOp, lhsOp.getInQubits(), rhsOpInQubitsFromLhsOp);
6667

68+
// TODO: Handle PPRotationArbitraryOp properly
69+
6770
if (!normalizedOps.first.commutes(normalizedOps.second)) {
6871
return std::pair(false, nullptr);
6972
}
@@ -133,8 +136,11 @@ void moveOpToLayer(QECOpInterface rhsOp, QECLayer &rhsLayer, QECOpInterface merg
133136
// then just remove the `rhsOp` from the rhsLayer.
134137
void mergePPR(QECOpInterface rhsOp, QECLayer &rhsLayer, QECOpInterface mergeOp, IRRewriter &writer)
135138
{
136-
int16_t signedRk = static_cast<int16_t>(mergeOp.getRotationKind());
137-
mergeOp.setRotationKind(static_cast<uint16_t>(signedRk / 2));
139+
auto mergeOpPprOp = dyn_cast<PPRotationOp>(mergeOp.getOperation());
140+
assert(mergeOpPprOp != nullptr && "Op is not a PPRotationOp");
141+
142+
int16_t signedRk = static_cast<int16_t>(mergeOpPprOp.getRotationKind());
143+
mergeOpPprOp.setRotationKind(static_cast<uint16_t>(signedRk / 2));
138144

139145
rhsLayer.eraseOp(rhsOp);
140146
writer.replaceOp(rhsOp, rhsOp->getOperands());

mlir/lib/QEC/Utils/PauliStringWrapper.cpp

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,21 @@ PauliWordPair normalizePPROps(QECOpInterface lhs, QECOpInterface rhs, ValueRange
151151
lhsPSWrapper.op = lhs;
152152
rhsPSWrapper.op = rhs;
153153

154-
lhsPSWrapper.updateSign((int16_t)lhs.getRotationKind() < 0);
155-
rhsPSWrapper.updateSign((int16_t)rhs.getRotationKind() < 0);
154+
auto applySignFromOp = [](PauliStringWrapper &wrapper, QECOpInterface qecOp) {
155+
Operation *operation = qecOp.getOperation();
156+
157+
if (auto pprOp = dyn_cast<PPRotationOp>(operation)) {
158+
wrapper.updateSign(static_cast<int16_t>(pprOp.getRotationKind()) < 0);
159+
return;
160+
}
161+
162+
if (auto ppmOp = dyn_cast<PPMeasurementOp>(operation)) {
163+
wrapper.updateSign(static_cast<int16_t>(ppmOp.getRotationSign()) < 0);
164+
}
165+
};
166+
167+
applySignFromOp(lhsPSWrapper, lhs);
168+
applySignFromOp(rhsPSWrapper, rhs);
156169

157170
return std::make_pair(std::move(lhsPSWrapper), std::move(rhsPSWrapper));
158171
}
@@ -211,10 +224,17 @@ void updatePauliWord(QECOpInterface op, const PauliWord &newPauliWord, PatternRe
211224

212225
void updatePauliWordSign(QECOpInterface op, bool isNegated, PatternRewriter &rewriter)
213226
{
214-
int16_t rotationKind = static_cast<int16_t>(op.getRotationKind());
215-
int16_t sign = isNegated ? -1 : 1;
216-
rotationKind = (rotationKind < 0 ? -rotationKind : rotationKind) * sign;
217-
op.setRotationKind(rotationKind);
227+
if (auto pprOp = dyn_cast<PPRotationOp>(op.getOperation())) {
228+
int16_t rotationKind = static_cast<int16_t>(pprOp.getRotationKind());
229+
int16_t sign = isNegated ? -1 : 1;
230+
rotationKind = (rotationKind < 0 ? -rotationKind : rotationKind) * sign;
231+
pprOp.setRotationKind(rotationKind);
232+
}
233+
else if (auto ppmOp = dyn_cast<PPMeasurementOp>(op.getOperation())) {
234+
int16_t rotationSign = static_cast<int16_t>(ppmOp.getRotationSign());
235+
rotationSign = (rotationSign < 0 ? -rotationSign : rotationSign) * (isNegated ? -1 : 1);
236+
ppmOp.setRotationSign(rotationSign);
237+
}
218238
}
219239

220240
SmallVector<StringRef> extractPauliString(QECOpInterface op)

0 commit comments

Comments
 (0)