Skip to content

Commit 5e3813a

Browse files
authored
Add error if transform applied inside a qnode with program capture. (#2256)
**Context:** We currently don't support applying transforms to a subcircuit. We need to explicitly error out on this case. For: ``` @qml.qjit @qml.qnode(qml.device('lightning.qubit', wires=1)) @qml.transforms.cancel_inverses def c(): qml.X(0) qml.X(0) return qml.probs() c(), print(c.mlir) ``` we would have gotten: ``` TypeError: No ir_type_handler for aval type: <class 'pennylane.measurements.capture_measurements._get_abstract_measurement.<locals>.AbstractMeasurement'> ``` Now we get: ``` NotImplementedError: transforms cannot currently be applied inside a QNode. ``` Which is much more helpful and informative. **Description of the Change:** `QFuncToPlxprInterpreter` now raises an error if it encounters a transform. **Benefits:** Improved usability. **Possible Drawbacks:** **Related GitHub Issues:** [sc-104320]
1 parent fd2e4be commit 5e3813a

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070

7171
<h3>Improvements 🛠</h3>
7272

73+
* An error is now raised if a transform is applied inside a QNode when program capture is enabled.
74+
[(#2256)](https://github.com/PennyLaneAI/catalyst/pull/2256)
75+
7376
* A new ``"changed"`` option has been added to the ``keep_intermediate`` parameter of
7477
:func:`~.qjit`. This option saves intermediate IR files after each pass,
7578
but only when the IR is actually modified by the pass.

frontend/catalyst/from_plxpr/qfunc_interpreter.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pennylane.capture.primitives import ctrl_transform_prim as plxpr_ctrl_transform_prim
3333
from pennylane.capture.primitives import measure_prim as plxpr_measure_prim
3434
from pennylane.capture.primitives import pauli_measure_prim as plxpr_pauli_measure_prim
35+
from pennylane.capture.primitives import transform_prim
3536
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
3637
from pennylane.measurements import CountsMP
3738

@@ -807,6 +808,11 @@ def calling_convention(*args_plus_qreg):
807808
return outvals
808809

809810

811+
@PLxPRToQuantumJaxprInterpreter.register_primitive(transform_prim)
812+
def _error_on_transform(*args, **kwargs):
813+
raise NotImplementedError("transforms cannot currently be applied inside a QNode.")
814+
815+
810816
_special_op_bind_call = {
811817
qml.QubitUnitary: _qubit_unitary_bind_call,
812818
qml.GlobalPhase: _gphase_bind_call,

frontend/test/pytest/from_plxpr/test_from_plxpr.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,21 @@ def c():
179179
with pytest.raises(NotImplementedError, match="not yet supported"):
180180
from_plxpr(jaxpr)()
181181

182+
def test_errors_transform_inside_qnode(self):
183+
"""Test that an error is raised if a transform is applied inside a transform."""
184+
185+
@qml.qnode(qml.device("lightning.qubit", wires=1))
186+
@qml.transforms.cancel_inverses
187+
def c():
188+
return qml.expval(qml.Z(0))
189+
190+
jaxpr = jax.make_jaxpr(c)()
191+
192+
with pytest.raises(
193+
NotImplementedError, match="transforms cannot currently be applied inside a QNode."
194+
):
195+
from_plxpr(jaxpr)()
196+
182197

183198
class TestCatalystCompareJaxpr:
184199
"""Test comparing catalyst and pennylane jaxpr for a variety of situations."""

0 commit comments

Comments
 (0)