Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qualtran/_infra/bloq.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def call_graph(
according to `keep` and `max_depth` (if provided) or if a bloq cannot be
decomposed.
"""
from qualtran.resource_counting.bloq_counts import get_bloq_call_graph
from qualtran.resource_counting import get_bloq_call_graph

return get_bloq_call_graph(self, generalizer=generalizer, keep=keep, max_depth=max_depth)

Expand Down
8 changes: 4 additions & 4 deletions qualtran/cirq_interop/t_complexity_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,16 @@ def _from_iterable(it: Any) -> Optional[TComplexity]:

def _from_bloq_build_call_graph(stc: Any) -> Optional[TComplexity]:
# Uses the depth 1 call graph of Bloq `stc` to recursively compute the complexity.
from qualtran.resource_counting import get_bloq_callee_counts
from qualtran.resource_counting.generalizers import cirq_to_bloqs

if not isinstance(stc, Bloq):
return None
_, sigma = stc.call_graph(max_depth=1, generalizer=cirq_to_bloqs)
if sigma == {stc: 1}:
# No decomposition found.
callee_counts = get_bloq_callee_counts(bloq=stc, generalizer=cirq_to_bloqs)
if len(callee_counts) == 0:
return None
ret = TComplexity()
for bloq, n in sigma.items():
for bloq, n in callee_counts:
r = t_complexity(bloq)
if r is None:
return None
Expand Down
5 changes: 5 additions & 0 deletions qualtran/cirq_interop/t_complexity_protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import cirq
import pytest
from attrs import frozen

from qualtran import Bloq, GateWithRegisters, Signature
from qualtran._infra.gate_with_registers import get_named_qubits
Expand All @@ -36,6 +37,7 @@ class DoesNotSupportTComplexity:
...


@frozen
class SupportsTComplexityGateWithRegisters(GateWithRegisters):
@property
def signature(self) -> Signature:
Expand Down Expand Up @@ -64,6 +66,7 @@ def signature(self) -> 'Signature':
return Signature.build(q=1)


@frozen
class SupportsTComplexityBloqViaBuildCallGraph(Bloq):
@property
def signature(self) -> 'Signature':
Expand All @@ -75,6 +78,8 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> Set['BloqCountT']:

def test_t_complexity_for_bloq_via_build_call_graph():
bloq = SupportsTComplexityBloqViaBuildCallGraph()
_, sigma = bloq.call_graph(max_depth=1)
assert sigma != {}
assert t_complexity(bloq) == TComplexity(t=5, clifford=10)


Expand Down
9 changes: 5 additions & 4 deletions qualtran/drawing/bloq_counts_graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ def test_format_counts_graph_markdown():
ret = format_counts_graph_markdown(graph)
assert (
ret
== r""" - `MultiAnd(cvs=(1, 1, 1, 1, 1, 1))`
- `And(cv1=1, cv2=1, uncompute=False)`: $\displaystyle 5$
== """\
- `MultiAnd(cvs=(1, 1, 1, 1, 1, 1))`
- `And(cv1=1, cv2=1, uncompute=False)`: $\\displaystyle 5$
- `And(cv1=1, cv2=1, uncompute=False)`
- `ArbitraryClifford(n=2)`: $\displaystyle 9$
- `TGate()`: $\displaystyle 4$
- `ArbitraryClifford(n=2)`: $\\displaystyle 9$
- `TGate()`: $\\displaystyle 4$
"""
)

Expand Down
2 changes: 1 addition & 1 deletion qualtran/drawing/flame_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sympy

from qualtran import Bloq
from qualtran.resource_counting.bloq_counts import _compute_sigma
from qualtran.resource_counting._call_graph import _compute_sigma
from qualtran.resource_counting.t_counts_from_sigma import t_counts_from_sigma


Expand Down
6 changes: 4 additions & 2 deletions qualtran/resource_counting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
isort:skip_file
"""

from .bloq_counts import (
from ._generalization import GeneralizerT

from ._call_graph import (
BloqCountT,
GeneralizerT,
big_O,
SympySymbolAllocator,
get_bloq_callee_counts,
get_bloq_call_graph,
print_counts_graph,
build_cbloq_call_graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from qualtran import Bloq, CompositeBloq, DecomposeNotImplementedError, DecomposeTypeError

BloqCountT = Tuple[Bloq, Union[int, sympy.Expr]]
GeneralizerT = Callable[[Bloq], Optional[Bloq]]
from ._generalization import _make_composite_generalizer, GeneralizerT


def big_O(expr) -> sympy.Order:
Expand Down Expand Up @@ -85,6 +85,38 @@ def _generalize_callees(
return callee_counts


def get_bloq_callee_counts(
bloq: 'Bloq', generalizer: 'GeneralizerT' = None, ssa: SympySymbolAllocator = None
) -> List[BloqCountT]:
"""Get the direct callees of a bloq and the number of times they are called.

This calls `bloq.build_call_graph()` with the correct configuration options.

Args:
bloq: The bloq.
generalizer: If provided, run this function on each callee to consolidate attributes
that do not affect resource estimates. If the callable
returns `None`, the bloq is omitted from the counts graph. If a sequence of
generalizers is provided, each generalizer will be run in order.
ssa: A sympy symbol allocator that can be provided if one already exists in your
computation.

Returns:
A list of (bloq, n) bloq counts.
"""
if generalizer is None:
generalizer = lambda b: b
if isinstance(generalizer, (list, tuple)):
generalizer = _make_composite_generalizer(*generalizer)
if ssa is None:
ssa = SympySymbolAllocator()

try:
return _generalize_callees(bloq.build_call_graph(ssa), generalizer)
except (DecomposeNotImplementedError, DecomposeTypeError):
return []


def _build_call_graph(
bloq: Bloq,
generalizer: GeneralizerT,
Expand All @@ -103,8 +135,7 @@ def _build_call_graph(
# We already visited this node.
return

# Make sure this node is present in the graph. You could annotate
# additional node properties here, too.
# Make sure this node is present in the graph.
g.add_node(bloq)

# Base case 1: This node is requested by the user to be a leaf node via the `keep` parameter.
Expand All @@ -116,12 +147,7 @@ def _build_call_graph(
return

# Prep for recursion: get the callees and modify them according to `generalizer`.
try:
callee_counts = _generalize_callees(bloq.build_call_graph(ssa), generalizer)
except (DecomposeNotImplementedError, DecomposeTypeError):
# Base case 3: Decomposition (or `bloq_counts`) is not implemented. This is left as a
# leaf node.
return
callee_counts = get_bloq_callee_counts(bloq, generalizer)

# Base case 3: Empty list of callees
if not callee_counts:
Expand Down Expand Up @@ -165,19 +191,6 @@ def _compute_sigma(root_bloq: Bloq, g: nx.DiGraph) -> Dict[Bloq, Union[int, symp
return dict(bloq_sigmas[root_bloq])


def _make_composite_generalizer(*funcs: GeneralizerT) -> GeneralizerT:
"""Return a generalizer that calls each `*funcs` generalizers in order."""

def _composite_generalize(b: Bloq) -> Optional[Bloq]:
for func in funcs:
b = func(b)
if b is None:
return
return b

return _composite_generalize


def get_bloq_call_graph(
bloq: Bloq,
generalizer: Optional[Union['GeneralizerT', Sequence['GeneralizerT']]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from qualtran import Bloq, BloqBuilder, Signature, SoquetT
from qualtran.bloqs.basic_gates import TGate
from qualtran.bloqs.util_bloqs import ArbitraryClifford, Join, Split
from qualtran.resource_counting import BloqCountT, get_bloq_call_graph, SympySymbolAllocator
from qualtran.resource_counting import (
BloqCountT,
get_bloq_call_graph,
get_bloq_callee_counts,
SympySymbolAllocator,
)


@frozen
Expand Down Expand Up @@ -88,6 +93,23 @@ def test_bloq_counts_method():
assert str(expr) == '3*log(100)'


def test_get_bloq_callee_counts():
bloq = BigBloq(100)
callee_counts = get_bloq_callee_counts(bloq)
assert callee_counts == [(SubBloq(unrelated_param=0.5), sympy.log(100))]

bloq = DecompBloq(10)
callee_counts = get_bloq_callee_counts(bloq)
assert len(callee_counts) == 10 + 2 # 2 for split/join

bloq = SubBloq(unrelated_param=0.5)
callee_counts = get_bloq_callee_counts(bloq)
assert callee_counts == [(TGate(), 3)]

callee_counts = get_bloq_callee_counts(TGate())
assert callee_counts == []


def test_bloq_counts_decomp():
graph, sigma = get_bloq_call_graph(DecompBloq(10))
assert len(sigma) == 3 # includes split and join
Expand All @@ -107,7 +129,7 @@ def generalize(bloq):

@pytest.mark.notebook
def test_notebook():
qlt_testing.execute_notebook('bloq_counts')
qlt_testing.execute_notebook('call_graph')


def _to_tuple(x: Iterable[BloqCountT]) -> Sequence[BloqCountT]:
Expand Down
32 changes: 32 additions & 0 deletions qualtran/resource_counting/_generalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from qualtran import Bloq

GeneralizerT = Callable[['Bloq'], Optional['Bloq']]


def _make_composite_generalizer(*funcs: 'GeneralizerT') -> 'GeneralizerT':
"""Return a generalizer that calls each `*funcs` generalizers in order."""

def _composite_generalize(b: 'Bloq') -> Optional['Bloq']:
for func in funcs:
b = func(b)
if b is None:
return
return b

return _composite_generalize
46 changes: 46 additions & 0 deletions qualtran/resource_counting/_generalization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

from qualtran import Bloq
from qualtran.bloqs.for_testing import TestAtom
from qualtran.resource_counting._generalization import _make_composite_generalizer


def test_make_composite_generalizer():
def func1(b: Bloq) -> Optional[Bloq]:
if isinstance(b, TestAtom):
return TestAtom()
return b

def func2(b: Bloq) -> Optional[Bloq]:
if isinstance(b, TestAtom):
return
return b

b = TestAtom(tag='test')
assert func1(b) == TestAtom()
assert func2(b) is None

g00 = _make_composite_generalizer()
g10 = _make_composite_generalizer(func1)
g01 = _make_composite_generalizer(func2)
g11 = _make_composite_generalizer(func1, func2)
g11_r = _make_composite_generalizer(func2, func1)

assert g00(b) == b
assert g10(b) == TestAtom()
assert g01(b) is None
assert g11(b) is None
assert g11_r(b) is None
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# The Call Graph Protocol\n",
"\n",
"The call graph protocol lets you query which subbloq are called in a bloq's decomposition. Proper accounting of the quantity of subroutine calls is a crucial tool in estimating resource requirements for an algorithm. For example, you can expand the call graph until you reach 'expensive' gates like `TGate` or `Toffoli`. The total number of these gates set the runtime of the algorithm."
"The call graph protocol lets you query which subbloq are called in a bloq's decomposition. Proper accounting of the quantity of subroutine calls is a crucial tool in estimating resource requirements for an algorithm. For example, the number of 'expensive' gates like `TGate` or `Toffoli` required by a bloq is the sum of the number of those gates used by the bloq's callees."
]
},
{
Expand All @@ -20,10 +20,8 @@
"from qualtran.drawing import show_call_graph, show_counts_sigma\n",
"from qualtran.bloqs.mcmt import MultiAnd, And\n",
"\n",
"graph, sigma = MultiAnd(cvs=(1,)*6).call_graph()\n",
"\n",
"show_call_graph(graph)\n",
"show_counts_sigma(sigma)"
"graph, _ = MultiAnd(cvs=(1,)*6).call_graph()\n",
"show_call_graph(graph)"
]
},
{
Expand All @@ -33,7 +31,7 @@
"source": [
"## Interface\n",
"\n",
"The primary method for accessing the call graph of a bloq is `Bloq.call_graph()`. It returns a networkx graph as well as a dictionary of totals for \"leaf\" bloqs. \n",
"The primary method for accessing the call graph of a bloq is `Bloq.call_graph()`. It returns a networkx graph as well as an accounting of total bloq counts for \"leaf\" bloqs. \n",
"\n",
"Another method is `Bloq.bloq_counts`, which will return a dictionary of immediate children."
]
Expand Down Expand Up @@ -190,8 +188,7 @@
"source": [
"myfunc = MyFunc(n=sympy.sympify('n'))\n",
"graph, sigma = myfunc.call_graph()\n",
"show_call_graph(graph)\n",
"show_counts_sigma(sigma)"
"show_call_graph(graph)"
]
},
{
Expand All @@ -203,7 +200,7 @@
"\n",
"If a bloq does not override `build_call_graph(...)`, the default fallback will be used by Qualtran to support the call graph protocol.\n",
"\n",
"By default, Qualtran will use the decomposition to count subbloqs called by the bloq. For example, below we author a `SWAP` bloq. We define a decomposition but do not explicitly provide the call graph counts."
"By default, Qualtran will extract the call graph from the full decomposition. For example, below we author a `SWAP` bloq. We define a decomposition but do not explicitly override `build_call_graph`."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion qualtran/resource_counting/generalizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from qualtran.bloqs.mcmt.and_bloq import And, MultiAnd
from qualtran.bloqs.util_bloqs import Allocate, Free, Join, Partition, Split
from qualtran.cirq_interop import CirqGateAsBloq
from qualtran.resource_counting.bloq_counts import _make_composite_generalizer
from qualtran.resource_counting._generalization import _make_composite_generalizer
from qualtran.resource_counting.generalizers import (
cirq_to_bloqs,
CV,
Expand Down
2 changes: 1 addition & 1 deletion qualtran/resource_counting/t_counts_from_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import cirq

from qualtran.cirq_interop.t_complexity_protocol import TComplexity
from qualtran.resource_counting.symbolic_counting_utils import SymbolicInt

if TYPE_CHECKING:
Expand Down Expand Up @@ -46,6 +45,7 @@ def t_counts_from_sigma(
) -> SymbolicInt:
"""Aggregates T-counts from a sigma dictionary by summing T-costs for all rotation bloqs."""
from qualtran.bloqs.basic_gates import TGate
from qualtran.cirq_interop.t_complexity_protocol import TComplexity

if rotation_types is None:
rotation_types = _get_all_rotation_types()
Expand Down