From 4e21ec1104bc188e243f1c65158038a3510e6987 Mon Sep 17 00:00:00 2001 From: John Long Date: Sun, 13 Apr 2025 23:45:57 -0400 Subject: [PATCH 01/29] initial steps for squin to stim rewrite --- src/bloqade/main.py | 0 src/bloqade/squin/analysis/nsites/__init__.py | 1 + .../squin/analysis/rewrite/__init__.py | 0 src/bloqade/squin/analysis/rewrite/stim.py | 30 +++++++++++++++++++ 4 files changed, 31 insertions(+) create mode 100644 src/bloqade/main.py create mode 100644 src/bloqade/squin/analysis/rewrite/__init__.py create mode 100644 src/bloqade/squin/analysis/rewrite/stim.py diff --git a/src/bloqade/main.py b/src/bloqade/main.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/analysis/nsites/__init__.py b/src/bloqade/squin/analysis/nsites/__init__.py index da0a8e86..e3177322 100644 --- a/src/bloqade/squin/analysis/nsites/__init__.py +++ b/src/bloqade/squin/analysis/nsites/__init__.py @@ -1,6 +1,7 @@ # Need this for impl registration to work properly! from . import impls as impls from .lattice import ( + Sites as Sites, NoSites as NoSites, AnySites as AnySites, NumberSites as NumberSites, diff --git a/src/bloqade/squin/analysis/rewrite/__init__.py b/src/bloqade/squin/analysis/rewrite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/analysis/rewrite/stim.py b/src/bloqade/squin/analysis/rewrite/stim.py new file mode 100644 index 00000000..e655bec2 --- /dev/null +++ b/src/bloqade/squin/analysis/rewrite/stim.py @@ -0,0 +1,30 @@ +from typing import Dict +from dataclasses import dataclass + +from bloqade.analysis.address import Address +from bloqade.squin.analysis.nsites import Sites + +from kirin import ir +from kirin.rewrite.abc import RewriteResult, RewriteRule + + +@dataclass +class SquinToStim(RewriteRule): + + # Somehow need to plug in Address and Sites + # into the SSAValue Hints field, which only accepts + # Attribute types + + ## Could literally just plug in `ir.Attribute` into + ## the Address and Site lattices? + ## Couldn't I just create my own attributes instead? + + address_analysis: Dict[ir.SSAValue, Address] + op_site_analysis: Dict[ir.SSAValue, Sites] + + # need to plug in data into the SSAValue + # for the rewrite from these passes, + # then something should look at those hints + # and generate the corresponding stim statements + + pass \ No newline at end of file From 89a17fdb87d5c93148fecb629ac30f1c086fedfd Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 14 Apr 2025 10:58:50 -0400 Subject: [PATCH 02/29] Wrap rewrite pass --- src/bloqade/squin/analysis/rewrite/stim.py | 79 +++++++++++++++++----- 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/src/bloqade/squin/analysis/rewrite/stim.py b/src/bloqade/squin/analysis/rewrite/stim.py index e655bec2..77049df7 100644 --- a/src/bloqade/squin/analysis/rewrite/stim.py +++ b/src/bloqade/squin/analysis/rewrite/stim.py @@ -1,30 +1,77 @@ from typing import Dict from dataclasses import dataclass +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.print.printer import Printer + +from bloqade.squin import op, wire from bloqade.analysis.address import Address from bloqade.squin.analysis.nsites import Sites -from kirin import ir -from kirin.rewrite.abc import RewriteResult, RewriteRule +# Probably best to move these attributes to a +# separate file? Keep here for now +# to get things working first + + +@wire.dialect.register +@dataclass +class AddressAttribute(ir.Attribute): + + name = "Address" + address: Address + + def __hash__(self) -> int: + return hash(self.address) + + def print_impl(self, printer: Printer) -> None: + # Can return to implementing this later + pass +@op.dialect.register @dataclass -class SquinToStim(RewriteRule): - - # Somehow need to plug in Address and Sites - # into the SSAValue Hints field, which only accepts - # Attribute types +class SitesAttribute(ir.Attribute): - ## Could literally just plug in `ir.Attribute` into - ## the Address and Site lattices? - ## Couldn't I just create my own attributes instead? + name = "Sites" + sites: Sites + + def __hash__(self) -> int: + return hash(self.sites) + + def print_impl(self, printer: Printer) -> None: + # Can return to implementing this later + pass + + +@dataclass +class WrapSquinAnalysis(RewriteRule): address_analysis: Dict[ir.SSAValue, Address] - op_site_analysis: Dict[ir.SSAValue, Sites] + op_site_analysis: Dict[ir.SSAValue, Sites] + + def wrap(self, value: ir.SSAValue) -> bool: + address_analysis_result = self.address_analysis[value] + op_site_analysis_result = self.op_site_analysis[value] + + if value.hints["address"] and value.hints["sites"]: + return False + else: + value.hints["address"] = AddressAttribute(address_analysis_result) + value.hints["sites"] = SitesAttribute(op_site_analysis_result) + + return True - # need to plug in data into the SSAValue - # for the rewrite from these passes, - # then something should look at those hints - # and generate the corresponding stim statements + def rewrite_Block(self, node: ir.Block) -> RewriteResult: + has_done_something = False + for arg in node.args: + if self.wrap(arg): + has_done_something = True + return RewriteResult(has_done_something=has_done_something) - pass \ No newline at end of file + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + has_done_something = False + for result in node.results: + if self.wrap(result): + has_done_something = True + return RewriteResult(has_done_something=has_done_something) From f6105bb9958aa446e87ec9ec5ec8eef4956181d0 Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 14 Apr 2025 14:59:34 -0400 Subject: [PATCH 03/29] confirm analysis wrapping works --- src/bloqade/main.py | 0 src/bloqade/squin/analysis/rewrite/__init__.py | 0 src/bloqade/squin/rewrite/__init__.py | 5 +++++ src/bloqade/squin/{analysis => }/rewrite/stim.py | 6 +++--- 4 files changed, 8 insertions(+), 3 deletions(-) delete mode 100644 src/bloqade/main.py delete mode 100644 src/bloqade/squin/analysis/rewrite/__init__.py create mode 100644 src/bloqade/squin/rewrite/__init__.py rename src/bloqade/squin/{analysis => }/rewrite/stim.py (93%) diff --git a/src/bloqade/main.py b/src/bloqade/main.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/bloqade/squin/analysis/rewrite/__init__.py b/src/bloqade/squin/analysis/rewrite/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py new file mode 100644 index 00000000..6d733a9c --- /dev/null +++ b/src/bloqade/squin/rewrite/__init__.py @@ -0,0 +1,5 @@ +from .stim import ( + SitesAttribute as SitesAttribute, + AddressAttribute as AddressAttribute, + WrapSquinAnalysis as WrapSquinAnalysis, +) diff --git a/src/bloqade/squin/analysis/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py similarity index 93% rename from src/bloqade/squin/analysis/rewrite/stim.py rename to src/bloqade/squin/rewrite/stim.py index 77049df7..102fcce7 100644 --- a/src/bloqade/squin/analysis/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -26,7 +26,7 @@ def __hash__(self) -> int: def print_impl(self, printer: Printer) -> None: # Can return to implementing this later - pass + printer.print(self.address) @op.dialect.register @@ -41,7 +41,7 @@ def __hash__(self) -> int: def print_impl(self, printer: Printer) -> None: # Can return to implementing this later - pass + printer.print(self.sites) @dataclass @@ -54,7 +54,7 @@ def wrap(self, value: ir.SSAValue) -> bool: address_analysis_result = self.address_analysis[value] op_site_analysis_result = self.op_site_analysis[value] - if value.hints["address"] and value.hints["sites"]: + if value.hints.get("address") and value.hints.get("sites"): return False else: value.hints["address"] = AddressAttribute(address_analysis_result) From 7c70b5e15374e0c552e86833ce6eebf633a5c8e0 Mon Sep 17 00:00:00 2001 From: John Long Date: Mon, 14 Apr 2025 23:47:49 -0400 Subject: [PATCH 04/29] going to bed --- src/bloqade/squin/rewrite/stim.py | 92 ++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 102fcce7..c56b788b 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -2,10 +2,12 @@ from dataclasses import dataclass from kirin import ir +from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult from kirin.print.printer import Printer -from bloqade.squin import op, wire +from bloqade import stim +from bloqade.squin import op, wire, qubit from bloqade.analysis.address import Address from bloqade.squin.analysis.nsites import Sites @@ -75,3 +77,91 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if self.wrap(result): has_done_something = True return RewriteResult(has_done_something=has_done_something) + + +@dataclass +class SquinToStim(RewriteRule): + + def get_address(self, value: ir.SSAValue): + return value.hints.get("address") + + def get_sites(self, value: ir.SSAValue): + return value.hints.get("sites") + + # Go from (most) squin 1Q Ops to stim Ops + ## X, Y, Z, H, S, (no T!) + def get_stim_1q_gate(self, squin_op: op.stmts.Operator): + match squin_op: + case op.stmts.X(): + return stim.gate.X + case op.stmts.Y(): + return stim.gate.Y + case op.stmts.Z(): + return stim.gate.Z + case op.stmts.H(): + return stim.gate.H + case op.stmts.S(): + return stim.gate.S + case _: + return None + + # might be worth attempting multiple dispatch like qasm2 rewrites + # for Glob and Parallel to UOp + # The problem is I'd have to introduce names for all the statements + # as a ClassVar str. Maybe hold off for now. + + # Don't translate constants to Stim Aux Constants just yet, + # The Stim operations don't even rely on those particular + # constants, seems to be more for lowering from Python AST + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + pass + + def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: + + # this is an SSAValue, need it to be the actual operator + applied_op = apply_stmt.operator.owner + + # need to handle Identity and Control through separate means + # but we can handle X, Y, Z, and H here just fine + stim_1q_op = self.get_stim_1q_gate(applied_op) + + if isinstance(apply_stmt, qubit.Apply): + qubits = apply_stmt.qubits + address_attribute: AddressAttribute = self.get_address(qubits) + # Should get an AddressTuple out of the address stored in attribute + address_tuple = address_attribute.address + qubit_idx_ssas: list[ir.SSAValue] = [] + for address_qubit in address_tuple.data: + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_stmt.insert_before(apply_stmt) + + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + + apply_stmt.replace_by(stim_1q_stmt) + apply_stmt.delete() + + return RewriteResult(has_done_something=True) + + elif isinstance(apply_stmt, wire.Apply): + wires_ssa = apply_stmt.inputs + qubit_idx_ssas: list[ir.SSAValue] = [] + for wire_ssa in wires_ssa: + address_attribute = self.get_address(wire_ssa) + # get parent qubit idx + wire_address = address_attribute.data + qubit_idx = wire_address.origin_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_stmt.insert_before(apply_stmt) + + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + + apply_stmt.replace_by(stim_1q_stmt) + apply_stmt.delete() + + return RewriteResult(has_done_something=True) + + return RewriteResult() From e55a8fbdda6031ee27835eb83a353d3644cd86eb Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 15 Apr 2025 08:10:04 -0400 Subject: [PATCH 05/29] preliminary handling of Apply --- src/bloqade/squin/rewrite/__init__.py | 1 + src/bloqade/squin/rewrite/stim.py | 37 +++++++-- test/squin/stim/stim.py | 110 ++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 6 deletions(-) create mode 100644 test/squin/stim/stim.py diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index 6d733a9c..ecb6aab7 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,4 +1,5 @@ from .stim import ( + SquinToStim as SquinToStim, SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, WrapSquinAnalysis as WrapSquinAnalysis, diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index c56b788b..04b3eae8 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -115,7 +115,27 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator): # constants, seems to be more for lowering from Python AST def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - pass + + match node: + case wire.Apply() | qubit.Apply(): + return self.rewrite_Apply(node) + case wire.Wrap(): + return self.rewrite_Wrap(node) + case _: + return RewriteResult() + + return RewriteResult() + + def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: + + # get the wire going into the statement + wire_ssa = wrap_stmt.wire + # remove the wrap statement altogether, then the wire that went into it + wrap_stmt.delete() + wire_ssa.delete() + + # do NOT want to delete the qubit SSA! Leave that alone! + return RewriteResult(has_done_something=True) def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: @@ -140,8 +160,8 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) - apply_stmt.replace_by(stim_1q_stmt) - apply_stmt.delete() + # can't do any of this because of dependencies downstream + # apply_stmt.replace_by(stim_1q_stmt) return RewriteResult(has_done_something=True) @@ -151,16 +171,21 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: for wire_ssa in wires_ssa: address_attribute = self.get_address(wire_ssa) # get parent qubit idx - wire_address = address_attribute.data + wire_address = address_attribute.address qubit_idx = wire_address.origin_qubit.data qubit_idx_stmt = py.Constant(qubit_idx) + # accumulate all qubit idx SSA to instantiate stim gate stmt qubit_idx_ssas.append(qubit_idx_stmt.result) qubit_idx_stmt.insert_before(apply_stmt) stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + stim_1q_stmt.insert_before(apply_stmt) + + # There is something depending on the results of the statement, + # need to handle that so replacement/deletion can occur without problems - apply_stmt.replace_by(stim_1q_stmt) - apply_stmt.delete() + # apply's results become wires that go to other apply's/wrap stmts + # apply_stmt.replace_by(stim_1q_stmt) return RewriteResult(has_done_something=True) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py new file mode 100644 index 00000000..c2cfaa6b --- /dev/null +++ b/test/squin/stim/stim.py @@ -0,0 +1,110 @@ +from kirin import ir, types +from kirin.passes import Fold +from kirin.rewrite import Walk, Fixpoint, DeadCodeElimination +from kirin.dialects import py, func, ilist + +from bloqade import qasm2, squin +from bloqade.analysis import address +from bloqade.squin.rewrite import SquinToStim, WrapSquinAnalysis +from bloqade.squin.analysis import nsites + + +def as_int(value: int): + return py.constant.Constant(value=value) + + +def as_float(value: float): + return py.constant.Constant(value=value) + + +def gen_func_from_stmts(stmts): + + extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist) + + block = ir.Block(stmts) + block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + func_wrapper = func.Function( + sym_name="main", + signature=func.Signature(inputs=(), output=types.NoneType), + body=ir.Region(blocks=block), + ) + + constructed_method = ir.Method( + mod=None, + py_func=None, + sym_name="main", + dialects=extended_dialect, + code=func_wrapper, + arg_names=[], + ) + + fold_pass = Fold(extended_dialect) + fold_pass(constructed_method) + + return constructed_method + + +def test_1q(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubit out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + # pass the wires through some 1 Qubit operators + (op1 := squin.op.stmts.S()), + (op2 := squin.op.stmts.H()), + (op3 := squin.op.stmts.X()), + (v0 := squin.wire.Apply(op1.result, w0.result)), + (v1 := squin.wire.Apply(op2.result, v0.results[0])), + (v2 := squin.wire.Apply(op3.result, v1.results[0])), + ( + squin.wire.Wrap(v2.results[0], q0.result) + ), # for wrap, just free a use for the result SSAval + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + # the fact I return a wire here means DCE will NOT go ahead and + # eliminate all the other wire.Apply stmts + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + address_frame, _ = address.AddressAnalysis( + constructed_method.dialects + ).run_analysis(constructed_method, no_raise=False) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + constructed_method.print(analysis=address_frame.entries) + constructed_method.print(analysis=nsites_frame.entries) + + # attempt to wrap analysis results + + wrap_squin_analysis = WrapSquinAnalysis( + address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries + ) + fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis)) + rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code) + + # attempt rewrite to Stim + # Be careful with Fixpoint, can go to infinity until reaches defined threshold + squin_to_stim = Walk(SquinToStim()) + rewrite_res = squin_to_stim.rewrite(constructed_method.code) + + # Get rid of the unused statements + dce = Fixpoint(Walk(DeadCodeElimination())) + rewrite_res = dce.rewrite(constructed_method.code) + print(rewrite_res) + + constructed_method.print() + + +test_1q() From e76b3f303fdb249f2f224244ce9656686da87ce5 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 15 Apr 2025 11:08:12 -0400 Subject: [PATCH 06/29] support for control gates confirmed --- src/bloqade/squin/analysis/nsites/impls.py | 11 +- src/bloqade/squin/rewrite/stim.py | 130 ++++++++++++++------- test/squin/stim/stim.py | 68 ++++++++++- 3 files changed, 163 insertions(+), 46 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 3a4f94f1..36ea44fa 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -2,7 +2,7 @@ from kirin import ir, interp -from bloqade.squin import op +from bloqade.squin import op, wire from .lattice import ( NoSites, @@ -11,6 +11,15 @@ from .analysis import NSitesAnalysis +@wire.dialect.register(key="op.nsites") +class SquinWire(interp.MethodTable): + + @interp.impl(wire.Apply) + def apply(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.Apply): + + return tuple([frame.get(input) for input in stmt.inputs]) + + @op.dialect.register(key="op.nsites") class SquinOp(interp.MethodTable): diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 04b3eae8..50552c7d 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -102,9 +102,46 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator): return stim.gate.H case op.stmts.S(): return stim.gate.S + case op.stmts.Identity(): # enforce sites defined = num wires in + return stim.gate.Identity case _: return None + # get the qubit indices from the Apply statement argument + # wires/qubits + def insert_qubit_idx_ssa( + self, apply_stmt: wire.Apply | qubit.Apply + ) -> tuple[ir.SSAValue, ...]: + + if isinstance(apply_stmt, qubit.Apply): + qubits = apply_stmt.qubits + address_attribute: AddressAttribute = self.get_address(qubits) + # Should get an AddressTuple out of the address stored in attribute + address_tuple = address_attribute.address + qubit_idx_ssas: list[ir.SSAValue] = [] + for address_qubit in address_tuple.data: + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(apply_stmt) + qubit_idx_ssas.append(qubit_idx_stmt.result) + + return tuple(qubit_idx_ssas) + + elif isinstance(apply_stmt, wire.Apply): + wire_ssas = apply_stmt.inputs + qubit_idx_ssas: list[ir.SSAValue] = [] + for wire_ssa in wire_ssas: + address_attribute = self.get_address(wire_ssa) + # get parent qubit idx + wire_address = address_attribute.address + qubit_idx = wire_address.origin_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + # accumulate all qubit idx SSA to instantiate stim gate stmt + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_stmt.insert_before(apply_stmt) + + return tuple(qubit_idx_ssas) + # might be worth attempting multiple dispatch like qasm2 rewrites # for Glob and Parallel to UOp # The problem is I'd have to introduce names for all the statements @@ -142,51 +179,60 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: # this is an SSAValue, need it to be the actual operator applied_op = apply_stmt.operator.owner - # need to handle Identity and Control through separate means - # but we can handle X, Y, Z, and H here just fine - stim_1q_op = self.get_stim_1q_gate(applied_op) - - if isinstance(apply_stmt, qubit.Apply): - qubits = apply_stmt.qubits - address_attribute: AddressAttribute = self.get_address(qubits) - # Should get an AddressTuple out of the address stored in attribute - address_tuple = address_attribute.address - qubit_idx_ssas: list[ir.SSAValue] = [] - for address_qubit in address_tuple.data: - qubit_idx = address_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - qubit_idx_ssas.append(qubit_idx_stmt.result) - qubit_idx_stmt.insert_before(apply_stmt) - - stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) - - # can't do any of this because of dependencies downstream - # apply_stmt.replace_by(stim_1q_stmt) - - return RewriteResult(has_done_something=True) + if isinstance(applied_op, op.stmts.Control): + return self.rewrite_Control(apply_stmt) - elif isinstance(apply_stmt, wire.Apply): - wires_ssa = apply_stmt.inputs - qubit_idx_ssas: list[ir.SSAValue] = [] - for wire_ssa in wires_ssa: - address_attribute = self.get_address(wire_ssa) - # get parent qubit idx - wire_address = address_attribute.address - qubit_idx = wire_address.origin_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - # accumulate all qubit idx SSA to instantiate stim gate stmt - qubit_idx_ssas.append(qubit_idx_stmt.result) - qubit_idx_stmt.insert_before(apply_stmt) + # need to handle Control through separate means + # but we can handle X, Y, Z, H, and S here just fine + stim_1q_op = self.get_stim_1q_gate(applied_op) - stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) - stim_1q_stmt.insert_before(apply_stmt) + qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt) + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + stim_1q_stmt.insert_before(apply_stmt) - # There is something depending on the results of the statement, - # need to handle that so replacement/deletion can occur without problems + return RewriteResult(has_done_something=True) - # apply's results become wires that go to other apply's/wrap stmts - # apply_stmt.replace_by(stim_1q_stmt) + def rewrite_Control( + self, apply_stmt_ctrl: qubit.Apply | wire.Apply + ) -> RewriteResult: + # stim only supports CX, CY, CZ so we have to check the + # operator of Apply is a Control gate, enforce it's only asking for 1 control qubit, + # and that the target of the control is X, Y, Z in squin + + ctrl_op: op.stmts.Control = apply_stmt_ctrl.operator.owner + # enforce that n_controls is 1 + + ctrl_op_target_gate = ctrl_op.op.owner + + # should enforce that this is some multiple of 2 + qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt_ctrl) + # according to stim, final result can be: + # CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4) + target_qubits = [] + ctrl_qubits = [] + # definitely a better way to do this but + # can't think of it right now + for i in range(len(qubit_idx_ssas)): + if (i % 2) == 0: + ctrl_qubits.append(qubit_idx_ssas[i]) + else: + target_qubits.append(qubit_idx_ssas[i]) + + target_qubits = tuple(target_qubits) + ctrl_qubits = tuple(ctrl_qubits) + + match ctrl_op_target_gate: + case op.stmts.X(): + stim_stmt = stim.CX(controls=ctrl_qubits, targets=target_qubits) + case op.stmts.Y(): + stim_stmt = stim.CY(controls=ctrl_qubits, targets=target_qubits) + case op.stmts.Z(): + stim_stmt = stim.CZ(controls=ctrl_qubits, targets=target_qubits) + case _: + raise NotImplementedError( + "Control gates beyond CX, CY, and CZ are not supported" + ) - return RewriteResult(has_done_something=True) + stim_stmt.insert_before(apply_stmt_ctrl) - return RewriteResult() + return RewriteResult(has_done_something=True) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index c2cfaa6b..2df3ea08 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -58,12 +58,14 @@ def test_1q(): # pass the wires through some 1 Qubit operators (op1 := squin.op.stmts.S()), (op2 := squin.op.stmts.H()), - (op3 := squin.op.stmts.X()), + (op3 := squin.op.stmts.Identity(sites=1)), + (op4 := squin.op.stmts.Identity(sites=1)), (v0 := squin.wire.Apply(op1.result, w0.result)), (v1 := squin.wire.Apply(op2.result, v0.results[0])), (v2 := squin.wire.Apply(op3.result, v1.results[0])), + (v3 := squin.wire.Apply(op4.result, v2.results[0])), ( - squin.wire.Wrap(v2.results[0], q0.result) + squin.wire.Wrap(v3.results[0], q0.result) ), # for wrap, just free a use for the result SSAval (ret_none := func.ConstantNone()), (func.Return(ret_none)), @@ -107,4 +109,64 @@ def test_1q(): constructed_method.print() -test_1q() +def test_control(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubis out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + # set up control gate + (op1 := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op1.result, n_controls=1)), + (app := squin.wire.Apply(cx.result, w0.result, w1.result)), + # wrap things back + (squin.wire.Wrap(wire=app.results[0], qubit=q0.result)), + (squin.wire.Wrap(wire=app.results[1], qubit=q1.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + address_frame, _ = address.AddressAnalysis( + constructed_method.dialects + ).run_analysis(constructed_method, no_raise=False) + + nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( + constructed_method, no_raise=False + ) + + constructed_method.print(analysis=address_frame.entries) + constructed_method.print(analysis=nsites_frame.entries) + + wrap_squin_analysis = WrapSquinAnalysis( + address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries + ) + fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis)) + rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code) + + # attempt rewrite to Stim + # Be careful with Fixpoint, can go to infinity until reaches defined threshold + squin_to_stim = Walk(SquinToStim()) + rewrite_res = squin_to_stim.rewrite(constructed_method.code) + + constructed_method.print() + + # Get rid of the unused statements + dce = Fixpoint(Walk(DeadCodeElimination())) + rewrite_res = dce.rewrite(constructed_method.code) + print(rewrite_res) + + constructed_method.print() + + +test_control() From 7ba56708f46cb539dcdfc2173170543eba87ed31 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 15 Apr 2025 20:57:24 -0400 Subject: [PATCH 07/29] finally put everything into a pass --- src/bloqade/squin/passes/__init__.py | 1 + src/bloqade/squin/passes/stim.py | 59 +++++++++++++++++++++++ src/bloqade/squin/rewrite/__init__.py | 2 +- src/bloqade/squin/rewrite/stim.py | 58 ++++++++++++++++++++-- test/squin/stim/stim.py | 69 ++------------------------- 5 files changed, 119 insertions(+), 70 deletions(-) create mode 100644 src/bloqade/squin/passes/__init__.py create mode 100644 src/bloqade/squin/passes/stim.py diff --git a/src/bloqade/squin/passes/__init__.py b/src/bloqade/squin/passes/__init__.py new file mode 100644 index 00000000..6368db40 --- /dev/null +++ b/src/bloqade/squin/passes/__init__.py @@ -0,0 +1 @@ +from .stim import SquinToStim as SquinToStim diff --git a/src/bloqade/squin/passes/stim.py b/src/bloqade/squin/passes/stim.py new file mode 100644 index 00000000..774ec37f --- /dev/null +++ b/src/bloqade/squin/passes/stim.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass + +from kirin.passes import Fold +from kirin.rewrite import ( + Walk, + Chain, + Fixpoint, + DeadCodeElimination, + CommonSubexpressionElimination, +) +from kirin.ir.method import Method +from kirin.passes.abc import Pass +from kirin.rewrite.abc import RewriteResult + +import bloqade.squin.rewrite as squin_rewrite +from bloqade.analysis.address import AddressAnalysis +from bloqade.squin.analysis.nsites import ( + NSitesAnalysis, +) + + +@dataclass +class SquinToStim(Pass): + + def unsafe_run(self, mt: Method) -> RewriteResult: + fold_pass = Fold(mt.dialects) + # propagate constants + rewrite_result = fold_pass(mt) + + # Get necessary analysis results to plug into hints + address_analysis = AddressAnalysis(mt.dialects) + address_frame, _ = address_analysis.run_analysis(mt) + site_analysis = NSitesAnalysis(mt.dialects) + sites_frame, _ = site_analysis.run_analysis(mt) + + # Wrap Rewrite + SquinToStim can happen w/ standard walk + rewrite_result = ( + Walk( + Chain( + squin_rewrite.WrapSquinAnalysis( + address_analysis=address_frame.entries, + op_site_analysis=sites_frame.entries, + ), + squin_rewrite._SquinToStim(), + ) + ) + .rewrite(mt.code) + .join(rewrite_result) + ) + + rewrite_result = ( + Fixpoint( + Walk(Chain(DeadCodeElimination(), CommonSubexpressionElimination())) + ) + .rewrite(mt.code) + .join(rewrite_result) + ) + + return rewrite_result diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index ecb6aab7..5a475fcc 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,6 +1,6 @@ from .stim import ( - SquinToStim as SquinToStim, SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, WrapSquinAnalysis as WrapSquinAnalysis, + _SquinToStim as _SquinToStim, ) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 50552c7d..8aec111b 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -8,7 +8,7 @@ from bloqade import stim from bloqade.squin import op, wire, qubit -from bloqade.analysis.address import Address +from bloqade.analysis.address import Address, AddressWire, AddressTuple from bloqade.squin.analysis.nsites import Sites # Probably best to move these attributes to a @@ -80,13 +80,19 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @dataclass -class SquinToStim(RewriteRule): +class _SquinToStim(RewriteRule): def get_address(self, value: ir.SSAValue): - return value.hints.get("address") + try: + return value.hints["address"] + except KeyError: + raise KeyError(f"The address analysis hint for {value} does not exist") def get_sites(self, value: ir.SSAValue): - return value.hints.get("sites") + try: + return value.hints["sites"] + except KeyError: + raise KeyError(f"The sites analysis hint for {value} does not exist") # Go from (most) squin 1Q Ops to stim Ops ## X, Y, Z, H, S, (no T!) @@ -105,7 +111,9 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator): case op.stmts.Identity(): # enforce sites defined = num wires in return stim.gate.Identity case _: - return None + raise NotImplementedError( + f"The squin operator {squin_op} is not supported in the stim dialect" + ) # get the qubit indices from the Apply statement argument # wires/qubits @@ -158,6 +166,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return self.rewrite_Apply(node) case wire.Wrap(): return self.rewrite_Wrap(node) + case wire.Measure() | qubit.Measure(): + return self.rewrite_Measure(node) case _: return RewriteResult() @@ -236,3 +246,41 @@ def rewrite_Control( stim_stmt.insert_before(apply_stmt_ctrl) return RewriteResult(has_done_something=True) + + def rewrite_Measure( + self, measure_stmt: qubit.Measure | wire.Measure + ) -> RewriteResult: + + if isinstance(measure_stmt, qubit.Measure): + qubit_ilist_ssa = measure_stmt.qubits + # qubits are in an ilist which makes up an AddressTuple + address_tuple: AddressTuple = self.get_address(qubit_ilist_ssa).address + qubit_idx_ssas = [] + for qubit_address in address_tuple: + qubit_idx = qubit_address.data + qubit_idx_stmt = py.constant.Constant(qubit_idx) + qubit_idx_stmt.insert_before(measure_stmt) + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_ssas = tuple(qubit_idx_ssas) + + elif isinstance(measure_stmt, wire.Measure): + wire_ssa = measure_stmt.wire + wire_address: AddressWire = self.get_address(wire_ssa).address + + qubit_idx = wire_address.origin_qubit.data + qubit_idx_stmt = py.constant.Constant(qubit_idx) + qubit_idx_stmt.insert_before(measure_stmt) + qubit_idx_ssas = (qubit_idx_stmt.result,) + + else: + return RewriteResult() + + prob_noise_stmt = py.constant.Constant(0.0) + stim_measure_stmt = stim.collapse.MZ( + p=prob_noise_stmt.result, + targets=qubit_idx_ssas, + ) + prob_noise_stmt.insert_before(measure_stmt) + stim_measure_stmt.insert_before(measure_stmt) + + return RewriteResult(has_done_something=True) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 2df3ea08..22b4d018 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -1,12 +1,8 @@ from kirin import ir, types -from kirin.passes import Fold -from kirin.rewrite import Walk, Fixpoint, DeadCodeElimination from kirin.dialects import py, func, ilist +import bloqade.squin.passes as squin_passes from bloqade import qasm2, squin -from bloqade.analysis import address -from bloqade.squin.rewrite import SquinToStim, WrapSquinAnalysis -from bloqade.squin.analysis import nsites def as_int(value: int): @@ -38,9 +34,6 @@ def gen_func_from_stmts(stmts): arg_names=[], ) - fold_pass = Fold(extended_dialect) - fold_pass(constructed_method) - return constructed_method @@ -77,34 +70,8 @@ def test_1q(): constructed_method.print() - address_frame, _ = address.AddressAnalysis( - constructed_method.dialects - ).run_analysis(constructed_method, no_raise=False) - - nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False - ) - - constructed_method.print(analysis=address_frame.entries) - constructed_method.print(analysis=nsites_frame.entries) - - # attempt to wrap analysis results - - wrap_squin_analysis = WrapSquinAnalysis( - address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries - ) - fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis)) - rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code) - - # attempt rewrite to Stim - # Be careful with Fixpoint, can go to infinity until reaches defined threshold - squin_to_stim = Walk(SquinToStim()) - rewrite_res = squin_to_stim.rewrite(constructed_method.code) - - # Get rid of the unused statements - dce = Fixpoint(Walk(DeadCodeElimination())) - rewrite_res = dce.rewrite(constructed_method.code) - print(rewrite_res) + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) constructed_method.print() @@ -137,34 +104,8 @@ def test_control(): constructed_method = gen_func_from_stmts(stmts) constructed_method.print() - address_frame, _ = address.AddressAnalysis( - constructed_method.dialects - ).run_analysis(constructed_method, no_raise=False) - - nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis( - constructed_method, no_raise=False - ) - - constructed_method.print(analysis=address_frame.entries) - constructed_method.print(analysis=nsites_frame.entries) - - wrap_squin_analysis = WrapSquinAnalysis( - address_analysis=address_frame.entries, op_site_analysis=nsites_frame.entries - ) - fix_walk_squin_analysis = Fixpoint(Walk(wrap_squin_analysis)) - rewrite_res = fix_walk_squin_analysis.rewrite(constructed_method.code) - - # attempt rewrite to Stim - # Be careful with Fixpoint, can go to infinity until reaches defined threshold - squin_to_stim = Walk(SquinToStim()) - rewrite_res = squin_to_stim.rewrite(constructed_method.code) - - constructed_method.print() - - # Get rid of the unused statements - dce = Fixpoint(Walk(DeadCodeElimination())) - rewrite_res = dce.rewrite(constructed_method.code) - print(rewrite_res) + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) constructed_method.print() From f3203dadde4c6ae2c82ce0d788f9f3b69df7f057 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 16 Apr 2025 20:30:24 -0400 Subject: [PATCH 08/29] partially working reset rewrite --- src/bloqade/squin/rewrite/stim.py | 193 +++++++++++++++++++++++------- test/squin/stim/stim.py | 63 +++++++++- 2 files changed, 207 insertions(+), 49 deletions(-) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 8aec111b..94c02c04 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -8,7 +8,7 @@ from bloqade import stim from bloqade.squin import op, wire, qubit -from bloqade.analysis.address import Address, AddressWire, AddressTuple +from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple from bloqade.squin.analysis.nsites import Sites # Probably best to move these attributes to a @@ -82,9 +82,12 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @dataclass class _SquinToStim(RewriteRule): - def get_address(self, value: ir.SSAValue): + def get_address(self, value: ir.SSAValue) -> AddressAttribute: + try: - return value.hints["address"] + address_attr = value.hints["address"] + assert isinstance(address_attr, AddressAttribute) + return address_attr except KeyError: raise KeyError(f"The address analysis hint for {value} does not exist") @@ -115,9 +118,61 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator): f"The squin operator {squin_op} is not supported in the stim dialect" ) + def insert_qubit_idx_from_address( + self, address: AddressAttribute, stmt_to_insert_before: ir.Statement + ) -> tuple[ir.SSAValue, ...]: + + address_data = address.address + + qubit_idx_ssas = [] + + if isinstance(address_data, AddressTuple): + for address_qubit in address_data.data: + + # ensure that the stuff in the AddressTuple should be AddressQubit + # could handle AddressWires as well but don't see the need for that right now + if not isinstance(address_qubit, AddressQubit): + raise ValueError( + "Unsupported Address type detected inside AddressTuple, must be AddressQubit" + ) + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + qubit_idx_ssas.append(qubit_idx_stmt.result) + elif isinstance(address_data, AddressWire): + address_qubit = address_data.origin_qubit + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + qubit_idx_ssas.append(qubit_idx_stmt.result) + else: + NotImplementedError( + "qubit idx extraction and insertion only support for AddressTuple[AddressQubit] and AddressWire instances" + ) + + return tuple(qubit_idx_ssas) + + def insert_qubit_idx_from_wire_ssa( + self, wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement + ) -> tuple[ir.SSAValue, ...]: + qubit_idx_ssas = [] + for wire_ssa in wire_ssas: + address_attribute = self.get_address(wire_ssa) # get AddressWire + # get parent qubit idx + wire_address = address_attribute.address + assert isinstance(wire_address, AddressWire) + qubit_idx = wire_address.origin_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + # accumulate all qubit idx SSA to instantiate stim gate stmt + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + + return tuple(qubit_idx_ssas) + # get the qubit indices from the Apply statement argument # wires/qubits - def insert_qubit_idx_ssa( + + def insert_qubit_idx_after_apply( self, apply_stmt: wire.Apply | qubit.Apply ) -> tuple[ir.SSAValue, ...]: @@ -125,30 +180,18 @@ def insert_qubit_idx_ssa( qubits = apply_stmt.qubits address_attribute: AddressAttribute = self.get_address(qubits) # Should get an AddressTuple out of the address stored in attribute - address_tuple = address_attribute.address - qubit_idx_ssas: list[ir.SSAValue] = [] - for address_qubit in address_tuple.data: - qubit_idx = address_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - qubit_idx_stmt.insert_before(apply_stmt) - qubit_idx_ssas.append(qubit_idx_stmt.result) - - return tuple(qubit_idx_ssas) - + return self.insert_qubit_idx_from_address( + address=address_attribute, stmt_to_insert_before=apply_stmt + ) elif isinstance(apply_stmt, wire.Apply): wire_ssas = apply_stmt.inputs - qubit_idx_ssas: list[ir.SSAValue] = [] - for wire_ssa in wire_ssas: - address_attribute = self.get_address(wire_ssa) - # get parent qubit idx - wire_address = address_attribute.address - qubit_idx = wire_address.origin_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - # accumulate all qubit idx SSA to instantiate stim gate stmt - qubit_idx_ssas.append(qubit_idx_stmt.result) - qubit_idx_stmt.insert_before(apply_stmt) - - return tuple(qubit_idx_ssas) + return self.insert_qubit_idx_from_wire_ssa( + wire_ssas=wire_ssas, stmt_to_insert_before=apply_stmt + ) + else: + raise TypeError( + "unsupported statement detected, only wire.Apply and qubit.Apply statements are supported by this method" + ) # might be worth attempting multiple dispatch like qasm2 rewrites # for Glob and Parallel to UOp @@ -168,6 +211,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return self.rewrite_Wrap(node) case wire.Measure() | qubit.Measure(): return self.rewrite_Measure(node) + case wire.Reset() | qubit.Reset(): + return self.rewrite_Reset(node) case _: return RewriteResult() @@ -188,6 +233,7 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: # this is an SSAValue, need it to be the actual operator applied_op = apply_stmt.operator.owner + assert isinstance(applied_op, op.stmts.Operator) if isinstance(applied_op, op.stmts.Control): return self.rewrite_Control(apply_stmt) @@ -196,7 +242,25 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: # but we can handle X, Y, Z, H, and S here just fine stim_1q_op = self.get_stim_1q_gate(applied_op) - qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt) + # wire.Apply -> tuple of SSA -> AddressTuple + # qubit.Apply -> list of qubits -> AddressTuple + ## Both cases the statements follow the Stim semantics of + ## 1QGate a b c d .... + + if isinstance(apply_stmt, qubit.Apply): + address_attr = self.get_address(apply_stmt.qubits) + qubit_idx_ssas = self.insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=apply_stmt + ) + elif isinstance(apply_stmt, wire.Apply): + qubit_idx_ssas = self.insert_qubit_idx_from_wire_ssa( + wire_ssas=apply_stmt.inputs, stmt_to_insert_before=apply_stmt + ) + else: + raise TypeError( + "Unsupported statement detected, only qubit.Apply and wire.Apply are permitted" + ) + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) stim_1q_stmt.insert_before(apply_stmt) @@ -209,13 +273,16 @@ def rewrite_Control( # operator of Apply is a Control gate, enforce it's only asking for 1 control qubit, # and that the target of the control is X, Y, Z in squin - ctrl_op: op.stmts.Control = apply_stmt_ctrl.operator.owner + ctrl_op = apply_stmt_ctrl.operator.owner + assert isinstance(ctrl_op, op.stmts.Control) + # enforce that n_controls is 1 ctrl_op_target_gate = ctrl_op.op.owner + assert isinstance(ctrl_op_target_gate, op.stmts.Operator) # should enforce that this is some multiple of 2 - qubit_idx_ssas = self.insert_qubit_idx_ssa(apply_stmt=apply_stmt_ctrl) + qubit_idx_ssas = self.insert_qubit_idx_after_apply(apply_stmt=apply_stmt_ctrl) # according to stim, final result can be: # CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4) target_qubits = [] @@ -254,26 +321,26 @@ def rewrite_Measure( if isinstance(measure_stmt, qubit.Measure): qubit_ilist_ssa = measure_stmt.qubits # qubits are in an ilist which makes up an AddressTuple - address_tuple: AddressTuple = self.get_address(qubit_ilist_ssa).address - qubit_idx_ssas = [] - for qubit_address in address_tuple: - qubit_idx = qubit_address.data - qubit_idx_stmt = py.constant.Constant(qubit_idx) - qubit_idx_stmt.insert_before(measure_stmt) - qubit_idx_ssas.append(qubit_idx_stmt.result) - qubit_idx_ssas = tuple(qubit_idx_ssas) + address_attr = self.get_address(qubit_ilist_ssa) elif isinstance(measure_stmt, wire.Measure): + # Wire Terminator, should kill the existence of + # the wire here so DCE can sweep up the rest like with rewriting wrap wire_ssa = measure_stmt.wire - wire_address: AddressWire = self.get_address(wire_ssa).address + address_attr = self.get_address(wire_ssa) - qubit_idx = wire_address.origin_qubit.data - qubit_idx_stmt = py.constant.Constant(qubit_idx) - qubit_idx_stmt.insert_before(measure_stmt) - qubit_idx_ssas = (qubit_idx_stmt.result,) + # DCE can't remove the old measure_stmt for both wire and qubit versions + # because of the fact it has a result that can be depended on by other statements + # whereas Stim Measure has no such notion else: - return RewriteResult() + raise TypeError( + "unsupported Statement, only qubit.Measure and wire.Measure are supported" + ) + + qubit_idx_ssas = self.insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=measure_stmt + ) prob_noise_stmt = py.constant.Constant(0.0) stim_measure_stmt = stim.collapse.MZ( @@ -284,3 +351,43 @@ def rewrite_Measure( stim_measure_stmt.insert_before(measure_stmt) return RewriteResult(has_done_something=True) + + def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult: + """ + qubit.Reset(ilist of qubits) -> nothing + # safe to delete the statement afterwards, no depending results + # DCE could probably do this automatically? + + wire.Reset(single wire) -> new wire + # DO NOT DELETE + + # assume RZ, but could extend to RY and RX later + Stim RZ(targets = tuple[int of SSAVals]) + """ + + if isinstance(reset_stmt, qubit.Reset): + qubit_ilist_ssa = reset_stmt.qubits + # qubits are in an ilist which makes up an AddressTuple + address_attr = self.get_address(qubit_ilist_ssa) + qubit_idx_ssas = self.insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=reset_stmt + ) + elif isinstance(reset_stmt, wire.Reset): + address_attr = self.get_address(reset_stmt.wire) + qubit_idx_ssas = self.insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=reset_stmt + ) + else: + raise TypeError( + "unsupported statement, only qubit.Reset and wire.Reset are supported" + ) + + stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) + stim_rz_stmt.insert_before(reset_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_MeasureAndReset( + self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset + ): + pass diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 22b4d018..fcb65674 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -13,15 +13,15 @@ def as_float(value: float): return py.constant.Constant(value=value) -def gen_func_from_stmts(stmts): +def gen_func_from_stmts(stmts, output=types.NoneType): - extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist) + extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist).add(squin.qubit) block = ir.Block(stmts) block.args.append_from(types.MethodType[[], types.NoneType], "main_self") func_wrapper = func.Function( sym_name="main", - signature=func.Signature(inputs=(), output=types.NoneType), + signature=func.Signature(inputs=(), output=output), body=ir.Region(blocks=block), ) @@ -37,7 +37,7 @@ def gen_func_from_stmts(stmts): return constructed_method -def test_1q(): +def test_wire_1q(): stmts: list[ir.Statement] = [ # Create qubit register @@ -76,7 +76,7 @@ def test_1q(): constructed_method.print() -def test_control(): +def test_wire_control(): stmts: list[ir.Statement] = [ # Create qubit register @@ -110,4 +110,55 @@ def test_control(): constructed_method.print() -test_control() +def test_wire_measure(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(2)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubis out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + # measure the wires out + (r0 := squin.wire.Measure(w0.result)), + # return ints so DCE doesn't get + # rid of everything + # (ret_none := func.ConstantNone()), + (func.Return(r0)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_qubit_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # qubit.reset only accepts ilist of qubits + (qlist := ilist.New(values=[q0.result])), + (squin.qubit.Reset(qubits=qlist.result)), + (squin.qubit.Measure(qubits=qlist.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() From 20d4214587307c4903ee5ec270c0d6fdfb110d0c Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 16 Apr 2025 21:52:35 -0400 Subject: [PATCH 09/29] account for MeasureAndReset --- src/bloqade/squin/analysis/nsites/impls.py | 9 ++++ src/bloqade/squin/rewrite/stim.py | 61 ++++++++++++++++++---- test/squin/stim/stim.py | 58 ++++++++++++++++++++ 3 files changed, 118 insertions(+), 10 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 36ea44fa..fe94ff42 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -19,6 +19,15 @@ def apply(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.Apply): return tuple([frame.get(input) for input in stmt.inputs]) + @interp.impl(wire.MeasureAndReset) + def measure_and_reset( + self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.MeasureAndReset + ): + + # MeasureAndReset produces both a new wire + # and an integer which don't have any sites at all + return (NoSites(), NoSites()) + @op.dialect.register(key="op.nsites") class SquinOp(interp.MethodTable): diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 94c02c04..09e23d76 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -82,7 +82,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: @dataclass class _SquinToStim(RewriteRule): - def get_address(self, value: ir.SSAValue) -> AddressAttribute: + def get_address_attr(self, value: ir.SSAValue) -> AddressAttribute: try: address_attr = value.hints["address"] @@ -91,7 +91,7 @@ def get_address(self, value: ir.SSAValue) -> AddressAttribute: except KeyError: raise KeyError(f"The address analysis hint for {value} does not exist") - def get_sites(self, value: ir.SSAValue): + def get_sites_attr(self, value: ir.SSAValue): try: return value.hints["sites"] except KeyError: @@ -157,7 +157,7 @@ def insert_qubit_idx_from_wire_ssa( ) -> tuple[ir.SSAValue, ...]: qubit_idx_ssas = [] for wire_ssa in wire_ssas: - address_attribute = self.get_address(wire_ssa) # get AddressWire + address_attribute = self.get_address_attr(wire_ssa) # get AddressWire # get parent qubit idx wire_address = address_attribute.address assert isinstance(wire_address, AddressWire) @@ -178,7 +178,7 @@ def insert_qubit_idx_after_apply( if isinstance(apply_stmt, qubit.Apply): qubits = apply_stmt.qubits - address_attribute: AddressAttribute = self.get_address(qubits) + address_attribute: AddressAttribute = self.get_address_attr(qubits) # Should get an AddressTuple out of the address stored in attribute return self.insert_qubit_idx_from_address( address=address_attribute, stmt_to_insert_before=apply_stmt @@ -213,6 +213,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return self.rewrite_Measure(node) case wire.Reset() | qubit.Reset(): return self.rewrite_Reset(node) + case wire.MeasureAndReset() | qubit.MeasureAndReset(): + return self.rewrite_MeasureAndReset(node) case _: return RewriteResult() @@ -248,7 +250,7 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: ## 1QGate a b c d .... if isinstance(apply_stmt, qubit.Apply): - address_attr = self.get_address(apply_stmt.qubits) + address_attr = self.get_address_attr(apply_stmt.qubits) qubit_idx_ssas = self.insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=apply_stmt ) @@ -321,13 +323,13 @@ def rewrite_Measure( if isinstance(measure_stmt, qubit.Measure): qubit_ilist_ssa = measure_stmt.qubits # qubits are in an ilist which makes up an AddressTuple - address_attr = self.get_address(qubit_ilist_ssa) + address_attr = self.get_address_attr(qubit_ilist_ssa) elif isinstance(measure_stmt, wire.Measure): # Wire Terminator, should kill the existence of # the wire here so DCE can sweep up the rest like with rewriting wrap wire_ssa = measure_stmt.wire - address_attr = self.get_address(wire_ssa) + address_attr = self.get_address_attr(wire_ssa) # DCE can't remove the old measure_stmt for both wire and qubit versions # because of the fact it has a result that can be depended on by other statements @@ -368,12 +370,12 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult: if isinstance(reset_stmt, qubit.Reset): qubit_ilist_ssa = reset_stmt.qubits # qubits are in an ilist which makes up an AddressTuple - address_attr = self.get_address(qubit_ilist_ssa) + address_attr = self.get_address_attr(qubit_ilist_ssa) qubit_idx_ssas = self.insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=reset_stmt ) elif isinstance(reset_stmt, wire.Reset): - address_attr = self.get_address(reset_stmt.wire) + address_attr = self.get_address_attr(reset_stmt.wire) qubit_idx_ssas = self.insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=reset_stmt ) @@ -390,4 +392,43 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult: def rewrite_MeasureAndReset( self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset ): - pass + """ + qubit.MeasureAndReset(qubits) -> result + Could be translated (roughly equivalent) to + + stim.MZ(tuple[SSAvals for ints]) + stim.RZ(tuple[SSAvals for ints]) + + Stim does have MRZ, might be more reflective of what we want/ + lines up the semantics better + + """ + + if isinstance(meas_and_reset_stmt, qubit.MeasureAndReset): + + address_attr = self.get_address_attr(meas_and_reset_stmt.qubits) + qubit_idx_ssas = self.insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=meas_and_reset_stmt + ) + + elif isinstance(meas_and_reset_stmt, wire.MeasureAndReset): + address_attr = self.get_address_attr(meas_and_reset_stmt.wire) + qubit_idx_ssas = self.insert_qubit_idx_from_address( + address_attr, stmt_to_insert_before=meas_and_reset_stmt + ) + + else: + raise TypeError( + "Unsupported statement detected, only qubit.MeasureAndReset and wire.MeasureAndReset are supported" + ) + + error_p_stmt = py.Constant(0.0) + stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) + stim_rz_stmt = stim.collapse.RZ( + targets=qubit_idx_ssas, + ) + error_p_stmt.insert_before(meas_and_reset_stmt) + stim_mz_stmt.insert_before(meas_and_reset_stmt) + stim_rz_stmt.insert_before(meas_and_reset_stmt) + + return RewriteResult(has_done_something=True) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index fcb65674..4b2eaee4 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -162,3 +162,61 @@ def test_qubit_reset(): rewrite_result = squin_to_stim(constructed_method) print(rewrite_result) constructed_method.print() + + +def test_qubit_measure_and_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # qubit.reset only accepts ilist of qubits + (qlist := ilist.New(values=[q0.result])), + (squin.qubit.MeasureAndReset(qlist.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + # analysis_res, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(constructed_method) + # constructed_method.print(analysis=analysis_res.entries) + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_wire_measure_and_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # get wire out + (w0 := squin.wire.Unwrap(q0.result)), + # qubit.reset only accepts ilist of qubits + (squin.wire.MeasureAndReset(w0.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +# test_wire_measure_and_reset() +# test_qubit_measure_and_reset() From 300f9d70682bd824e0b25456c0d5c4688086ce55 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 16 Apr 2025 22:37:54 -0400 Subject: [PATCH 10/29] account for MeasureAndReset, fix up address analysis --- src/bloqade/analysis/address/impls.py | 14 +++++++++ src/bloqade/squin/rewrite/stim.py | 2 +- test/squin/stim/stim.py | 43 +++++++++++++++++++++++++-- 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index a9ae40e8..a32d221d 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -210,6 +210,20 @@ def apply( ) return new_address_wires + @interp.impl(squin.wire.MeasureAndReset) + def measure_and_reset( + self, + interp_: AddressAnalysis, + frame: ForwardFrame[Address], + stmt: squin.wire.MeasureAndReset, + ): + + # take the address data from the incoming wire + # and propagate that forward to the new wire generated. + # The first entry can safely be NotQubit because + # it's an integer + return (NotQubit(), frame.get(stmt.wire)) + @squin.qubit.dialect.register(key="qubit.address") class SquinQubitMethodTable(interp.MethodTable): diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 09e23d76..e15fef7a 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -322,7 +322,6 @@ def rewrite_Measure( if isinstance(measure_stmt, qubit.Measure): qubit_ilist_ssa = measure_stmt.qubits - # qubits are in an ilist which makes up an AddressTuple address_attr = self.get_address_attr(qubit_ilist_ssa) elif isinstance(measure_stmt, wire.Measure): @@ -386,6 +385,7 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult: stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) stim_rz_stmt.insert_before(reset_stmt) + reset_stmt.delete() return RewriteResult(has_done_something=True) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 4b2eaee4..14fa5801 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -1,8 +1,10 @@ from kirin import ir, types +from kirin.passes import Fold from kirin.dialects import py, func, ilist import bloqade.squin.passes as squin_passes from bloqade import qasm2, squin +from bloqade.analysis import address def as_int(value: int): @@ -150,7 +152,33 @@ def test_qubit_reset(): # qubit.reset only accepts ilist of qubits (qlist := ilist.New(values=[q0.result])), (squin.qubit.Reset(qubits=qlist.result)), - (squin.qubit.Measure(qubits=qlist.result)), + # (squin.qubit.Measure(qubits=qlist.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + rewrite_result = squin_to_stim(constructed_method) + print(rewrite_result) + constructed_method.print() + + +def test_wire_reset(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # get wire + (w0 := squin.wire.Unwrap(q0.result)), + # reset the wire + (squin.wire.Reset(w0.result)), (ret_none := func.ConstantNone()), (func.Return(ret_none)), ] @@ -212,11 +240,22 @@ def test_wire_measure_and_reset(): constructed_method = gen_func_from_stmts(stmts) constructed_method.print() + fold_pass = Fold(constructed_method.dialects) + fold_pass(constructed_method) + # need to make sure the origin qubit data is properly + # propagated to the new wire that wire.MeasureAndReset spits out + address_res, _ = address.AddressAnalysis(constructed_method.dialects).run_analysis( + constructed_method + ) + constructed_method.print(analysis=address_res.entries) + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) rewrite_result = squin_to_stim(constructed_method) print(rewrite_result) constructed_method.print() -# test_wire_measure_and_reset() +test_wire_measure_and_reset() # test_qubit_measure_and_reset() +# test_wire_reset() +# test_qubit_reset() From 13ae8a52c540a78721a0e50b06888e71dd660708 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 18 Apr 2025 10:33:56 -0400 Subject: [PATCH 11/29] more testing, verification implemented --- src/bloqade/squin/rewrite/stim.py | 86 ++++++++++++++++--- test/squin/stim/stim.py | 136 +++++++++++++++++++++++++++++- 2 files changed, 208 insertions(+), 14 deletions(-) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index e15fef7a..a0b74130 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, cast from dataclasses import dataclass from kirin import ir @@ -9,7 +9,7 @@ from bloqade import stim from bloqade.squin import op, wire, qubit from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple -from bloqade.squin.analysis.nsites import Sites +from bloqade.squin.analysis.nsites import Sites, NumberSites # Probably best to move these attributes to a # separate file? Keep here for now @@ -93,7 +93,9 @@ def get_address_attr(self, value: ir.SSAValue) -> AddressAttribute: def get_sites_attr(self, value: ir.SSAValue): try: - return value.hints["sites"] + sites_attr = value.hints["sites"] + assert isinstance(sites_attr, SitesAttribute) + return sites_attr except KeyError: raise KeyError(f"The sites analysis hint for {value} does not exist") @@ -111,7 +113,7 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator): return stim.gate.H case op.stmts.S(): return stim.gate.S - case op.stmts.Identity(): # enforce sites defined = num wires in + case op.stmts.Identity(): return stim.gate.Identity case _: raise NotImplementedError( @@ -193,6 +195,64 @@ def insert_qubit_idx_after_apply( "unsupported statement detected, only wire.Apply and qubit.Apply statements are supported by this method" ) + def verify_num_site_Apply(self, apply_stmt: wire.Apply | qubit.Apply): + + # get the number of wires/qubits that went into the statement + if isinstance(apply_stmt, wire.Apply): + num_sites_targeted = len(apply_stmt.inputs) + elif isinstance(apply_stmt, qubit.Apply): + address_attr = self.get_address_attr(apply_stmt.qubits) + # ilist has AddressTuple type, + # should be the case that the types INSIDE the AddressTuple + # are all AddressQubit + address_tuple = address_attr.address + assert isinstance(address_tuple, AddressTuple) + num_sites_targeted = len(address_tuple.data) + else: + raise TypeError( + "Number of sites verification cannot occur on statements other than wire.Apply and qubit.Apply" + ) + + # The only single qubit operator that can have its size customized is the Identity gate. + # There are two possible valid uses for size. + # Either: + ## Apply(Identity(size=n), wire0, ..., wire_n) + # Or: + ## Apply(Identity(size=1), wire0, ..., wire_n) + # both should have the same effect, and can naturally be represented in Stim as: + # 1QGate q1 q2 q3 q4 + + op_ssa = apply_stmt.operator + op_stmt = op_ssa.owner + cast(ir.Statement, op_stmt) + + sites_attr = self.get_sites_attr(op_ssa) + sites_type = sites_attr.sites + assert isinstance(sites_type, NumberSites) + num_sites_supported = sites_type.sites + + if isinstance(op_stmt, op.stmts.Identity): + if num_sites_supported != 1 or num_sites_supported != num_sites_targeted: + raise ValueError( + "squin.op.Identity must either have sites = 1 or sites = the number of qubits/wires it is being applied on" + ) + elif isinstance(op_stmt, op.stmts.Control): + # in Stim control gates have the following supported syntax + ## CX 1 2 + ## CX 1 2 3 4 (equivalent to CX 1 2, then CX 3 4) + + if ( + num_sites_targeted < num_sites_supported + or num_sites_targeted % num_sites_supported != 0 + ): + raise ValueError( + "Mismatch found between Control gate supported number of qubits/wires and number of qubits/wires being supplied." + ) + else: + return None + + return None + # might be worth attempting multiple dispatch like qasm2 rewrites # for Glob and Parallel to UOp # The problem is I'd have to introduce names for all the statements @@ -206,6 +266,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case wire.Apply() | qubit.Apply(): + self.verify_num_site_Apply(node) return self.rewrite_Apply(node) case wire.Wrap(): return self.rewrite_Wrap(node) @@ -244,11 +305,6 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: # but we can handle X, Y, Z, H, and S here just fine stim_1q_op = self.get_stim_1q_gate(applied_op) - # wire.Apply -> tuple of SSA -> AddressTuple - # qubit.Apply -> list of qubits -> AddressTuple - ## Both cases the statements follow the Stim semantics of - ## 1QGate a b c d .... - if isinstance(apply_stmt, qubit.Apply): address_attr = self.get_address_attr(apply_stmt.qubits) qubit_idx_ssas = self.insert_qubit_idx_from_address( @@ -266,6 +322,12 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) stim_1q_stmt.insert_before(apply_stmt) + # Could I safely delete the apply statements? + # If it's a a qubit.Apply yes, because it doesn't return anything + # If it's a wire.Apply no, because the `results` of that Apply get used later on + if isinstance(apply_stmt, qubit.Apply): + apply_stmt.delete() + return RewriteResult(has_done_something=True) def rewrite_Control( @@ -278,12 +340,9 @@ def rewrite_Control( ctrl_op = apply_stmt_ctrl.operator.owner assert isinstance(ctrl_op, op.stmts.Control) - # enforce that n_controls is 1 - ctrl_op_target_gate = ctrl_op.op.owner assert isinstance(ctrl_op_target_gate, op.stmts.Operator) - # should enforce that this is some multiple of 2 qubit_idx_ssas = self.insert_qubit_idx_after_apply(apply_stmt=apply_stmt_ctrl) # according to stim, final result can be: # CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4) @@ -314,6 +373,9 @@ def rewrite_Control( stim_stmt.insert_before(apply_stmt_ctrl) + if isinstance(apply_stmt_ctrl, qubit.Apply): + apply_stmt_ctrl.delete() + return RewriteResult(has_done_something=True) def rewrite_Measure( diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 14fa5801..1cc2dc7d 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -78,6 +78,135 @@ def test_wire_1q(): constructed_method.print() +def test_parallel_wire_1q_application(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + (w3 := squin.wire.Unwrap(qubit=q3.result)), + # Apply with stim semantics + (h_op := squin.op.stmts.H()), + ( + app_res := squin.wire.Apply( + h_op.result, w0.result, w1.result, w2.result, w3.result + ) + ), + # Wrap everything back + (squin.wire.Wrap(app_res.results[0], q0.result)), + (squin.wire.Wrap(app_res.results[1], q1.result)), + (squin.wire.Wrap(app_res.results[2], q2.result)), + (squin.wire.Wrap(app_res.results[3], q3.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +def test_parallel_qubit_1q_application(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # create ilist of qubits + (q_list := ilist.New(values=(q0.result, q1.result, q2.result, q3.result))), + # Apply with stim semantics + (h_op := squin.op.stmts.H()), + (app_res := squin.qubit.Apply(h_op.result, q_list.result)), # noqa: F841 + # Measure everything out + (meas_res := squin.qubit.Measure(q_list.result)), # noqa: F841 + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +def test_parallel_control_gate_wire_application(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + (w3 := squin.wire.Unwrap(qubit=q3.result)), + # Create and apply CX gate + (x_op := squin.op.stmts.X()), + (ctrl_x_op := squin.op.stmts.Control(x_op.result, n_controls=1)), + ( + app_res := squin.wire.Apply( + ctrl_x_op.result, w0.result, w1.result, w2.result, w3.result + ) + ), + # measure it all out + (meas_res_0 := squin.wire.Measure(app_res.results[0])), # noqa: F841 + (meas_res_1 := squin.wire.Measure(app_res.results[1])), # noqa: F841 + (meas_res_2 := squin.wire.Measure(app_res.results[2])), # noqa: F841 + (meas_res_3 := squin.wire.Measure(app_res.results[3])), # noqa: F841 + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + def test_wire_control(): stmts: list[ir.Statement] = [ @@ -255,7 +384,10 @@ def test_wire_measure_and_reset(): constructed_method.print() -test_wire_measure_and_reset() +# test_wire_measure_and_reset() # test_qubit_measure_and_reset() # test_wire_reset() -# test_qubit_reset() + +# test_parallel_qubit_1q_application() +# test_parallel_wire_1q_application() +test_parallel_control_gate_wire_application() From 59f763d83ae8df624c3e5c4155177f0c3639f359 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 18 Apr 2025 10:35:13 -0400 Subject: [PATCH 12/29] remove test call --- test/squin/stim/stim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 1cc2dc7d..3b44f25b 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -390,4 +390,4 @@ def test_wire_measure_and_reset(): # test_parallel_qubit_1q_application() # test_parallel_wire_1q_application() -test_parallel_control_gate_wire_application() +# test_parallel_control_gate_wire_application() From 9427571b67f929a1e0185385227ec347364e7f7e Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 18 Apr 2025 10:45:27 -0400 Subject: [PATCH 13/29] simple site verification test --- test/squin/stim/stim.py | 42 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 3b44f25b..92cfa91d 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -1,3 +1,4 @@ +import pytest from kirin import ir, types from kirin.passes import Fold from kirin.dialects import py, func, ilist @@ -384,6 +385,45 @@ def test_wire_measure_and_reset(): constructed_method.print() +def test_wire_apply_site_verification(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(3)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubis out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + # set up control gate + (op1 := squin.op.stmts.X()), + (cx := squin.op.stmts.Control(op1.result, n_controls=1)), + # improper application, dangling qubit that verification should catch! + (app := squin.wire.Apply(cx.result, w0.result, w1.result, w2.result)), + # wrap things back + (squin.wire.Wrap(wire=app.results[0], qubit=q0.result)), + (squin.wire.Wrap(wire=app.results[1], qubit=q1.result)), + (squin.wire.Wrap(wire=app.results[2], qubit=q2.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + + with pytest.raises(ValueError): + squin_to_stim(constructed_method) + + # test_wire_measure_and_reset() # test_qubit_measure_and_reset() # test_wire_reset() @@ -391,3 +431,5 @@ def test_wire_measure_and_reset(): # test_parallel_qubit_1q_application() # test_parallel_wire_1q_application() # test_parallel_control_gate_wire_application() + +# test_wire_apply_site_verification() From e45c10c1618da7ce1096301d4898f6794fcf0667 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 22 Apr 2025 10:22:04 -0400 Subject: [PATCH 14/29] saving remaining work before move to codegen --- src/bloqade/squin/qubit.py | 5 +++++ src/bloqade/stim/dialects/aux/stmts/annotate.py | 4 ++++ test/squin/stim/stim.py | 2 ++ 3 files changed, 11 insertions(+) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index f54882f6..b3c376d7 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -48,6 +48,11 @@ class MeasureAndReset(ir.Statement): result: ir.ResultValue = info.result(types.Int) +# MZ -> RZ +# MRZ (never really used, but could have performance benefit) +# + + @statement(dialect=dialect) class Reset(ir.Statement): qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) diff --git a/src/bloqade/stim/dialects/aux/stmts/annotate.py b/src/bloqade/stim/dialects/aux/stmts/annotate.py index eacd881f..28acf873 100644 --- a/src/bloqade/stim/dialects/aux/stmts/annotate.py +++ b/src/bloqade/stim/dialects/aux/stmts/annotate.py @@ -45,3 +45,7 @@ class NewPauliString(ir.Statement): flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) result: ir.ResultValue = info.result(type=PauliStringType) + + +# dialect_group = squin_extended = squin in bloqade-circuit + physical in bloqade-qec +# Chen will need squin -> Stim rewrite IF it is a subroutine of rewrite from bloqade-qec extension diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 92cfa91d..9992c73d 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -424,6 +424,8 @@ def test_wire_apply_site_verification(): squin_to_stim(constructed_method) +test_wire_measure() + # test_wire_measure_and_reset() # test_qubit_measure_and_reset() # test_wire_reset() From 30b8a973b6a96f1185d2ea7adae6d013e55d10c8 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 13 May 2025 22:25:32 -0400 Subject: [PATCH 15/29] account for MeasureQubit, MeasureQubitIlist as well as Broadcast functionality --- src/bloqade/squin/rewrite/stim.py | 209 ++++++++++++++---------------- 1 file changed, 98 insertions(+), 111 deletions(-) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index a0b74130..321bc1cf 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -123,6 +123,12 @@ def get_stim_1q_gate(self, squin_op: op.stmts.Operator): def insert_qubit_idx_from_address( self, address: AddressAttribute, stmt_to_insert_before: ir.Statement ) -> tuple[ir.SSAValue, ...]: + """ + Given an AddressAttribute which wraps the result of address analysis for a statement, + extract the qubit indices from the address type and insert them into the SSA form. + + Currently supports AddressTuple[AddressQubit] and AddressWire types. + """ address_data = address.address @@ -171,58 +177,33 @@ def insert_qubit_idx_from_wire_ssa( return tuple(qubit_idx_ssas) - # get the qubit indices from the Apply statement argument - # wires/qubits - - def insert_qubit_idx_after_apply( - self, apply_stmt: wire.Apply | qubit.Apply - ) -> tuple[ir.SSAValue, ...]: - - if isinstance(apply_stmt, qubit.Apply): - qubits = apply_stmt.qubits - address_attribute: AddressAttribute = self.get_address_attr(qubits) - # Should get an AddressTuple out of the address stored in attribute - return self.insert_qubit_idx_from_address( - address=address_attribute, stmt_to_insert_before=apply_stmt - ) - elif isinstance(apply_stmt, wire.Apply): - wire_ssas = apply_stmt.inputs - return self.insert_qubit_idx_from_wire_ssa( - wire_ssas=wire_ssas, stmt_to_insert_before=apply_stmt - ) - else: - raise TypeError( - "unsupported statement detected, only wire.Apply and qubit.Apply statements are supported by this method" - ) - - def verify_num_site_Apply(self, apply_stmt: wire.Apply | qubit.Apply): + def verify_num_sites( + self, stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast + ): + """ + Ensure for Apply statements that the number of qubits/wires strictly matches the number of sites + supported by the operator, and for Broadcast statements that the number of qubits/wires + is a multiple of the number of sites supported by the operator. + """ - # get the number of wires/qubits that went into the statement - if isinstance(apply_stmt, wire.Apply): - num_sites_targeted = len(apply_stmt.inputs) - elif isinstance(apply_stmt, qubit.Apply): - address_attr = self.get_address_attr(apply_stmt.qubits) - # ilist has AddressTuple type, - # should be the case that the types INSIDE the AddressTuple - # are all AddressQubit + # Determine the number of sites targeted + ## wire.Apply and wire.Broadcast takes a standard python tuple of SSAValues, + ## qubit.Apply and qubit.Broadcast takes an AddressTuple of AddressQubits + ## and need some extra logic to extract the number of sites targeted + if isinstance(stmt, (wire.Apply, wire.Broadcast)): + num_sites_targeted = len(stmt.inputs) + elif isinstance(stmt, (qubit.Apply, qubit.Broadcast)): + address_attr = self.get_address_attr(stmt.qubits) address_tuple = address_attr.address assert isinstance(address_tuple, AddressTuple) num_sites_targeted = len(address_tuple.data) else: raise TypeError( - "Number of sites verification cannot occur on statements other than wire.Apply and qubit.Apply" + "Number of sites verification can only occur on Apply or Broadcast statements" ) - # The only single qubit operator that can have its size customized is the Identity gate. - # There are two possible valid uses for size. - # Either: - ## Apply(Identity(size=n), wire0, ..., wire_n) - # Or: - ## Apply(Identity(size=1), wire0, ..., wire_n) - # both should have the same effect, and can naturally be represented in Stim as: - # 1QGate q1 q2 q3 q4 - - op_ssa = apply_stmt.operator + # Get the operator and its supported number of sites + op_ssa = stmt.operator op_stmt = op_ssa.owner cast(ir.Statement, op_stmt) @@ -231,33 +212,20 @@ def verify_num_site_Apply(self, apply_stmt: wire.Apply | qubit.Apply): assert isinstance(sites_type, NumberSites) num_sites_supported = sites_type.sites - if isinstance(op_stmt, op.stmts.Identity): - if num_sites_supported != 1 or num_sites_supported != num_sites_targeted: + # Perform the verification + if isinstance(stmt, (wire.Broadcast, qubit.Broadcast)): + if num_sites_targeted % num_sites_supported != 0: raise ValueError( - "squin.op.Identity must either have sites = 1 or sites = the number of qubits/wires it is being applied on" + "Number of qubits/wires to broadcast to must be a multiple of the number of sites supported by the operator" ) - elif isinstance(op_stmt, op.stmts.Control): - # in Stim control gates have the following supported syntax - ## CX 1 2 - ## CX 1 2 3 4 (equivalent to CX 1 2, then CX 3 4) - - if ( - num_sites_targeted < num_sites_supported - or num_sites_targeted % num_sites_supported != 0 - ): + elif isinstance(stmt, (wire.Apply, qubit.Apply)): + if num_sites_targeted != num_sites_supported: raise ValueError( - "Mismatch found between Control gate supported number of qubits/wires and number of qubits/wires being supplied." + "Number of qubits/wires to apply to must match the number of sites supported by the operator" ) - else: - return None return None - # might be worth attempting multiple dispatch like qasm2 rewrites - # for Glob and Parallel to UOp - # The problem is I'd have to introduce names for all the statements - # as a ClassVar str. Maybe hold off for now. - # Don't translate constants to Stim Aux Constants just yet, # The Stim operations don't even rely on those particular # constants, seems to be more for lowering from Python AST @@ -265,12 +233,12 @@ def verify_num_site_Apply(self, apply_stmt: wire.Apply | qubit.Apply): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: - case wire.Apply() | qubit.Apply(): - self.verify_num_site_Apply(node) - return self.rewrite_Apply(node) + case wire.Apply() | qubit.Apply() | wire.Broadcast() | qubit.Broadcast(): + self.verify_num_sites(node) + return self.rewrite_Apply_and_Broadcast(node) case wire.Wrap(): return self.rewrite_Wrap(node) - case wire.Measure() | qubit.Measure(): + case wire.Measure() | qubit.MeasureQubit() | qubit.MeasureQubitList(): return self.rewrite_Measure(node) case wire.Reset() | qubit.Reset(): return self.rewrite_Reset(node) @@ -279,8 +247,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: case _: return RewriteResult() - return RewriteResult() - def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: # get the wire going into the statement @@ -292,8 +258,9 @@ def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: # do NOT want to delete the qubit SSA! Leave that alone! return RewriteResult(has_done_something=True) - def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: - + def rewrite_Apply_and_Broadcast( + self, apply_stmt: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast + ) -> RewriteResult: # this is an SSAValue, need it to be the actual operator applied_op = apply_stmt.operator.owner assert isinstance(applied_op, op.stmts.Operator) @@ -305,51 +272,49 @@ def rewrite_Apply(self, apply_stmt: qubit.Apply | wire.Apply) -> RewriteResult: # but we can handle X, Y, Z, H, and S here just fine stim_1q_op = self.get_stim_1q_gate(applied_op) - if isinstance(apply_stmt, qubit.Apply): + if isinstance(apply_stmt, (qubit.Apply, qubit.Broadcast)): address_attr = self.get_address_attr(apply_stmt.qubits) qubit_idx_ssas = self.insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=apply_stmt ) - elif isinstance(apply_stmt, wire.Apply): + elif isinstance(apply_stmt, (wire.Apply, wire.Broadcast)): qubit_idx_ssas = self.insert_qubit_idx_from_wire_ssa( wire_ssas=apply_stmt.inputs, stmt_to_insert_before=apply_stmt ) else: raise TypeError( - "Unsupported statement detected, only qubit.Apply and wire.Apply are permitted" + "Unsupported statement detected, only Apply and Broadcast statements are permitted" ) stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) stim_1q_stmt.insert_before(apply_stmt) # Could I safely delete the apply statements? - # If it's a a qubit.Apply yes, because it doesn't return anything - # If it's a wire.Apply no, because the `results` of that Apply get used later on - if isinstance(apply_stmt, qubit.Apply): + # If it's a qubit.Apply or qubit.Broadcast, yes, because it doesn't return anything + # If it's a wire.Apply or wire.Broadcast, no, because the `results` of that Apply/Broadcast get used later on + if isinstance(apply_stmt, (qubit.Apply, qubit.Broadcast)): apply_stmt.delete() return RewriteResult(has_done_something=True) def rewrite_Control( - self, apply_stmt_ctrl: qubit.Apply | wire.Apply + self, + stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast, ) -> RewriteResult: - # stim only supports CX, CY, CZ so we have to check the - # operator of Apply is a Control gate, enforce it's only asking for 1 control qubit, - # and that the target of the control is X, Y, Z in squin - - ctrl_op = apply_stmt_ctrl.operator.owner + """ + Handle control gates for Apply and Broadcast statements. + """ + ctrl_op = stmt_with_ctrl.operator.owner assert isinstance(ctrl_op, op.stmts.Control) ctrl_op_target_gate = ctrl_op.op.owner assert isinstance(ctrl_op_target_gate, op.stmts.Operator) - qubit_idx_ssas = self.insert_qubit_idx_after_apply(apply_stmt=apply_stmt_ctrl) - # according to stim, final result can be: - # CX 1 2 3 4 -> CX(1, targ=2), CX(3, targ=4) + qubit_idx_ssas = self.insert_qubit_idx_after_apply(stmt=stmt_with_ctrl) + + # Separate control and target qubits target_qubits = [] ctrl_qubits = [] - # definitely a better way to do this but - # can't think of it right now for i in range(len(qubit_idx_ssas)): if (i % 2) == 0: ctrl_qubits.append(qubit_idx_ssas[i]) @@ -359,6 +324,7 @@ def rewrite_Control( target_qubits = tuple(target_qubits) ctrl_qubits = tuple(ctrl_qubits) + # Handle supported gates match ctrl_op_target_gate: case op.stmts.X(): stim_stmt = stim.CX(controls=ctrl_qubits, targets=target_qubits) @@ -371,36 +337,57 @@ def rewrite_Control( "Control gates beyond CX, CY, and CZ are not supported" ) - stim_stmt.insert_before(apply_stmt_ctrl) + stim_stmt.insert_before(stmt_with_ctrl) - if isinstance(apply_stmt_ctrl, qubit.Apply): - apply_stmt_ctrl.delete() + # Delete the original statement if it's a qubit.Apply or qubit.Broadcast + if isinstance(stmt_with_ctrl, (qubit.Apply, qubit.Broadcast)): + stmt_with_ctrl.delete() return RewriteResult(has_done_something=True) - def rewrite_Measure( - self, measure_stmt: qubit.Measure | wire.Measure - ) -> RewriteResult: - - if isinstance(measure_stmt, qubit.Measure): - qubit_ilist_ssa = measure_stmt.qubits - address_attr = self.get_address_attr(qubit_ilist_ssa) - - elif isinstance(measure_stmt, wire.Measure): - # Wire Terminator, should kill the existence of - # the wire here so DCE can sweep up the rest like with rewriting wrap - wire_ssa = measure_stmt.wire - address_attr = self.get_address_attr(wire_ssa) - - # DCE can't remove the old measure_stmt for both wire and qubit versions - # because of the fact it has a result that can be depended on by other statements - # whereas Stim Measure has no such notion - + def insert_qubit_idx_after_apply( + self, stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast + ) -> tuple[ir.SSAValue, ...]: + """ + Extract qubit indices from Apply or Broadcast statements. + """ + if isinstance(stmt, (qubit.Apply, qubit.Broadcast)): + qubits = stmt.qubits + address_attribute: AddressAttribute = self.get_address_attr(qubits) + return self.insert_qubit_idx_from_address( + address=address_attribute, stmt_to_insert_before=stmt + ) + elif isinstance(stmt, (wire.Apply, wire.Broadcast)): + wire_ssas = stmt.inputs + return self.insert_qubit_idx_from_wire_ssa( + wire_ssas=wire_ssas, stmt_to_insert_before=stmt + ) else: raise TypeError( - "unsupported Statement, only qubit.Measure and wire.Measure are supported" + "Unsupported statement detected, only Apply and Broadcast statements are supported by this method" ) + # qubit.Measure no longer exists, need to handle + # qubit.MeasureQubit and MeasureQubitList + def rewrite_Measure( + self, measure_stmt: wire.Measure | qubit.MeasureQubit | qubit.MeasureQubitList + ) -> RewriteResult: + + match measure_stmt: + case qubit.MeasureQubit(): + qubit_ilist_ssa = measure_stmt.qubit + address_attr = self.get_address_attr(qubit_ilist_ssa) + case qubit.MeasureQubitList(): + qubit_ssa = measure_stmt.qubits + address_attr = self.get_address_attr(qubit_ssa) + case wire.Measure(): + wire_ssa = measure_stmt.wire + address_attr = self.get_address_attr(wire_ssa) + case _: + raise TypeError( + "Unsupported Statement, only qubit.MeasureQubit, qubit.MeasureQubitList, and wire.Measure are supported" + ) + qubit_idx_ssas = self.insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=measure_stmt ) From 615a30b0427d7cc8d7ff4b4583941a3d40c73d91 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 13 May 2025 22:37:29 -0400 Subject: [PATCH 16/29] revise tests --- src/bloqade/squin/rewrite/stim.py | 7 ------- test/squin/stim/stim.py | 27 +++++++-------------------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 321bc1cf..6a1f5ed6 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -11,10 +11,6 @@ from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple from bloqade.squin.analysis.nsites import Sites, NumberSites -# Probably best to move these attributes to a -# separate file? Keep here for now -# to get things working first - @wire.dialect.register @dataclass @@ -448,9 +444,6 @@ def rewrite_MeasureAndReset( stim.MZ(tuple[SSAvals for ints]) stim.RZ(tuple[SSAvals for ints]) - Stim does have MRZ, might be more reflective of what we want/ - lines up the semantics better - """ if isinstance(meas_and_reset_stmt, qubit.MeasureAndReset): diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 9992c73d..1f8fd4aa 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -79,7 +79,7 @@ def test_wire_1q(): constructed_method.print() -def test_parallel_wire_1q_application(): +def test_broadcast_wire_1q_application(): stmts: list[ir.Statement] = [ # Create qubit register @@ -102,7 +102,7 @@ def test_parallel_wire_1q_application(): # Apply with stim semantics (h_op := squin.op.stmts.H()), ( - app_res := squin.wire.Apply( + app_res := squin.wire.Broadcast( h_op.result, w0.result, w1.result, w2.result, w3.result ) ), @@ -125,7 +125,7 @@ def test_parallel_wire_1q_application(): constructed_method.print() -def test_parallel_qubit_1q_application(): +def test_broadcast_qubit_1q_application(): stmts: list[ir.Statement] = [ # Create qubit register @@ -144,9 +144,9 @@ def test_parallel_qubit_1q_application(): (q_list := ilist.New(values=(q0.result, q1.result, q2.result, q3.result))), # Apply with stim semantics (h_op := squin.op.stmts.H()), - (app_res := squin.qubit.Apply(h_op.result, q_list.result)), # noqa: F841 + (app_res := squin.qubit.Broadcast(h_op.result, q_list.result)), # noqa: F841 # Measure everything out - (meas_res := squin.qubit.Measure(q_list.result)), # noqa: F841 + (meas_res := squin.qubit.MeasureQubitList(q_list.result)), # noqa: F841 (ret_none := func.ConstantNone()), (func.Return(ret_none)), ] @@ -161,7 +161,7 @@ def test_parallel_qubit_1q_application(): constructed_method.print() -def test_parallel_control_gate_wire_application(): +def test_broadcast_control_gate_wire_application(): stmts: list[ir.Statement] = [ # Create qubit register @@ -405,7 +405,7 @@ def test_wire_apply_site_verification(): # set up control gate (op1 := squin.op.stmts.X()), (cx := squin.op.stmts.Control(op1.result, n_controls=1)), - # improper application, dangling qubit that verification should catch! + # improper application, cx should only support 2 sites (app := squin.wire.Apply(cx.result, w0.result, w1.result, w2.result)), # wrap things back (squin.wire.Wrap(wire=app.results[0], qubit=q0.result)), @@ -422,16 +422,3 @@ def test_wire_apply_site_verification(): with pytest.raises(ValueError): squin_to_stim(constructed_method) - - -test_wire_measure() - -# test_wire_measure_and_reset() -# test_qubit_measure_and_reset() -# test_wire_reset() - -# test_parallel_qubit_1q_application() -# test_parallel_wire_1q_application() -# test_parallel_control_gate_wire_application() - -# test_wire_apply_site_verification() From 47691e2e369e182c8f2c94d77452d145ad3fd6e7 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 13 May 2025 22:40:25 -0400 Subject: [PATCH 17/29] remove unnecessary comment --- src/bloqade/squin/qubit.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index b1563b68..355b37d8 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -75,11 +75,6 @@ class MeasureAndReset(ir.Statement): result: ir.ResultValue = info.result(ilist.IListType[types.Bool]) -# MZ -> RZ -# MRZ (never really used, but could have performance benefit) -# - - @statement(dialect=dialect) class Reset(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) From c8c38aee27beac8078d6f453043e681c34abc11a Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 14 May 2025 14:49:52 -0400 Subject: [PATCH 18/29] Move wrap analysis rewrite into its own file --- src/bloqade/squin/rewrite/__init__.py | 2 +- src/bloqade/squin/rewrite/stim.py | 71 ++------------------- src/bloqade/squin/rewrite/wrap_analysis.py | 73 ++++++++++++++++++++++ 3 files changed, 78 insertions(+), 68 deletions(-) create mode 100644 src/bloqade/squin/rewrite/wrap_analysis.py diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index 5a475fcc..33cf281c 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,6 +1,6 @@ from .stim import ( SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, - WrapSquinAnalysis as WrapSquinAnalysis, _SquinToStim as _SquinToStim, ) +from .wrap_analysis import WrapSquinAnalysis as WrapSquinAnalysis diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 6a1f5ed6..53f125d9 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -1,78 +1,15 @@ -from typing import Dict, cast +from typing import cast from dataclasses import dataclass from kirin import ir from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from kirin.print.printer import Printer from bloqade import stim from bloqade.squin import op, wire, qubit -from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple -from bloqade.squin.analysis.nsites import Sites, NumberSites - - -@wire.dialect.register -@dataclass -class AddressAttribute(ir.Attribute): - - name = "Address" - address: Address - - def __hash__(self) -> int: - return hash(self.address) - - def print_impl(self, printer: Printer) -> None: - # Can return to implementing this later - printer.print(self.address) - - -@op.dialect.register -@dataclass -class SitesAttribute(ir.Attribute): - - name = "Sites" - sites: Sites - - def __hash__(self) -> int: - return hash(self.sites) - - def print_impl(self, printer: Printer) -> None: - # Can return to implementing this later - printer.print(self.sites) - - -@dataclass -class WrapSquinAnalysis(RewriteRule): - - address_analysis: Dict[ir.SSAValue, Address] - op_site_analysis: Dict[ir.SSAValue, Sites] - - def wrap(self, value: ir.SSAValue) -> bool: - address_analysis_result = self.address_analysis[value] - op_site_analysis_result = self.op_site_analysis[value] - - if value.hints.get("address") and value.hints.get("sites"): - return False - else: - value.hints["address"] = AddressAttribute(address_analysis_result) - value.hints["sites"] = SitesAttribute(op_site_analysis_result) - - return True - - def rewrite_Block(self, node: ir.Block) -> RewriteResult: - has_done_something = False - for arg in node.args: - if self.wrap(arg): - has_done_something = True - return RewriteResult(has_done_something=has_done_something) - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - has_done_something = False - for result in node.results: - if self.wrap(result): - has_done_something = True - return RewriteResult(has_done_something=has_done_something) +from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple +from bloqade.squin.analysis.nsites import NumberSites +from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute @dataclass diff --git a/src/bloqade/squin/rewrite/wrap_analysis.py b/src/bloqade/squin/rewrite/wrap_analysis.py new file mode 100644 index 00000000..fafd3fc4 --- /dev/null +++ b/src/bloqade/squin/rewrite/wrap_analysis.py @@ -0,0 +1,73 @@ +from typing import Dict +from dataclasses import dataclass + +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult +from kirin.print.printer import Printer + +from bloqade.squin import op, wire +from bloqade.analysis.address import Address +from bloqade.squin.analysis.nsites import Sites + + +@wire.dialect.register +@dataclass +class AddressAttribute(ir.Attribute): + + name = "Address" + address: Address + + def __hash__(self) -> int: + return hash(self.address) + + def print_impl(self, printer: Printer) -> None: + # Can return to implementing this later + printer.print(self.address) + + +@op.dialect.register +@dataclass +class SitesAttribute(ir.Attribute): + + name = "Sites" + sites: Sites + + def __hash__(self) -> int: + return hash(self.sites) + + def print_impl(self, printer: Printer) -> None: + # Can return to implementing this later + printer.print(self.sites) + + +@dataclass +class WrapSquinAnalysis(RewriteRule): + + address_analysis: Dict[ir.SSAValue, Address] + op_site_analysis: Dict[ir.SSAValue, Sites] + + def wrap(self, value: ir.SSAValue) -> bool: + address_analysis_result = self.address_analysis[value] + op_site_analysis_result = self.op_site_analysis[value] + + if value.hints.get("address") and value.hints.get("sites"): + return False + else: + value.hints["address"] = AddressAttribute(address_analysis_result) + value.hints["sites"] = SitesAttribute(op_site_analysis_result) + + return True + + def rewrite_Block(self, node: ir.Block) -> RewriteResult: + has_done_something = False + for arg in node.args: + if self.wrap(arg): + has_done_something = True + return RewriteResult(has_done_something=has_done_something) + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + has_done_something = False + for result in node.results: + if self.wrap(result): + has_done_something = True + return RewriteResult(has_done_something=has_done_something) From 72e5f3fb7123d1cab17801e6fd1f8cef5e7ba048 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 14 May 2025 16:55:52 -0400 Subject: [PATCH 19/29] split out reusable utility functions into a seperate file --- src/bloqade/squin/rewrite/qubit_to_stim.py | 0 src/bloqade/squin/rewrite/stim.py | 207 ++++----------------- src/bloqade/squin/rewrite/stim_util.py | 125 +++++++++++++ src/bloqade/squin/rewrite/wire_to_stim.py | 0 4 files changed, 160 insertions(+), 172 deletions(-) create mode 100644 src/bloqade/squin/rewrite/qubit_to_stim.py create mode 100644 src/bloqade/squin/rewrite/stim_util.py create mode 100644 src/bloqade/squin/rewrite/wire_to_stim.py diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py new file mode 100644 index 00000000..e69de29b diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py index 53f125d9..38c9e02f 100644 --- a/src/bloqade/squin/rewrite/stim.py +++ b/src/bloqade/squin/rewrite/stim.py @@ -1,4 +1,3 @@ -from typing import cast from dataclasses import dataclass from kirin import ir @@ -7,167 +6,23 @@ from bloqade import stim from bloqade.squin import op, wire, qubit -from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple -from bloqade.squin.analysis.nsites import NumberSites -from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute +from bloqade.squin.rewrite.stim_util import ( + get_stim_1q_gate, + verify_num_sites, + insert_qubit_idx_from_address, + insert_qubit_idx_from_wire_ssa, +) +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute @dataclass class _SquinToStim(RewriteRule): - def get_address_attr(self, value: ir.SSAValue) -> AddressAttribute: - - try: - address_attr = value.hints["address"] - assert isinstance(address_attr, AddressAttribute) - return address_attr - except KeyError: - raise KeyError(f"The address analysis hint for {value} does not exist") - - def get_sites_attr(self, value: ir.SSAValue): - try: - sites_attr = value.hints["sites"] - assert isinstance(sites_attr, SitesAttribute) - return sites_attr - except KeyError: - raise KeyError(f"The sites analysis hint for {value} does not exist") - - # Go from (most) squin 1Q Ops to stim Ops - ## X, Y, Z, H, S, (no T!) - def get_stim_1q_gate(self, squin_op: op.stmts.Operator): - match squin_op: - case op.stmts.X(): - return stim.gate.X - case op.stmts.Y(): - return stim.gate.Y - case op.stmts.Z(): - return stim.gate.Z - case op.stmts.H(): - return stim.gate.H - case op.stmts.S(): - return stim.gate.S - case op.stmts.Identity(): - return stim.gate.Identity - case _: - raise NotImplementedError( - f"The squin operator {squin_op} is not supported in the stim dialect" - ) - - def insert_qubit_idx_from_address( - self, address: AddressAttribute, stmt_to_insert_before: ir.Statement - ) -> tuple[ir.SSAValue, ...]: - """ - Given an AddressAttribute which wraps the result of address analysis for a statement, - extract the qubit indices from the address type and insert them into the SSA form. - - Currently supports AddressTuple[AddressQubit] and AddressWire types. - """ - - address_data = address.address - - qubit_idx_ssas = [] - - if isinstance(address_data, AddressTuple): - for address_qubit in address_data.data: - - # ensure that the stuff in the AddressTuple should be AddressQubit - # could handle AddressWires as well but don't see the need for that right now - if not isinstance(address_qubit, AddressQubit): - raise ValueError( - "Unsupported Address type detected inside AddressTuple, must be AddressQubit" - ) - qubit_idx = address_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - qubit_idx_stmt.insert_before(stmt_to_insert_before) - qubit_idx_ssas.append(qubit_idx_stmt.result) - elif isinstance(address_data, AddressWire): - address_qubit = address_data.origin_qubit - qubit_idx = address_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - qubit_idx_stmt.insert_before(stmt_to_insert_before) - qubit_idx_ssas.append(qubit_idx_stmt.result) - else: - NotImplementedError( - "qubit idx extraction and insertion only support for AddressTuple[AddressQubit] and AddressWire instances" - ) - - return tuple(qubit_idx_ssas) - - def insert_qubit_idx_from_wire_ssa( - self, wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement - ) -> tuple[ir.SSAValue, ...]: - qubit_idx_ssas = [] - for wire_ssa in wire_ssas: - address_attribute = self.get_address_attr(wire_ssa) # get AddressWire - # get parent qubit idx - wire_address = address_attribute.address - assert isinstance(wire_address, AddressWire) - qubit_idx = wire_address.origin_qubit.data - qubit_idx_stmt = py.Constant(qubit_idx) - # accumulate all qubit idx SSA to instantiate stim gate stmt - qubit_idx_ssas.append(qubit_idx_stmt.result) - qubit_idx_stmt.insert_before(stmt_to_insert_before) - - return tuple(qubit_idx_ssas) - - def verify_num_sites( - self, stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast - ): - """ - Ensure for Apply statements that the number of qubits/wires strictly matches the number of sites - supported by the operator, and for Broadcast statements that the number of qubits/wires - is a multiple of the number of sites supported by the operator. - """ - - # Determine the number of sites targeted - ## wire.Apply and wire.Broadcast takes a standard python tuple of SSAValues, - ## qubit.Apply and qubit.Broadcast takes an AddressTuple of AddressQubits - ## and need some extra logic to extract the number of sites targeted - if isinstance(stmt, (wire.Apply, wire.Broadcast)): - num_sites_targeted = len(stmt.inputs) - elif isinstance(stmt, (qubit.Apply, qubit.Broadcast)): - address_attr = self.get_address_attr(stmt.qubits) - address_tuple = address_attr.address - assert isinstance(address_tuple, AddressTuple) - num_sites_targeted = len(address_tuple.data) - else: - raise TypeError( - "Number of sites verification can only occur on Apply or Broadcast statements" - ) - - # Get the operator and its supported number of sites - op_ssa = stmt.operator - op_stmt = op_ssa.owner - cast(ir.Statement, op_stmt) - - sites_attr = self.get_sites_attr(op_ssa) - sites_type = sites_attr.sites - assert isinstance(sites_type, NumberSites) - num_sites_supported = sites_type.sites - - # Perform the verification - if isinstance(stmt, (wire.Broadcast, qubit.Broadcast)): - if num_sites_targeted % num_sites_supported != 0: - raise ValueError( - "Number of qubits/wires to broadcast to must be a multiple of the number of sites supported by the operator" - ) - elif isinstance(stmt, (wire.Apply, qubit.Apply)): - if num_sites_targeted != num_sites_supported: - raise ValueError( - "Number of qubits/wires to apply to must match the number of sites supported by the operator" - ) - - return None - - # Don't translate constants to Stim Aux Constants just yet, - # The Stim operations don't even rely on those particular - # constants, seems to be more for lowering from Python AST - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case wire.Apply() | qubit.Apply() | wire.Broadcast() | qubit.Broadcast(): - self.verify_num_sites(node) + verify_num_sites(node) return self.rewrite_Apply_and_Broadcast(node) case wire.Wrap(): return self.rewrite_Wrap(node) @@ -203,15 +58,16 @@ def rewrite_Apply_and_Broadcast( # need to handle Control through separate means # but we can handle X, Y, Z, H, and S here just fine - stim_1q_op = self.get_stim_1q_gate(applied_op) + stim_1q_op = get_stim_1q_gate(applied_op) if isinstance(apply_stmt, (qubit.Apply, qubit.Broadcast)): - address_attr = self.get_address_attr(apply_stmt.qubits) - qubit_idx_ssas = self.insert_qubit_idx_from_address( + address_attr = apply_stmt.qubits.hints.get("address") + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=apply_stmt ) elif isinstance(apply_stmt, (wire.Apply, wire.Broadcast)): - qubit_idx_ssas = self.insert_qubit_idx_from_wire_ssa( + qubit_idx_ssas = insert_qubit_idx_from_wire_ssa( wire_ssas=apply_stmt.inputs, stmt_to_insert_before=apply_stmt ) else: @@ -286,13 +142,13 @@ def insert_qubit_idx_after_apply( """ if isinstance(stmt, (qubit.Apply, qubit.Broadcast)): qubits = stmt.qubits - address_attribute: AddressAttribute = self.get_address_attr(qubits) - return self.insert_qubit_idx_from_address( + address_attribute: AddressAttribute = qubits.hints.get("address") + return insert_qubit_idx_from_address( address=address_attribute, stmt_to_insert_before=stmt ) elif isinstance(stmt, (wire.Apply, wire.Broadcast)): wire_ssas = stmt.inputs - return self.insert_qubit_idx_from_wire_ssa( + return insert_qubit_idx_from_wire_ssa( wire_ssas=wire_ssas, stmt_to_insert_before=stmt ) else: @@ -309,19 +165,22 @@ def rewrite_Measure( match measure_stmt: case qubit.MeasureQubit(): qubit_ilist_ssa = measure_stmt.qubit - address_attr = self.get_address_attr(qubit_ilist_ssa) + address_attr = qubit_ilist_ssa.hints.get("address") + assert isinstance(address_attr, AddressAttribute) case qubit.MeasureQubitList(): qubit_ssa = measure_stmt.qubits - address_attr = self.get_address_attr(qubit_ssa) + address_attr = qubit_ssa.hints.get("address") + assert isinstance(address_attr, AddressAttribute) case wire.Measure(): wire_ssa = measure_stmt.wire - address_attr = self.get_address_attr(wire_ssa) + address_attr = wire_ssa.hints.get("address") + assert isinstance(address_attr, AddressAttribute) case _: raise TypeError( "Unsupported Statement, only qubit.MeasureQubit, qubit.MeasureQubitList, and wire.Measure are supported" ) - qubit_idx_ssas = self.insert_qubit_idx_from_address( + qubit_idx_ssas = insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=measure_stmt ) @@ -351,13 +210,15 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult: if isinstance(reset_stmt, qubit.Reset): qubit_ilist_ssa = reset_stmt.qubits # qubits are in an ilist which makes up an AddressTuple - address_attr = self.get_address_attr(qubit_ilist_ssa) - qubit_idx_ssas = self.insert_qubit_idx_from_address( + address_attr = qubit_ilist_ssa.hints.get("address") + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=reset_stmt ) elif isinstance(reset_stmt, wire.Reset): - address_attr = self.get_address_attr(reset_stmt.wire) - qubit_idx_ssas = self.insert_qubit_idx_from_address( + address_attr = reset_stmt.wire.hints.get("address") + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=reset_stmt ) else: @@ -385,14 +246,16 @@ def rewrite_MeasureAndReset( if isinstance(meas_and_reset_stmt, qubit.MeasureAndReset): - address_attr = self.get_address_attr(meas_and_reset_stmt.qubits) - qubit_idx_ssas = self.insert_qubit_idx_from_address( + address_attr = meas_and_reset_stmt.qubits.hints.get("address") + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( address=address_attr, stmt_to_insert_before=meas_and_reset_stmt ) elif isinstance(meas_and_reset_stmt, wire.MeasureAndReset): - address_attr = self.get_address_attr(meas_and_reset_stmt.wire) - qubit_idx_ssas = self.insert_qubit_idx_from_address( + address_attr = meas_and_reset_stmt.wire.hints.get("address") + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( address_attr, stmt_to_insert_before=meas_and_reset_stmt ) diff --git a/src/bloqade/squin/rewrite/stim_util.py b/src/bloqade/squin/rewrite/stim_util.py new file mode 100644 index 00000000..979271ff --- /dev/null +++ b/src/bloqade/squin/rewrite/stim_util.py @@ -0,0 +1,125 @@ +from typing import cast + +from kirin import ir +from kirin.dialects import py + +from bloqade import stim +from bloqade.squin import op, wire, qubit +from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple +from bloqade.squin.analysis.nsites import NumberSites +from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute + + +def get_stim_1q_gate(squin_op: op.stmts.Operator): + """ + Map squin 1Q Ops to stim Ops. + """ + match squin_op: + case op.stmts.X(): + return stim.gate.X + case op.stmts.Y(): + return stim.gate.Y + case op.stmts.Z(): + return stim.gate.Z + case op.stmts.H(): + return stim.gate.H + case op.stmts.S(): + return stim.gate.S + case op.stmts.Identity(): + return stim.gate.Identity + case _: + raise NotImplementedError( + f"The squin operator {squin_op} is not supported in the stim dialect" + ) + + +def insert_qubit_idx_from_address( + address: AddressAttribute, stmt_to_insert_before: ir.Statement +) -> tuple[ir.SSAValue, ...]: + """ + Extract qubit indices from an AddressAttribute and insert them into the SSA form. + """ + address_data = address.address + qubit_idx_ssas = [] + + if isinstance(address_data, AddressTuple): + for address_qubit in address_data.data: + if not isinstance(address_qubit, AddressQubit): + raise ValueError( + "Unsupported Address type detected inside AddressTuple, must be AddressQubit" + ) + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + qubit_idx_ssas.append(qubit_idx_stmt.result) + elif isinstance(address_data, AddressWire): + address_qubit = address_data.origin_qubit + qubit_idx = address_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + qubit_idx_ssas.append(qubit_idx_stmt.result) + else: + raise NotImplementedError( + "qubit idx extraction and insertion only supported for AddressTuple[AddressQubit] and AddressWire instances" + ) + + return tuple(qubit_idx_ssas) + + +def insert_qubit_idx_from_wire_ssa( + wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement +) -> tuple[ir.SSAValue, ...]: + """ + Extract qubit indices from wire SSA values and insert them into the SSA form. + """ + qubit_idx_ssas = [] + for wire_ssa in wire_ssas: + address_attribute = wire_ssa.hints.get("address") + assert isinstance(address_attribute, AddressAttribute) + wire_address = address_attribute.address + assert isinstance(wire_address, AddressWire) + qubit_idx = wire_address.origin_qubit.data + qubit_idx_stmt = py.Constant(qubit_idx) + qubit_idx_ssas.append(qubit_idx_stmt.result) + qubit_idx_stmt.insert_before(stmt_to_insert_before) + + return tuple(qubit_idx_ssas) + + +def verify_num_sites(stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast): + """ + Verify that the number of qubits/wires matches the number of sites supported by the operator. + """ + if isinstance(stmt, (wire.Apply, wire.Broadcast)): + num_sites_targeted = len(stmt.inputs) + elif isinstance(stmt, (qubit.Apply, qubit.Broadcast)): + address_attr = stmt.qubits.hints.get("address") + assert isinstance(address_attr, AddressAttribute) + address_tuple = address_attr.address + assert isinstance(address_tuple, AddressTuple) + num_sites_targeted = len(address_tuple.data) + else: + raise TypeError( + "Number of sites verification can only occur on Apply or Broadcast statements" + ) + + op_ssa = stmt.operator + op_stmt = op_ssa.owner + cast(ir.Statement, op_stmt) + + sites_attr = op_ssa.hints.get("sites") + assert isinstance(sites_attr, SitesAttribute) + sites_type = sites_attr.sites + assert isinstance(sites_type, NumberSites) + num_sites_supported = sites_type.sites + + if isinstance(stmt, (wire.Broadcast, qubit.Broadcast)): + if num_sites_targeted % num_sites_supported != 0: + raise ValueError( + "Number of qubits/wires to broadcast to must be a multiple of the number of sites supported by the operator" + ) + elif isinstance(stmt, (wire.Apply, qubit.Apply)): + if num_sites_targeted != num_sites_supported: + raise ValueError( + "Number of qubits/wires to apply to must match the number of sites supported by the operator" + ) diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py new file mode 100644 index 00000000..e69de29b From 4731fe0ee4099512e0f0cdb491b27a32cf60724e Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 14 May 2025 17:04:28 -0400 Subject: [PATCH 20/29] fix export problem --- src/bloqade/squin/rewrite/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index 33cf281c..b79118c3 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,6 +1,8 @@ from .stim import ( + _SquinToStim as _SquinToStim, +) +from .wrap_analysis import ( SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, - _SquinToStim as _SquinToStim, + WrapSquinAnalysis as WrapSquinAnalysis, ) -from .wrap_analysis import WrapSquinAnalysis as WrapSquinAnalysis From 410bcbc83b08f56e2c3320079bdd61c41a56d0f6 Mon Sep 17 00:00:00 2001 From: John Long Date: Wed, 14 May 2025 23:31:02 -0400 Subject: [PATCH 21/29] split out rewrite rules, factor in feedback on rewriting wrap/unwrap without blind deletion --- src/bloqade/squin/passes/stim.py | 7 +- src/bloqade/squin/rewrite/__init__.py | 5 +- src/bloqade/squin/rewrite/qubit_to_stim.py | 155 ++++++++++ src/bloqade/squin/rewrite/stim.py | 276 ------------------ .../{stim_util.py => stim_rewrite_util.py} | 134 ++++++--- src/bloqade/squin/rewrite/wire_to_stim.py | 144 +++++++++ 6 files changed, 405 insertions(+), 316 deletions(-) delete mode 100644 src/bloqade/squin/rewrite/stim.py rename src/bloqade/squin/rewrite/{stim_util.py => stim_rewrite_util.py} (52%) diff --git a/src/bloqade/squin/passes/stim.py b/src/bloqade/squin/passes/stim.py index 774ec37f..d8b80635 100644 --- a/src/bloqade/squin/passes/stim.py +++ b/src/bloqade/squin/passes/stim.py @@ -12,7 +12,7 @@ from kirin.passes.abc import Pass from kirin.rewrite.abc import RewriteResult -import bloqade.squin.rewrite as squin_rewrite +from bloqade.squin.rewrite import SquinWireToStim, SquinQubitToStim, WrapSquinAnalysis from bloqade.analysis.address import AddressAnalysis from bloqade.squin.analysis.nsites import ( NSitesAnalysis, @@ -37,11 +37,12 @@ def unsafe_run(self, mt: Method) -> RewriteResult: rewrite_result = ( Walk( Chain( - squin_rewrite.WrapSquinAnalysis( + WrapSquinAnalysis( address_analysis=address_frame.entries, op_site_analysis=sites_frame.entries, ), - squin_rewrite._SquinToStim(), + SquinQubitToStim(), + SquinWireToStim(), ) ) .rewrite(mt.code) diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index b79118c3..f3efd62c 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,6 +1,5 @@ -from .stim import ( - _SquinToStim as _SquinToStim, -) +from .wire_to_stim import SquinWireToStim as SquinWireToStim +from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim from .wrap_analysis import ( SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py index e69de29b..c2667ead 100644 --- a/src/bloqade/squin/rewrite/qubit_to_stim.py +++ b/src/bloqade/squin/rewrite/qubit_to_stim.py @@ -0,0 +1,155 @@ +from kirin import ir +from kirin.dialects import py +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import stim +from bloqade.squin import op, qubit +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute +from bloqade.squin.rewrite.stim_rewrite_util import ( + rewrite_Control, + get_stim_1q_gate, + are_sites_compatible, + insert_qubit_idx_from_address, +) + + +class SquinQubitToStim(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + rewrite_methods = { + qubit.Apply: self.rewrite_Apply_and_Broadcast, + qubit.Broadcast: self.rewrite_Apply_and_Broadcast, + qubit.MeasureQubit: self.rewrite_Measure, + qubit.MeasureQubitList: self.rewrite_Measure, + qubit.Reset: self.rewrite_Reset, + qubit.MeasureAndReset: self.rewrite_MeasureAndReset, + } + + rewrite_method = rewrite_methods.get(type(node)) + if rewrite_method is None: + return RewriteResult() + + return rewrite_method(node) + + # handle Control + def rewrite_Apply_and_Broadcast( + self, stmt: qubit.Apply | qubit.Broadcast + ) -> RewriteResult: + """ + Rewrite Apply and Broadcast nodes to their stim equivalent statements. + """ + if not are_sites_compatible(stmt): + return RewriteResult() + + # this is an SSAValue, need it to be the actual operator + applied_op = stmt.operator.owner + assert isinstance(applied_op, op.stmts.Operator) + + if isinstance(applied_op, op.stmts.Control): + return rewrite_Control(stmt) + + # need to handle Control through separate means + # but we can handle X, Y, Z, H, and S here just fine + stim_1q_op = get_stim_1q_gate(applied_op) + if stim_1q_op is None: + return RewriteResult() + + address_attr = stmt.qubits.hints.get("address") + if address_attr is None: + return RewriteResult() + + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + stmt.replace_by(stim_1q_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_Measure( + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList + ) -> RewriteResult: + + # qubit_ssa will always be an ilist of qubits + # but need to be careful with singular vs plural "qubit" attribute name + if isinstance(measure_stmt, qubit.MeasureQubit): + qubit_ssa = measure_stmt.qubit + elif isinstance(measure_stmt, qubit.MeasureQubitList): + qubit_ssa = measure_stmt.qubits + else: + return RewriteResult() + + address_attr = qubit_ssa.hints.get("address") + if address_attr is None: + return RewriteResult() + + assert isinstance(address_attr, AddressAttribute) + + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=measure_stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + prob_noise_stmt = py.constant.Constant(0.0) + stim_measure_stmt = stim.collapse.MZ( + p=prob_noise_stmt.result, + targets=qubit_idx_ssas, + ) + prob_noise_stmt.insert_before(measure_stmt) + stim_measure_stmt.insert_before(measure_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult: + qubit_ilist_ssa = reset_stmt.qubits + # qubits are in an ilist which makes up an AddressTuple + address_attr = qubit_ilist_ssa.hints.get("address") + if address_attr is None: + return RewriteResult() + + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=reset_stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) + reset_stmt.replace_by(stim_rz_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_MeasureAndReset( + self, meas_and_reset_stmt: qubit.MeasureAndReset + ) -> RewriteResult: + + address_attr = meas_and_reset_stmt.qubits.hints.get("address") + if address_attr is None: + return RewriteResult() + + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=meas_and_reset_stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + error_p_stmt = py.Constant(0.0) + stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) + stim_rz_stmt = stim.collapse.RZ( + targets=qubit_idx_ssas, + ) + error_p_stmt.insert_before(meas_and_reset_stmt) + stim_mz_stmt.insert_before(meas_and_reset_stmt) + stim_rz_stmt.insert_before(meas_and_reset_stmt) + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/stim.py b/src/bloqade/squin/rewrite/stim.py deleted file mode 100644 index 38c9e02f..00000000 --- a/src/bloqade/squin/rewrite/stim.py +++ /dev/null @@ -1,276 +0,0 @@ -from dataclasses import dataclass - -from kirin import ir -from kirin.dialects import py -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from bloqade import stim -from bloqade.squin import op, wire, qubit -from bloqade.squin.rewrite.stim_util import ( - get_stim_1q_gate, - verify_num_sites, - insert_qubit_idx_from_address, - insert_qubit_idx_from_wire_ssa, -) -from bloqade.squin.rewrite.wrap_analysis import AddressAttribute - - -@dataclass -class _SquinToStim(RewriteRule): - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - match node: - case wire.Apply() | qubit.Apply() | wire.Broadcast() | qubit.Broadcast(): - verify_num_sites(node) - return self.rewrite_Apply_and_Broadcast(node) - case wire.Wrap(): - return self.rewrite_Wrap(node) - case wire.Measure() | qubit.MeasureQubit() | qubit.MeasureQubitList(): - return self.rewrite_Measure(node) - case wire.Reset() | qubit.Reset(): - return self.rewrite_Reset(node) - case wire.MeasureAndReset() | qubit.MeasureAndReset(): - return self.rewrite_MeasureAndReset(node) - case _: - return RewriteResult() - - def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: - - # get the wire going into the statement - wire_ssa = wrap_stmt.wire - # remove the wrap statement altogether, then the wire that went into it - wrap_stmt.delete() - wire_ssa.delete() - - # do NOT want to delete the qubit SSA! Leave that alone! - return RewriteResult(has_done_something=True) - - def rewrite_Apply_and_Broadcast( - self, apply_stmt: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast - ) -> RewriteResult: - # this is an SSAValue, need it to be the actual operator - applied_op = apply_stmt.operator.owner - assert isinstance(applied_op, op.stmts.Operator) - - if isinstance(applied_op, op.stmts.Control): - return self.rewrite_Control(apply_stmt) - - # need to handle Control through separate means - # but we can handle X, Y, Z, H, and S here just fine - stim_1q_op = get_stim_1q_gate(applied_op) - - if isinstance(apply_stmt, (qubit.Apply, qubit.Broadcast)): - address_attr = apply_stmt.qubits.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=apply_stmt - ) - elif isinstance(apply_stmt, (wire.Apply, wire.Broadcast)): - qubit_idx_ssas = insert_qubit_idx_from_wire_ssa( - wire_ssas=apply_stmt.inputs, stmt_to_insert_before=apply_stmt - ) - else: - raise TypeError( - "Unsupported statement detected, only Apply and Broadcast statements are permitted" - ) - - stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) - stim_1q_stmt.insert_before(apply_stmt) - - # Could I safely delete the apply statements? - # If it's a qubit.Apply or qubit.Broadcast, yes, because it doesn't return anything - # If it's a wire.Apply or wire.Broadcast, no, because the `results` of that Apply/Broadcast get used later on - if isinstance(apply_stmt, (qubit.Apply, qubit.Broadcast)): - apply_stmt.delete() - - return RewriteResult(has_done_something=True) - - def rewrite_Control( - self, - stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast, - ) -> RewriteResult: - """ - Handle control gates for Apply and Broadcast statements. - """ - ctrl_op = stmt_with_ctrl.operator.owner - assert isinstance(ctrl_op, op.stmts.Control) - - ctrl_op_target_gate = ctrl_op.op.owner - assert isinstance(ctrl_op_target_gate, op.stmts.Operator) - - qubit_idx_ssas = self.insert_qubit_idx_after_apply(stmt=stmt_with_ctrl) - - # Separate control and target qubits - target_qubits = [] - ctrl_qubits = [] - for i in range(len(qubit_idx_ssas)): - if (i % 2) == 0: - ctrl_qubits.append(qubit_idx_ssas[i]) - else: - target_qubits.append(qubit_idx_ssas[i]) - - target_qubits = tuple(target_qubits) - ctrl_qubits = tuple(ctrl_qubits) - - # Handle supported gates - match ctrl_op_target_gate: - case op.stmts.X(): - stim_stmt = stim.CX(controls=ctrl_qubits, targets=target_qubits) - case op.stmts.Y(): - stim_stmt = stim.CY(controls=ctrl_qubits, targets=target_qubits) - case op.stmts.Z(): - stim_stmt = stim.CZ(controls=ctrl_qubits, targets=target_qubits) - case _: - raise NotImplementedError( - "Control gates beyond CX, CY, and CZ are not supported" - ) - - stim_stmt.insert_before(stmt_with_ctrl) - - # Delete the original statement if it's a qubit.Apply or qubit.Broadcast - if isinstance(stmt_with_ctrl, (qubit.Apply, qubit.Broadcast)): - stmt_with_ctrl.delete() - - return RewriteResult(has_done_something=True) - - def insert_qubit_idx_after_apply( - self, stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast - ) -> tuple[ir.SSAValue, ...]: - """ - Extract qubit indices from Apply or Broadcast statements. - """ - if isinstance(stmt, (qubit.Apply, qubit.Broadcast)): - qubits = stmt.qubits - address_attribute: AddressAttribute = qubits.hints.get("address") - return insert_qubit_idx_from_address( - address=address_attribute, stmt_to_insert_before=stmt - ) - elif isinstance(stmt, (wire.Apply, wire.Broadcast)): - wire_ssas = stmt.inputs - return insert_qubit_idx_from_wire_ssa( - wire_ssas=wire_ssas, stmt_to_insert_before=stmt - ) - else: - raise TypeError( - "Unsupported statement detected, only Apply and Broadcast statements are supported by this method" - ) - - # qubit.Measure no longer exists, need to handle - # qubit.MeasureQubit and MeasureQubitList - def rewrite_Measure( - self, measure_stmt: wire.Measure | qubit.MeasureQubit | qubit.MeasureQubitList - ) -> RewriteResult: - - match measure_stmt: - case qubit.MeasureQubit(): - qubit_ilist_ssa = measure_stmt.qubit - address_attr = qubit_ilist_ssa.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - case qubit.MeasureQubitList(): - qubit_ssa = measure_stmt.qubits - address_attr = qubit_ssa.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - case wire.Measure(): - wire_ssa = measure_stmt.wire - address_attr = wire_ssa.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - case _: - raise TypeError( - "Unsupported Statement, only qubit.MeasureQubit, qubit.MeasureQubitList, and wire.Measure are supported" - ) - - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=measure_stmt - ) - - prob_noise_stmt = py.constant.Constant(0.0) - stim_measure_stmt = stim.collapse.MZ( - p=prob_noise_stmt.result, - targets=qubit_idx_ssas, - ) - prob_noise_stmt.insert_before(measure_stmt) - stim_measure_stmt.insert_before(measure_stmt) - - return RewriteResult(has_done_something=True) - - def rewrite_Reset(self, reset_stmt: qubit.Reset | wire.Reset) -> RewriteResult: - """ - qubit.Reset(ilist of qubits) -> nothing - # safe to delete the statement afterwards, no depending results - # DCE could probably do this automatically? - - wire.Reset(single wire) -> new wire - # DO NOT DELETE - - # assume RZ, but could extend to RY and RX later - Stim RZ(targets = tuple[int of SSAVals]) - """ - - if isinstance(reset_stmt, qubit.Reset): - qubit_ilist_ssa = reset_stmt.qubits - # qubits are in an ilist which makes up an AddressTuple - address_attr = qubit_ilist_ssa.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=reset_stmt - ) - elif isinstance(reset_stmt, wire.Reset): - address_attr = reset_stmt.wire.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=reset_stmt - ) - else: - raise TypeError( - "unsupported statement, only qubit.Reset and wire.Reset are supported" - ) - - stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) - stim_rz_stmt.insert_before(reset_stmt) - reset_stmt.delete() - - return RewriteResult(has_done_something=True) - - def rewrite_MeasureAndReset( - self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset - ): - """ - qubit.MeasureAndReset(qubits) -> result - Could be translated (roughly equivalent) to - - stim.MZ(tuple[SSAvals for ints]) - stim.RZ(tuple[SSAvals for ints]) - - """ - - if isinstance(meas_and_reset_stmt, qubit.MeasureAndReset): - - address_attr = meas_and_reset_stmt.qubits.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=meas_and_reset_stmt - ) - - elif isinstance(meas_and_reset_stmt, wire.MeasureAndReset): - address_attr = meas_and_reset_stmt.wire.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address_attr, stmt_to_insert_before=meas_and_reset_stmt - ) - - else: - raise TypeError( - "Unsupported statement detected, only qubit.MeasureAndReset and wire.MeasureAndReset are supported" - ) - - error_p_stmt = py.Constant(0.0) - stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) - stim_rz_stmt = stim.collapse.RZ( - targets=qubit_idx_ssas, - ) - error_p_stmt.insert_before(meas_and_reset_stmt) - stim_mz_stmt.insert_before(meas_and_reset_stmt) - stim_rz_stmt.insert_before(meas_and_reset_stmt) - - return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/stim_util.py b/src/bloqade/squin/rewrite/stim_rewrite_util.py similarity index 52% rename from src/bloqade/squin/rewrite/stim_util.py rename to src/bloqade/squin/rewrite/stim_rewrite_util.py index 979271ff..ff5c3d78 100644 --- a/src/bloqade/squin/rewrite/stim_util.py +++ b/src/bloqade/squin/rewrite/stim_rewrite_util.py @@ -2,6 +2,7 @@ from kirin import ir from kirin.dialects import py +from kirin.rewrite.abc import RewriteResult from bloqade import stim from bloqade.squin import op, wire, qubit @@ -14,28 +15,20 @@ def get_stim_1q_gate(squin_op: op.stmts.Operator): """ Map squin 1Q Ops to stim Ops. """ - match squin_op: - case op.stmts.X(): - return stim.gate.X - case op.stmts.Y(): - return stim.gate.Y - case op.stmts.Z(): - return stim.gate.Z - case op.stmts.H(): - return stim.gate.H - case op.stmts.S(): - return stim.gate.S - case op.stmts.Identity(): - return stim.gate.Identity - case _: - raise NotImplementedError( - f"The squin operator {squin_op} is not supported in the stim dialect" - ) + gate_mapping = { + op.stmts.X: stim.gate.X, + op.stmts.Y: stim.gate.Y, + op.stmts.Z: stim.gate.Z, + op.stmts.H: stim.gate.H, + op.stmts.S: stim.gate.S, + op.stmts.Identity: stim.gate.Identity, + } + return gate_mapping.get(type(squin_op)) def insert_qubit_idx_from_address( address: AddressAttribute, stmt_to_insert_before: ir.Statement -) -> tuple[ir.SSAValue, ...]: +) -> tuple[ir.SSAValue, ...] | None: """ Extract qubit indices from an AddressAttribute and insert them into the SSA form. """ @@ -45,9 +38,7 @@ def insert_qubit_idx_from_address( if isinstance(address_data, AddressTuple): for address_qubit in address_data.data: if not isinstance(address_qubit, AddressQubit): - raise ValueError( - "Unsupported Address type detected inside AddressTuple, must be AddressQubit" - ) + return qubit_idx = address_qubit.data qubit_idx_stmt = py.Constant(qubit_idx) qubit_idx_stmt.insert_before(stmt_to_insert_before) @@ -59,22 +50,22 @@ def insert_qubit_idx_from_address( qubit_idx_stmt.insert_before(stmt_to_insert_before) qubit_idx_ssas.append(qubit_idx_stmt.result) else: - raise NotImplementedError( - "qubit idx extraction and insertion only supported for AddressTuple[AddressQubit] and AddressWire instances" - ) + return return tuple(qubit_idx_ssas) def insert_qubit_idx_from_wire_ssa( wire_ssas: tuple[ir.SSAValue, ...], stmt_to_insert_before: ir.Statement -) -> tuple[ir.SSAValue, ...]: +) -> tuple[ir.SSAValue, ...] | None: """ Extract qubit indices from wire SSA values and insert them into the SSA form. """ qubit_idx_ssas = [] for wire_ssa in wire_ssas: address_attribute = wire_ssa.hints.get("address") + if address_attribute is None: + return assert isinstance(address_attribute, AddressAttribute) wire_address = address_attribute.address assert isinstance(wire_address, AddressWire) @@ -86,7 +77,9 @@ def insert_qubit_idx_from_wire_ssa( return tuple(qubit_idx_ssas) -def verify_num_sites(stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast): +def are_sites_compatible( + stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast, +): """ Verify that the number of qubits/wires matches the number of sites supported by the operator. """ @@ -99,15 +92,16 @@ def verify_num_sites(stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Bro assert isinstance(address_tuple, AddressTuple) num_sites_targeted = len(address_tuple.data) else: - raise TypeError( - "Number of sites verification can only occur on Apply or Broadcast statements" - ) + return False op_ssa = stmt.operator op_stmt = op_ssa.owner cast(ir.Statement, op_stmt) sites_attr = op_ssa.hints.get("sites") + if sites_attr is None: + return False + assert isinstance(sites_attr, SitesAttribute) sites_type = sites_attr.sites assert isinstance(sites_type, NumberSites) @@ -115,11 +109,83 @@ def verify_num_sites(stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Bro if isinstance(stmt, (wire.Broadcast, qubit.Broadcast)): if num_sites_targeted % num_sites_supported != 0: - raise ValueError( - "Number of qubits/wires to broadcast to must be a multiple of the number of sites supported by the operator" - ) + return False elif isinstance(stmt, (wire.Apply, qubit.Apply)): if num_sites_targeted != num_sites_supported: - raise ValueError( - "Number of qubits/wires to apply to must match the number of sites supported by the operator" + return False + + return True + + +def insert_qubit_idx_after_apply( + stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast, +) -> tuple[ir.SSAValue, ...] | None: + """ + Extract qubit indices from Apply or Broadcast statements. + """ + if isinstance(stmt, (qubit.Apply, qubit.Broadcast)): + qubits = stmt.qubits + address_attribute = qubits.hints.get("address") + if address_attribute is None: + return + assert isinstance(address_attribute, AddressAttribute) + return insert_qubit_idx_from_address( + address=address_attribute, stmt_to_insert_before=stmt + ) + elif isinstance(stmt, (wire.Apply, wire.Broadcast)): + wire_ssas = stmt.inputs + return insert_qubit_idx_from_wire_ssa( + wire_ssas=wire_ssas, stmt_to_insert_before=stmt + ) + + +def rewrite_Control( + stmt_with_ctrl: qubit.Apply | wire.Apply | qubit.Broadcast | wire.Broadcast, +) -> RewriteResult: + """ + Handle control gates for Apply and Broadcast statements. + """ + ctrl_op = stmt_with_ctrl.operator.owner + assert isinstance(ctrl_op, op.stmts.Control) + + ctrl_op_target_gate = ctrl_op.op.owner + assert isinstance(ctrl_op_target_gate, op.stmts.Operator) + + qubit_idx_ssas = insert_qubit_idx_after_apply(stmt=stmt_with_ctrl) + if qubit_idx_ssas is None: + return RewriteResult() + + # Separate control and target qubits + target_qubits = [] + ctrl_qubits = [] + for i in range(len(qubit_idx_ssas)): + if (i % 2) == 0: + ctrl_qubits.append(qubit_idx_ssas[i]) + else: + target_qubits.append(qubit_idx_ssas[i]) + + target_qubits = tuple(target_qubits) + ctrl_qubits = tuple(ctrl_qubits) + + # Handle supported gates + match ctrl_op_target_gate: + case op.stmts.X(): + stim_stmt = stim.CX(controls=ctrl_qubits, targets=target_qubits) + case op.stmts.Y(): + stim_stmt = stim.CY(controls=ctrl_qubits, targets=target_qubits) + case op.stmts.Z(): + stim_stmt = stim.CZ(controls=ctrl_qubits, targets=target_qubits) + case _: + raise NotImplementedError( + "Control gates beyond CX, CY, and CZ are not supported" ) + + stim_stmt.insert_before(stmt_with_ctrl) + + # Delete the original statement if it's a qubit.Apply or qubit.Broadcast + if isinstance(stmt_with_ctrl, (qubit.Apply, qubit.Broadcast)): + stmt_with_ctrl.delete() + + # need to think about how to handle wire.Apply and wire.Broadcast instance + + return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py index e69de29b..c92df792 100644 --- a/src/bloqade/squin/rewrite/wire_to_stim.py +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -0,0 +1,144 @@ +from kirin import ir +from kirin.dialects import py +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import stim +from bloqade.squin import op, wire +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute +from bloqade.squin.rewrite.stim_rewrite_util import ( + rewrite_Control, + get_stim_1q_gate, + are_sites_compatible, + insert_qubit_idx_from_address, + insert_qubit_idx_from_wire_ssa, +) + + +class SquinWireToStim(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case wire.Apply() | wire.Broadcast(): + are_sites_compatible(node) + return self.rewrite_Apply_and_Broadcast(node) + case wire.Wrap(): + return self.rewrite_Wrap(node) + case wire.Measure(): + return self.rewrite_Measure(node) + case wire.Reset(): + return self.rewrite_Reset(node) + case wire.MeasureAndReset(): + return self.rewrite_MeasureAndReset(node) + case _: + return RewriteResult() + + def rewrite_Apply_and_Broadcast( + self, stmt: wire.Apply | wire.Broadcast + ) -> RewriteResult: + + if not are_sites_compatible(stmt): + return RewriteResult() + + # this is an SSAValue, need it to be the actual operator + applied_op = stmt.operator.owner + assert isinstance(applied_op, op.stmts.Operator) + + if isinstance(applied_op, op.stmts.Control): + return rewrite_Control(stmt) + + stim_1q_op = get_stim_1q_gate(applied_op) + if stim_1q_op is None: + return RewriteResult() + + qubit_idx_ssas = insert_qubit_idx_from_wire_ssa( + wire_ssas=stmt.inputs, stmt_to_insert_before=stmt + ) + if qubit_idx_ssas is None: + return RewriteResult() + + stim_1q_stmt = stim_1q_op(targets=tuple(qubit_idx_ssas)) + + # Get the wires from the inputs of Apply or Broadcast, + # then put those as the result of the current stmt + # before replacing it entirely + for input_wire, output_wire in zip(stmt.inputs, stmt.results): + output_wire.replace_by(input_wire) + + stmt.replace_by(stim_1q_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: + + wire_origin_stmt = wrap_stmt.wire.owner + if isinstance(wire_origin_stmt, wire.Unwrap): + wire_origin_stmt.delete() + wrap_stmt.delete() + return RewriteResult(has_done_something=True) + + return RewriteResult() + + def rewrite_Measure(self, measure_stmt: wire.Measure) -> RewriteResult: + + wire_ssa = measure_stmt.wire + address_attr = wire_ssa.hints.get("address") + if address_attr is None: + return RewriteResult() + assert isinstance(address_attr, AddressAttribute) + + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=measure_stmt + ) + + if qubit_idx_ssas is None: + return RewriteResult() + + prob_noise_stmt = py.constant.Constant(0.0) + stim_measure_stmt = stim.collapse.MZ( + p=prob_noise_stmt.result, + targets=qubit_idx_ssas, + ) + prob_noise_stmt.insert_before(measure_stmt) + stim_measure_stmt.insert_before(measure_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult: + address_attr = reset_stmt.wire.hints.get("address") + if address_attr is None: + return RewriteResult() + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=reset_stmt + ) + if qubit_idx_ssas is None: + return RewriteResult() + + stim_rz_stmt = stim.collapse.stmts.RZ(targets=qubit_idx_ssas) + reset_stmt.replace_by(stim_rz_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset): + + address_attr = meas_and_reset_stmt.wire.hints.get("address") + if address_attr is None: + return RewriteResult() + assert isinstance(address_attr, AddressAttribute) + qubit_idx_ssas = insert_qubit_idx_from_address( + address_attr, stmt_to_insert_before=meas_and_reset_stmt + ) + if qubit_idx_ssas is None: + return RewriteResult() + + error_p_stmt = py.Constant(0.0) + stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) + stim_rz_stmt = stim.collapse.RZ( + targets=qubit_idx_ssas, + ) + error_p_stmt.insert_before(meas_and_reset_stmt) + stim_mz_stmt.insert_before(meas_and_reset_stmt) + stim_rz_stmt.insert_before(meas_and_reset_stmt) + + return RewriteResult(has_done_something=True) From cab4ab698579d3b1d34cf485751dc099384457ab Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 08:44:24 -0400 Subject: [PATCH 22/29] get control statement logic to work in wire dialect, simplify statement replacement --- .../squin/rewrite/stim_rewrite_util.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/bloqade/squin/rewrite/stim_rewrite_util.py b/src/bloqade/squin/rewrite/stim_rewrite_util.py index ff5c3d78..37e02d62 100644 --- a/src/bloqade/squin/rewrite/stim_rewrite_util.py +++ b/src/bloqade/squin/rewrite/stim_rewrite_util.py @@ -167,25 +167,26 @@ def rewrite_Control( target_qubits = tuple(target_qubits) ctrl_qubits = tuple(ctrl_qubits) - # Handle supported gates - match ctrl_op_target_gate: - case op.stmts.X(): - stim_stmt = stim.CX(controls=ctrl_qubits, targets=target_qubits) - case op.stmts.Y(): - stim_stmt = stim.CY(controls=ctrl_qubits, targets=target_qubits) - case op.stmts.Z(): - stim_stmt = stim.CZ(controls=ctrl_qubits, targets=target_qubits) - case _: - raise NotImplementedError( - "Control gates beyond CX, CY, and CZ are not supported" - ) - - stim_stmt.insert_before(stmt_with_ctrl) - - # Delete the original statement if it's a qubit.Apply or qubit.Broadcast - if isinstance(stmt_with_ctrl, (qubit.Apply, qubit.Broadcast)): - stmt_with_ctrl.delete() - - # need to think about how to handle wire.Apply and wire.Broadcast instance + supported_gate_mapping = { + op.stmts.X: stim.CX, + op.stmts.Y: stim.CY, + op.stmts.Z: stim.CZ, + } + + stim_gate = supported_gate_mapping.get(type(ctrl_op_target_gate)) + if stim_gate is None: + return RewriteResult() + + stim_stmt = stim_gate(controls=ctrl_qubits, targets=target_qubits) + + if isinstance(stmt_with_ctrl, (wire.Apply, wire.Broadcast)): + # have to "reroute" the input of these statements to directly plug in + # to subsequent statements, remove dependency on the current statement + for input_wire, output_wire in zip( + stmt_with_ctrl.inputs, stmt_with_ctrl.results + ): + output_wire.replace_by(input_wire) + + stmt_with_ctrl.replace_by(stim_stmt) return RewriteResult(has_done_something=True) From c8105446a70a1ef4418cfe594f90e17ae5952280 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 08:46:09 -0400 Subject: [PATCH 23/29] use dict instead of match, just care about type comparison --- src/bloqade/squin/rewrite/wire_to_stim.py | 28 +++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py index c92df792..7e29e04d 100644 --- a/src/bloqade/squin/rewrite/wire_to_stim.py +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -18,20 +18,20 @@ class SquinWireToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - match node: - case wire.Apply() | wire.Broadcast(): - are_sites_compatible(node) - return self.rewrite_Apply_and_Broadcast(node) - case wire.Wrap(): - return self.rewrite_Wrap(node) - case wire.Measure(): - return self.rewrite_Measure(node) - case wire.Reset(): - return self.rewrite_Reset(node) - case wire.MeasureAndReset(): - return self.rewrite_MeasureAndReset(node) - case _: - return RewriteResult() + rewrite_methods = { + wire.Apply: self.rewrite_Apply_and_Broadcast, + wire.Broadcast: self.rewrite_Apply_and_Broadcast, + wire.Wrap: self.rewrite_Wrap, + wire.Measure: self.rewrite_Measure, + wire.Reset: self.rewrite_Reset, + wire.MeasureAndReset: self.rewrite_MeasureAndReset, + } + + rewrite_method = rewrite_methods.get(type(node)) + if rewrite_method is None: + return RewriteResult() + + return rewrite_method(node) def rewrite_Apply_and_Broadcast( self, stmt: wire.Apply | wire.Broadcast From fb9ff37dff6f84fa735545a560d5ec5f1d68f728 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 10:11:40 -0400 Subject: [PATCH 24/29] use replace_by, fix analysis impl --- src/bloqade/squin/analysis/nsites/impls.py | 8 ++- src/bloqade/squin/rewrite/qubit_to_stim.py | 5 +- src/bloqade/squin/rewrite/wire_to_stim.py | 10 ++- test/squin/stim/stim.py | 78 ++++++++++------------ 4 files changed, 54 insertions(+), 47 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index cb44a81e..7d4e0610 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -13,7 +13,13 @@ class SquinWire(interp.MethodTable): @interp.impl(wire.Apply) - def apply(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: wire.Apply): + @interp.impl(wire.Broadcast) + def apply( + self, + interp: NSitesAnalysis, + frame: interp.Frame, + stmt: wire.Apply | wire.Broadcast, + ): return tuple([frame.get(input) for input in stmt.inputs]) diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py index c2667ead..e1bebf9d 100644 --- a/src/bloqade/squin/rewrite/qubit_to_stim.py +++ b/src/bloqade/squin/rewrite/qubit_to_stim.py @@ -103,7 +103,8 @@ def rewrite_Measure( targets=qubit_idx_ssas, ) prob_noise_stmt.insert_before(measure_stmt) - stim_measure_stmt.insert_before(measure_stmt) + # assume properly structured program + measure_stmt.replace_by(stim_measure_stmt) return RewriteResult(has_done_something=True) @@ -150,6 +151,6 @@ def rewrite_MeasureAndReset( ) error_p_stmt.insert_before(meas_and_reset_stmt) stim_mz_stmt.insert_before(meas_and_reset_stmt) - stim_rz_stmt.insert_before(meas_and_reset_stmt) + meas_and_reset_stmt.replace_by(stim_rz_stmt) return RewriteResult(has_done_something=True) diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py index 7e29e04d..a2ec9b9f 100644 --- a/src/bloqade/squin/rewrite/wire_to_stim.py +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -71,9 +71,12 @@ def rewrite_Apply_and_Broadcast( def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: + # structure at this point should be: + ## w = wire.Unwrap(wire) + ## wire.Wrap(qubit, w) + wire_origin_stmt = wrap_stmt.wire.owner if isinstance(wire_origin_stmt, wire.Unwrap): - wire_origin_stmt.delete() wrap_stmt.delete() return RewriteResult(has_done_something=True) @@ -100,7 +103,8 @@ def rewrite_Measure(self, measure_stmt: wire.Measure) -> RewriteResult: targets=qubit_idx_ssas, ) prob_noise_stmt.insert_before(measure_stmt) - stim_measure_stmt.insert_before(measure_stmt) + # stim_measure_stmt.insert_before(measure_stmt) + measure_stmt.replace_by(stim_measure_stmt) return RewriteResult(has_done_something=True) @@ -139,6 +143,6 @@ def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset): ) error_p_stmt.insert_before(meas_and_reset_stmt) stim_mz_stmt.insert_before(meas_and_reset_stmt) - stim_rz_stmt.insert_before(meas_and_reset_stmt) + meas_and_reset_stmt.replace_by(stim_rz_stmt) return RewriteResult(has_done_something=True) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 1f8fd4aa..3a38e21a 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -1,4 +1,3 @@ -import pytest from kirin import ir, types from kirin.passes import Fold from kirin.dialects import py, func, ilist @@ -40,6 +39,39 @@ def gen_func_from_stmts(stmts, output=types.NoneType): return constructed_method +def test_wire_1q_singular_apply(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(1)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubit out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + # Unwrap to get wires + (w0 := squin.wire.Unwrap(qubit=q0.result)), + # pass the wires through some 1 Qubit operators + (op1 := squin.op.stmts.S()), + (v0 := squin.wire.Apply(op1.result, w0.result)), + ( + squin.wire.Wrap(v0.results[0], q0.result) + ), # for wrap, just free a use for the result SSAval + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + # the fact I return a wire here means DCE will NOT go ahead and + # eliminate all the other wire.Apply stmts + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + def test_wire_1q(): stmts: list[ir.Statement] = [ @@ -185,7 +217,7 @@ def test_broadcast_control_gate_wire_application(): (x_op := squin.op.stmts.X()), (ctrl_x_op := squin.op.stmts.Control(x_op.result, n_controls=1)), ( - app_res := squin.wire.Apply( + app_res := squin.wire.Broadcast( ctrl_x_op.result, w0.result, w1.result, w2.result, w3.result ) ), @@ -242,6 +274,9 @@ def test_wire_control(): constructed_method.print() +# Measure being depended on, internal replace_by call +# will not be happy but assumption with rewrite is the +# program is in a valid form def test_wire_measure(): stmts: list[ir.Statement] = [ @@ -383,42 +418,3 @@ def test_wire_measure_and_reset(): rewrite_result = squin_to_stim(constructed_method) print(rewrite_result) constructed_method.print() - - -def test_wire_apply_site_verification(): - - stmts: list[ir.Statement] = [ - # Create qubit register - (n_qubits := as_int(3)), - (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), - # Get qubis out - (idx0 := as_int(0)), - (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), - (idx1 := as_int(1)), - (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), - (idx2 := as_int(2)), - (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), - # Unwrap to get wires - (w0 := squin.wire.Unwrap(qubit=q0.result)), - (w1 := squin.wire.Unwrap(qubit=q1.result)), - (w2 := squin.wire.Unwrap(qubit=q2.result)), - # set up control gate - (op1 := squin.op.stmts.X()), - (cx := squin.op.stmts.Control(op1.result, n_controls=1)), - # improper application, cx should only support 2 sites - (app := squin.wire.Apply(cx.result, w0.result, w1.result, w2.result)), - # wrap things back - (squin.wire.Wrap(wire=app.results[0], qubit=q0.result)), - (squin.wire.Wrap(wire=app.results[1], qubit=q1.result)), - (squin.wire.Wrap(wire=app.results[2], qubit=q2.result)), - (ret_none := func.ConstantNone()), - (func.Return(ret_none)), - ] - - constructed_method = gen_func_from_stmts(stmts) - constructed_method.print() - - squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) - - with pytest.raises(ValueError): - squin_to_stim(constructed_method) From e6c92b78bee9db2790b4795f8c65dcee338d54f4 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 13:29:27 -0400 Subject: [PATCH 25/29] first round of meeting feedback implemented --- src/bloqade/squin/passes/stim.py | 8 ++- src/bloqade/squin/rewrite/__init__.py | 3 ++ src/bloqade/squin/rewrite/qubit_to_stim.py | 18 +++++-- .../squin/rewrite/stim_rewrite_util.py | 37 +++++++++----- .../rewrite/wire_identity_elimination.py | 24 +++++++++ src/bloqade/squin/rewrite/wire_to_stim.py | 51 ++++++++----------- test/squin/stim/stim.py | 10 ++++ 7 files changed, 103 insertions(+), 48 deletions(-) create mode 100644 src/bloqade/squin/rewrite/wire_identity_elimination.py diff --git a/src/bloqade/squin/passes/stim.py b/src/bloqade/squin/passes/stim.py index d8b80635..85cf477d 100644 --- a/src/bloqade/squin/passes/stim.py +++ b/src/bloqade/squin/passes/stim.py @@ -12,7 +12,12 @@ from kirin.passes.abc import Pass from kirin.rewrite.abc import RewriteResult -from bloqade.squin.rewrite import SquinWireToStim, SquinQubitToStim, WrapSquinAnalysis +from bloqade.squin.rewrite import ( + SquinWireToStim, + SquinQubitToStim, + WrapSquinAnalysis, + SquinWireIdentityElimination, +) from bloqade.analysis.address import AddressAnalysis from bloqade.squin.analysis.nsites import ( NSitesAnalysis, @@ -43,6 +48,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult: ), SquinQubitToStim(), SquinWireToStim(), + SquinWireIdentityElimination(), ) ) .rewrite(mt.code) diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index f3efd62c..fad9ad29 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -5,3 +5,6 @@ AddressAttribute as AddressAttribute, WrapSquinAnalysis as WrapSquinAnalysis, ) +from .wire_identity_elimination import ( + SquinWireIdentityElimination as SquinWireIdentityElimination, +) diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py index e1bebf9d..763047ee 100644 --- a/src/bloqade/squin/rewrite/qubit_to_stim.py +++ b/src/bloqade/squin/rewrite/qubit_to_stim.py @@ -6,9 +6,10 @@ from bloqade.squin import op, qubit from bloqade.squin.rewrite.wrap_analysis import AddressAttribute from bloqade.squin.rewrite.stim_rewrite_util import ( + SQUIN_STIM_GATE_MAPPING, rewrite_Control, - get_stim_1q_gate, are_sites_compatible, + is_measure_result_used, insert_qubit_idx_from_address, ) @@ -16,6 +17,8 @@ class SquinQubitToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + # don't want to alloc dict, change back to if else/match case rewrite_methods = { qubit.Apply: self.rewrite_Apply_and_Broadcast, qubit.Broadcast: self.rewrite_Apply_and_Broadcast, @@ -31,13 +34,13 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return rewrite_method(node) - # handle Control def rewrite_Apply_and_Broadcast( self, stmt: qubit.Apply | qubit.Broadcast ) -> RewriteResult: """ Rewrite Apply and Broadcast nodes to their stim equivalent statements. """ + # get rid of are_sites_compatible, assume program is properly structured if not are_sites_compatible(stmt): return RewriteResult() @@ -50,7 +53,7 @@ def rewrite_Apply_and_Broadcast( # need to handle Control through separate means # but we can handle X, Y, Z, H, and S here just fine - stim_1q_op = get_stim_1q_gate(applied_op) + stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op)) if stim_1q_op is None: return RewriteResult() @@ -75,6 +78,9 @@ def rewrite_Measure( self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList ) -> RewriteResult: + if is_measure_result_used(measure_stmt): + return RewriteResult() + # qubit_ssa will always be an ilist of qubits # but need to be careful with singular vs plural "qubit" attribute name if isinstance(measure_stmt, qubit.MeasureQubit): @@ -132,6 +138,9 @@ def rewrite_MeasureAndReset( self, meas_and_reset_stmt: qubit.MeasureAndReset ) -> RewriteResult: + if is_measure_result_used(meas_and_reset_stmt): + return RewriteResult() + address_attr = meas_and_reset_stmt.qubits.hints.get("address") if address_attr is None: return RewriteResult() @@ -154,3 +163,6 @@ def rewrite_MeasureAndReset( meas_and_reset_stmt.replace_by(stim_rz_stmt) return RewriteResult(has_done_something=True) + + +# put rewrites for measure statements in separate rule, then just have to dispatch diff --git a/src/bloqade/squin/rewrite/stim_rewrite_util.py b/src/bloqade/squin/rewrite/stim_rewrite_util.py index 37e02d62..2bf42c10 100644 --- a/src/bloqade/squin/rewrite/stim_rewrite_util.py +++ b/src/bloqade/squin/rewrite/stim_rewrite_util.py @@ -10,20 +10,14 @@ from bloqade.squin.analysis.nsites import NumberSites from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute - -def get_stim_1q_gate(squin_op: op.stmts.Operator): - """ - Map squin 1Q Ops to stim Ops. - """ - gate_mapping = { - op.stmts.X: stim.gate.X, - op.stmts.Y: stim.gate.Y, - op.stmts.Z: stim.gate.Z, - op.stmts.H: stim.gate.H, - op.stmts.S: stim.gate.S, - op.stmts.Identity: stim.gate.Identity, - } - return gate_mapping.get(type(squin_op)) +SQUIN_STIM_GATE_MAPPING = { + op.stmts.X: stim.gate.X, + op.stmts.Y: stim.gate.Y, + op.stmts.Z: stim.gate.Z, + op.stmts.H: stim.gate.H, + op.stmts.S: stim.gate.S, + op.stmts.Identity: stim.gate.Identity, +} def insert_qubit_idx_from_address( @@ -190,3 +184,18 @@ def rewrite_Control( stmt_with_ctrl.replace_by(stim_stmt) return RewriteResult(has_done_something=True) + + +def is_measure_result_used( + stmt: ( + qubit.MeasureAndReset + | qubit.MeasureQubit + | qubit.MeasureQubitList + | wire.MeasureAndReset + | wire.Measure + ), +) -> bool: + """ + Check if the result of a measure statement is used in the program. + """ + return bool(stmt.result.uses) diff --git a/src/bloqade/squin/rewrite/wire_identity_elimination.py b/src/bloqade/squin/rewrite/wire_identity_elimination.py new file mode 100644 index 00000000..a9dcc837 --- /dev/null +++ b/src/bloqade/squin/rewrite/wire_identity_elimination.py @@ -0,0 +1,24 @@ +from kirin import ir +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade.squin import wire + + +class SquinWireIdentityElimination(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + """ + Handle the case where an unwrap feeds a wire directly into a wrap, + equivalent to nothing happening/identity operation + + w = unwrap(qubit) + wrap(qubit, w) + """ + if isinstance(node, wire.Wrap): + wire_origin_stmt = node.wire.owner + if isinstance(wire_origin_stmt, wire.Unwrap): + node.delete() # get rid of wrap + wire_origin_stmt.delete() # get rid of the unwrap + return RewriteResult(has_done_something=True) + + return RewriteResult() diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py index a2ec9b9f..00612adc 100644 --- a/src/bloqade/squin/rewrite/wire_to_stim.py +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -6,9 +6,10 @@ from bloqade.squin import op, wire from bloqade.squin.rewrite.wrap_analysis import AddressAttribute from bloqade.squin.rewrite.stim_rewrite_util import ( + SQUIN_STIM_GATE_MAPPING, rewrite_Control, - get_stim_1q_gate, are_sites_compatible, + is_measure_result_used, insert_qubit_idx_from_address, insert_qubit_idx_from_wire_ssa, ) @@ -17,21 +18,17 @@ class SquinWireToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - rewrite_methods = { - wire.Apply: self.rewrite_Apply_and_Broadcast, - wire.Broadcast: self.rewrite_Apply_and_Broadcast, - wire.Wrap: self.rewrite_Wrap, - wire.Measure: self.rewrite_Measure, - wire.Reset: self.rewrite_Reset, - wire.MeasureAndReset: self.rewrite_MeasureAndReset, - } - - rewrite_method = rewrite_methods.get(type(node)) - if rewrite_method is None: - return RewriteResult() - - return rewrite_method(node) + match node: + case wire.Apply() | wire.Broadcast(): + return self.rewrite_Apply_and_Broadcast(node) + case wire.Measure(): + return self.rewrite_Measure(node) + case wire.Reset(): + return self.rewrite_Reset(node) + case wire.MeasureAndReset(): + return self.rewrite_MeasureAndReset(node) + case _: + return RewriteResult() def rewrite_Apply_and_Broadcast( self, stmt: wire.Apply | wire.Broadcast @@ -47,7 +44,7 @@ def rewrite_Apply_and_Broadcast( if isinstance(applied_op, op.stmts.Control): return rewrite_Control(stmt) - stim_1q_op = get_stim_1q_gate(applied_op) + stim_1q_op = SQUIN_STIM_GATE_MAPPING.get(type(applied_op)) if stim_1q_op is None: return RewriteResult() @@ -69,21 +66,11 @@ def rewrite_Apply_and_Broadcast( return RewriteResult(has_done_something=True) - def rewrite_Wrap(self, wrap_stmt: wire.Wrap) -> RewriteResult: - - # structure at this point should be: - ## w = wire.Unwrap(wire) - ## wire.Wrap(qubit, w) - - wire_origin_stmt = wrap_stmt.wire.owner - if isinstance(wire_origin_stmt, wire.Unwrap): - wrap_stmt.delete() - return RewriteResult(has_done_something=True) - - return RewriteResult() - def rewrite_Measure(self, measure_stmt: wire.Measure) -> RewriteResult: + if is_measure_result_used(measure_stmt): + return RewriteResult() + wire_ssa = measure_stmt.wire address_attr = wire_ssa.hints.get("address") if address_attr is None: @@ -126,6 +113,9 @@ def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult: def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset): + if is_measure_result_used(meas_and_reset_stmt): + return RewriteResult() + address_attr = meas_and_reset_stmt.wire.hints.get("address") if address_attr is None: return RewriteResult() @@ -141,6 +131,7 @@ def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset): stim_rz_stmt = stim.collapse.RZ( targets=qubit_idx_ssas, ) + error_p_stmt.insert_before(meas_and_reset_stmt) stim_mz_stmt.insert_before(meas_and_reset_stmt) meas_and_reset_stmt.replace_by(stim_rz_stmt) diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 3a38e21a..55b3aee1 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -157,6 +157,16 @@ def test_broadcast_wire_1q_application(): constructed_method.print() +# before ANY rewrite, aggressively inline everything, then do the rewrite +# for Stim pass, need to call validation , check any invoke + +# Put one codegen test to stim +# finish measurement analysis Friday - if painful, ask help from Kai +# work on other detector rewrite + +# later on lower for loop to repeat + + def test_broadcast_qubit_1q_application(): stmts: list[ir.Statement] = [ From 9316a0eb1822e3321a4a119f17e291d60747b2ef Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 15:31:43 -0400 Subject: [PATCH 26/29] split out Measure rewrite into its own rule --- src/bloqade/squin/passes/stim.py | 2 + src/bloqade/squin/rewrite/__init__.py | 1 + src/bloqade/squin/rewrite/qubit_to_stim.py | 94 ++------------------- src/bloqade/squin/rewrite/squin_measure.py | 96 ++++++++++++++++++++++ src/bloqade/squin/rewrite/wire_to_stim.py | 62 -------------- 5 files changed, 106 insertions(+), 149 deletions(-) create mode 100644 src/bloqade/squin/rewrite/squin_measure.py diff --git a/src/bloqade/squin/passes/stim.py b/src/bloqade/squin/passes/stim.py index 85cf477d..3c172fb5 100644 --- a/src/bloqade/squin/passes/stim.py +++ b/src/bloqade/squin/passes/stim.py @@ -16,6 +16,7 @@ SquinWireToStim, SquinQubitToStim, WrapSquinAnalysis, + SquinMeasureToStim, SquinWireIdentityElimination, ) from bloqade.analysis.address import AddressAnalysis @@ -48,6 +49,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult: ), SquinQubitToStim(), SquinWireToStim(), + SquinMeasureToStim(), SquinWireIdentityElimination(), ) ) diff --git a/src/bloqade/squin/rewrite/__init__.py b/src/bloqade/squin/rewrite/__init__.py index fad9ad29..1280ef2c 100644 --- a/src/bloqade/squin/rewrite/__init__.py +++ b/src/bloqade/squin/rewrite/__init__.py @@ -1,5 +1,6 @@ from .wire_to_stim import SquinWireToStim as SquinWireToStim from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim +from .squin_measure import SquinMeasureToStim as SquinMeasureToStim from .wrap_analysis import ( SitesAttribute as SitesAttribute, AddressAttribute as AddressAttribute, diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py index 763047ee..1b722f83 100644 --- a/src/bloqade/squin/rewrite/qubit_to_stim.py +++ b/src/bloqade/squin/rewrite/qubit_to_stim.py @@ -1,5 +1,4 @@ from kirin import ir -from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import stim @@ -9,7 +8,6 @@ SQUIN_STIM_GATE_MAPPING, rewrite_Control, are_sites_compatible, - is_measure_result_used, insert_qubit_idx_from_address, ) @@ -18,21 +16,13 @@ class SquinQubitToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - # don't want to alloc dict, change back to if else/match case - rewrite_methods = { - qubit.Apply: self.rewrite_Apply_and_Broadcast, - qubit.Broadcast: self.rewrite_Apply_and_Broadcast, - qubit.MeasureQubit: self.rewrite_Measure, - qubit.MeasureQubitList: self.rewrite_Measure, - qubit.Reset: self.rewrite_Reset, - qubit.MeasureAndReset: self.rewrite_MeasureAndReset, - } - - rewrite_method = rewrite_methods.get(type(node)) - if rewrite_method is None: - return RewriteResult() - - return rewrite_method(node) + match node: + case qubit.Apply() | qubit.Broadcast(): + return self.rewrite_Apply_and_Broadcast(node) + case qubit.Reset(): + return self.rewrite_Reset(node) + case _: + return RewriteResult() def rewrite_Apply_and_Broadcast( self, stmt: qubit.Apply | qubit.Broadcast @@ -74,46 +64,6 @@ def rewrite_Apply_and_Broadcast( return RewriteResult(has_done_something=True) - def rewrite_Measure( - self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList - ) -> RewriteResult: - - if is_measure_result_used(measure_stmt): - return RewriteResult() - - # qubit_ssa will always be an ilist of qubits - # but need to be careful with singular vs plural "qubit" attribute name - if isinstance(measure_stmt, qubit.MeasureQubit): - qubit_ssa = measure_stmt.qubit - elif isinstance(measure_stmt, qubit.MeasureQubitList): - qubit_ssa = measure_stmt.qubits - else: - return RewriteResult() - - address_attr = qubit_ssa.hints.get("address") - if address_attr is None: - return RewriteResult() - - assert isinstance(address_attr, AddressAttribute) - - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=measure_stmt - ) - - if qubit_idx_ssas is None: - return RewriteResult() - - prob_noise_stmt = py.constant.Constant(0.0) - stim_measure_stmt = stim.collapse.MZ( - p=prob_noise_stmt.result, - targets=qubit_idx_ssas, - ) - prob_noise_stmt.insert_before(measure_stmt) - # assume properly structured program - measure_stmt.replace_by(stim_measure_stmt) - - return RewriteResult(has_done_something=True) - def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult: qubit_ilist_ssa = reset_stmt.qubits # qubits are in an ilist which makes up an AddressTuple @@ -134,35 +84,5 @@ def rewrite_Reset(self, reset_stmt: qubit.Reset) -> RewriteResult: return RewriteResult(has_done_something=True) - def rewrite_MeasureAndReset( - self, meas_and_reset_stmt: qubit.MeasureAndReset - ) -> RewriteResult: - - if is_measure_result_used(meas_and_reset_stmt): - return RewriteResult() - - address_attr = meas_and_reset_stmt.qubits.hints.get("address") - if address_attr is None: - return RewriteResult() - - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=meas_and_reset_stmt - ) - - if qubit_idx_ssas is None: - return RewriteResult() - - error_p_stmt = py.Constant(0.0) - stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) - stim_rz_stmt = stim.collapse.RZ( - targets=qubit_idx_ssas, - ) - error_p_stmt.insert_before(meas_and_reset_stmt) - stim_mz_stmt.insert_before(meas_and_reset_stmt) - meas_and_reset_stmt.replace_by(stim_rz_stmt) - - return RewriteResult(has_done_something=True) - # put rewrites for measure statements in separate rule, then just have to dispatch diff --git a/src/bloqade/squin/rewrite/squin_measure.py b/src/bloqade/squin/rewrite/squin_measure.py new file mode 100644 index 00000000..1efd71e4 --- /dev/null +++ b/src/bloqade/squin/rewrite/squin_measure.py @@ -0,0 +1,96 @@ +# create rewrite rule name SquinMeasureToStim using kirin +from kirin import ir +from kirin.dialects import py +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from bloqade import stim +from bloqade.squin import wire, qubit +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute +from bloqade.squin.rewrite.stim_rewrite_util import ( + is_measure_result_used, + insert_qubit_idx_from_address, +) + + +class SquinMeasureToStim(RewriteRule): + """ + Rewrite squin measure-related statements to stim statements. + """ + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + + match node: + case qubit.MeasureQubit() | qubit.MeasureQubitList() | wire.Measure(): + return self.rewrite_Measure(node) + case qubit.MeasureAndReset() | wire.MeasureAndReset(): + return self.rewrite_MeasureAndReset(node) + case _: + return RewriteResult() + + def rewrite_Measure( + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure + ) -> RewriteResult: + if not is_measure_result_used(measure_stmt): + return RewriteResult() + + qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) + if qubit_idx_ssas is None: + return RewriteResult() + + prob_noise_stmt = py.constant.Constant(0.0) + stim_measure_stmt = stim.collapse.MZ( + p=prob_noise_stmt.result, + targets=qubit_idx_ssas, + ) + prob_noise_stmt.insert_before(measure_stmt) + measure_stmt.replace_by(stim_measure_stmt) + + return RewriteResult(has_done_something=True) + + def rewrite_MeasureAndReset( + self, meas_and_reset_stmt: qubit.MeasureAndReset | wire.MeasureAndReset + ) -> RewriteResult: + if not is_measure_result_used(meas_and_reset_stmt): + return RewriteResult() + + qubit_idx_ssas = self.get_qubit_idx_ssas(meas_and_reset_stmt) + + if qubit_idx_ssas is None: + return RewriteResult() + + error_p_stmt = py.Constant(0.0) + stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) + stim_rz_stmt = stim.collapse.RZ( + targets=qubit_idx_ssas, + ) + + error_p_stmt.insert_before(meas_and_reset_stmt) + stim_mz_stmt.insert_before(meas_and_reset_stmt) + meas_and_reset_stmt.replace_by(stim_rz_stmt) + + return RewriteResult(has_done_something=True) + + def get_qubit_idx_ssas(self, measure_stmt) -> tuple[ir.SSAValue, ...] | None: + """ + Extract the address attribute and insert qubit indices for the given measure statement. + """ + match measure_stmt: + case qubit.MeasureQubit(): + address_attr = measure_stmt.qubit.hints.get("address") + case qubit.MeasureQubitList(): + address_attr = measure_stmt.qubits.hints.get("address") + case wire.Measure(): + address_attr = measure_stmt.wire.hints.get("address") + case _: + return None + + if address_attr is None: + return None + + assert isinstance(address_attr, AddressAttribute) + + qubit_idx_ssas = insert_qubit_idx_from_address( + address=address_attr, stmt_to_insert_before=measure_stmt + ) + + return qubit_idx_ssas diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py index 00612adc..c71cd462 100644 --- a/src/bloqade/squin/rewrite/wire_to_stim.py +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -1,5 +1,4 @@ from kirin import ir -from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult from bloqade import stim @@ -9,7 +8,6 @@ SQUIN_STIM_GATE_MAPPING, rewrite_Control, are_sites_compatible, - is_measure_result_used, insert_qubit_idx_from_address, insert_qubit_idx_from_wire_ssa, ) @@ -21,12 +19,8 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: case wire.Apply() | wire.Broadcast(): return self.rewrite_Apply_and_Broadcast(node) - case wire.Measure(): - return self.rewrite_Measure(node) case wire.Reset(): return self.rewrite_Reset(node) - case wire.MeasureAndReset(): - return self.rewrite_MeasureAndReset(node) case _: return RewriteResult() @@ -66,35 +60,6 @@ def rewrite_Apply_and_Broadcast( return RewriteResult(has_done_something=True) - def rewrite_Measure(self, measure_stmt: wire.Measure) -> RewriteResult: - - if is_measure_result_used(measure_stmt): - return RewriteResult() - - wire_ssa = measure_stmt.wire - address_attr = wire_ssa.hints.get("address") - if address_attr is None: - return RewriteResult() - assert isinstance(address_attr, AddressAttribute) - - qubit_idx_ssas = insert_qubit_idx_from_address( - address=address_attr, stmt_to_insert_before=measure_stmt - ) - - if qubit_idx_ssas is None: - return RewriteResult() - - prob_noise_stmt = py.constant.Constant(0.0) - stim_measure_stmt = stim.collapse.MZ( - p=prob_noise_stmt.result, - targets=qubit_idx_ssas, - ) - prob_noise_stmt.insert_before(measure_stmt) - # stim_measure_stmt.insert_before(measure_stmt) - measure_stmt.replace_by(stim_measure_stmt) - - return RewriteResult(has_done_something=True) - def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult: address_attr = reset_stmt.wire.hints.get("address") if address_attr is None: @@ -110,30 +75,3 @@ def rewrite_Reset(self, reset_stmt: wire.Reset) -> RewriteResult: reset_stmt.replace_by(stim_rz_stmt) return RewriteResult(has_done_something=True) - - def rewrite_MeasureAndReset(self, meas_and_reset_stmt: wire.MeasureAndReset): - - if is_measure_result_used(meas_and_reset_stmt): - return RewriteResult() - - address_attr = meas_and_reset_stmt.wire.hints.get("address") - if address_attr is None: - return RewriteResult() - assert isinstance(address_attr, AddressAttribute) - qubit_idx_ssas = insert_qubit_idx_from_address( - address_attr, stmt_to_insert_before=meas_and_reset_stmt - ) - if qubit_idx_ssas is None: - return RewriteResult() - - error_p_stmt = py.Constant(0.0) - stim_mz_stmt = stim.collapse.MZ(targets=qubit_idx_ssas, p=error_p_stmt.result) - stim_rz_stmt = stim.collapse.RZ( - targets=qubit_idx_ssas, - ) - - error_p_stmt.insert_before(meas_and_reset_stmt) - stim_mz_stmt.insert_before(meas_and_reset_stmt) - meas_and_reset_stmt.replace_by(stim_rz_stmt) - - return RewriteResult(has_done_something=True) From e3b478a58b07d82ab97d9f87e0e400e5d55e110d Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 17:44:12 -0400 Subject: [PATCH 27/29] add tests but codegen is acting weird --- src/bloqade/squin/passes/stim.py | 2 +- src/bloqade/squin/rewrite/squin_measure.py | 6 +- test/squin/stim/stim.py | 127 ++++++++++++++++++++- 3 files changed, 127 insertions(+), 8 deletions(-) diff --git a/src/bloqade/squin/passes/stim.py b/src/bloqade/squin/passes/stim.py index 3c172fb5..d919173c 100644 --- a/src/bloqade/squin/passes/stim.py +++ b/src/bloqade/squin/passes/stim.py @@ -49,7 +49,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult: ), SquinQubitToStim(), SquinWireToStim(), - SquinMeasureToStim(), + SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later SquinWireIdentityElimination(), ) ) diff --git a/src/bloqade/squin/rewrite/squin_measure.py b/src/bloqade/squin/rewrite/squin_measure.py index 1efd71e4..ee397233 100644 --- a/src/bloqade/squin/rewrite/squin_measure.py +++ b/src/bloqade/squin/rewrite/squin_measure.py @@ -30,7 +30,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: def rewrite_Measure( self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure ) -> RewriteResult: - if not is_measure_result_used(measure_stmt): + if is_measure_result_used(measure_stmt): return RewriteResult() qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) @@ -70,7 +70,9 @@ def rewrite_MeasureAndReset( return RewriteResult(has_done_something=True) - def get_qubit_idx_ssas(self, measure_stmt) -> tuple[ir.SSAValue, ...] | None: + def get_qubit_idx_ssas( + self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList | wire.Measure + ) -> tuple[ir.SSAValue, ...] | None: """ Extract the address attribute and insert qubit indices for the given measure statement. """ diff --git a/test/squin/stim/stim.py b/test/squin/stim/stim.py index 55b3aee1..533963b1 100644 --- a/test/squin/stim/stim.py +++ b/test/squin/stim/stim.py @@ -3,8 +3,18 @@ from kirin.dialects import py, func, ilist import bloqade.squin.passes as squin_passes -from bloqade import qasm2, squin +from bloqade import stim, qasm2, squin from bloqade.analysis import address +from bloqade.stim.emit import EmitStimMain + + +# Taken gratuitously from Kai's unit test +def stim_codegen(mt: ir.Method): + # method should not have any arguments! + emit = EmitStimMain(mt.dialects) + emit.initialize() + emit.run(mt=mt, args=()) + return emit.get_output() def as_int(value: int): @@ -15,15 +25,21 @@ def as_float(value: float): return py.constant.Constant(value=value) -def gen_func_from_stmts(stmts, output=types.NoneType): +def gen_func_from_stmts(stmts, output_type=types.NoneType): - extended_dialect = squin.groups.wired.add(qasm2.core).add(ilist).add(squin.qubit) + extended_dialect = ( + squin.groups.wired.add(qasm2.core) + .add(ilist) + .add(squin.qubit) + .add(stim.collapse) + .add(stim.gate) + ) block = ir.Block(stmts) - block.args.append_from(types.MethodType[[], types.NoneType], "main_self") + block.args.append_from(types.MethodType[[], types.NoneType], "main") func_wrapper = func.Function( sym_name="main", - signature=func.Signature(inputs=(), output=output), + signature=func.Signature(inputs=(), output=output_type), body=ir.Region(blocks=block), ) @@ -39,6 +55,107 @@ def gen_func_from_stmts(stmts, output=types.NoneType): return constructed_method +def test_qubit_to_stim(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # create ilist of qubits + (q_list := ilist.New(values=(q0.result, q1.result, q2.result, q3.result))), + # Broadcast with stim semantics + (h_op := squin.op.stmts.H()), + (app_res := squin.qubit.Broadcast(h_op.result, q_list.result)), # noqa: F841 + # try Apply now + (x_op := squin.op.stmts.X()), + (sub_q_list := ilist.New(values=(q0.result,))), + (squin.qubit.Apply(x_op.result, sub_q_list.result)), + # go for a control gate + (ctrl_op := squin.op.stmts.Control(x_op.result, n_controls=1)), + (sub_q_list2 := ilist.New(values=(q1.result, q3.result))), + (squin.qubit.Apply(ctrl_op.result, sub_q_list2.result)), + # Measure everything out + (meas_res := squin.qubit.MeasureQubitList(q_list.result)), # noqa: F841 + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_passes.SquinToStim(constructed_method.dialects, no_raise=False)( + constructed_method + ) + + constructed_method.print() + + # some problem with stim codegen in terms of + # stim_prog_str = stim_codegen(constructed_method) + # print(stim_prog_str) + + +def test_wire_to_stim(): + + stmts: list[ir.Statement] = [ + # Create qubit register + (n_qubits := as_int(4)), + (qreg := qasm2.core.QRegNew(n_qubits=n_qubits.result)), + # Get qubits out + (idx0 := as_int(0)), + (q0 := qasm2.core.QRegGet(reg=qreg.result, idx=idx0.result)), + (idx1 := as_int(1)), + (q1 := qasm2.core.QRegGet(reg=qreg.result, idx=idx1.result)), + (idx2 := as_int(2)), + (q2 := qasm2.core.QRegGet(reg=qreg.result, idx=idx2.result)), + (idx3 := as_int(3)), + (q3 := qasm2.core.QRegGet(reg=qreg.result, idx=idx3.result)), + # get wires from qubits + (w0 := squin.wire.Unwrap(qubit=q0.result)), + (w1 := squin.wire.Unwrap(qubit=q1.result)), + (w2 := squin.wire.Unwrap(qubit=q2.result)), + (w3 := squin.wire.Unwrap(qubit=q3.result)), + # try Apply + (op0 := squin.op.stmts.S()), + (app0 := squin.wire.Apply(op0.result, w0.result)), + # try Broadcast + (op1 := squin.op.stmts.H()), + ( + broad0 := squin.wire.Broadcast( + op1.result, app0.results[0], w1.result, w2.result, w3.result + ) + ), + # wrap everything back + (squin.wire.Wrap(broad0.results[0], q0.result)), + (squin.wire.Wrap(broad0.results[1], q1.result)), + (squin.wire.Wrap(broad0.results[2], q2.result)), + (squin.wire.Wrap(broad0.results[3], q3.result)), + (ret_none := func.ConstantNone()), + (func.Return(ret_none)), + ] + + constructed_method = gen_func_from_stmts(stmts) + + constructed_method.print() + + squin_to_stim = squin_passes.SquinToStim(constructed_method.dialects) + squin_to_stim(constructed_method) + + constructed_method.print() + + +test_wire_to_stim() + + def test_wire_1q_singular_apply(): stmts: list[ir.Statement] = [ From 0f759b8701701aedaa27bdea15c76d37e9093a36 Mon Sep 17 00:00:00 2001 From: John Long Date: Thu, 15 May 2025 17:49:46 -0400 Subject: [PATCH 28/29] remove site-target dimension check --- src/bloqade/squin/rewrite/qubit_to_stim.py | 4 -- .../squin/rewrite/stim_rewrite_util.py | 45 +------------------ src/bloqade/squin/rewrite/wire_to_stim.py | 4 -- 3 files changed, 1 insertion(+), 52 deletions(-) diff --git a/src/bloqade/squin/rewrite/qubit_to_stim.py b/src/bloqade/squin/rewrite/qubit_to_stim.py index 1b722f83..6a414d34 100644 --- a/src/bloqade/squin/rewrite/qubit_to_stim.py +++ b/src/bloqade/squin/rewrite/qubit_to_stim.py @@ -7,7 +7,6 @@ from bloqade.squin.rewrite.stim_rewrite_util import ( SQUIN_STIM_GATE_MAPPING, rewrite_Control, - are_sites_compatible, insert_qubit_idx_from_address, ) @@ -30,9 +29,6 @@ def rewrite_Apply_and_Broadcast( """ Rewrite Apply and Broadcast nodes to their stim equivalent statements. """ - # get rid of are_sites_compatible, assume program is properly structured - if not are_sites_compatible(stmt): - return RewriteResult() # this is an SSAValue, need it to be the actual operator applied_op = stmt.operator.owner diff --git a/src/bloqade/squin/rewrite/stim_rewrite_util.py b/src/bloqade/squin/rewrite/stim_rewrite_util.py index 2bf42c10..1148558a 100644 --- a/src/bloqade/squin/rewrite/stim_rewrite_util.py +++ b/src/bloqade/squin/rewrite/stim_rewrite_util.py @@ -1,5 +1,3 @@ -from typing import cast - from kirin import ir from kirin.dialects import py from kirin.rewrite.abc import RewriteResult @@ -7,8 +5,7 @@ from bloqade import stim from bloqade.squin import op, wire, qubit from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple -from bloqade.squin.analysis.nsites import NumberSites -from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute +from bloqade.squin.rewrite.wrap_analysis import AddressAttribute SQUIN_STIM_GATE_MAPPING = { op.stmts.X: stim.gate.X, @@ -71,46 +68,6 @@ def insert_qubit_idx_from_wire_ssa( return tuple(qubit_idx_ssas) -def are_sites_compatible( - stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast, -): - """ - Verify that the number of qubits/wires matches the number of sites supported by the operator. - """ - if isinstance(stmt, (wire.Apply, wire.Broadcast)): - num_sites_targeted = len(stmt.inputs) - elif isinstance(stmt, (qubit.Apply, qubit.Broadcast)): - address_attr = stmt.qubits.hints.get("address") - assert isinstance(address_attr, AddressAttribute) - address_tuple = address_attr.address - assert isinstance(address_tuple, AddressTuple) - num_sites_targeted = len(address_tuple.data) - else: - return False - - op_ssa = stmt.operator - op_stmt = op_ssa.owner - cast(ir.Statement, op_stmt) - - sites_attr = op_ssa.hints.get("sites") - if sites_attr is None: - return False - - assert isinstance(sites_attr, SitesAttribute) - sites_type = sites_attr.sites - assert isinstance(sites_type, NumberSites) - num_sites_supported = sites_type.sites - - if isinstance(stmt, (wire.Broadcast, qubit.Broadcast)): - if num_sites_targeted % num_sites_supported != 0: - return False - elif isinstance(stmt, (wire.Apply, qubit.Apply)): - if num_sites_targeted != num_sites_supported: - return False - - return True - - def insert_qubit_idx_after_apply( stmt: wire.Apply | qubit.Apply | wire.Broadcast | qubit.Broadcast, ) -> tuple[ir.SSAValue, ...] | None: diff --git a/src/bloqade/squin/rewrite/wire_to_stim.py b/src/bloqade/squin/rewrite/wire_to_stim.py index c71cd462..82971c86 100644 --- a/src/bloqade/squin/rewrite/wire_to_stim.py +++ b/src/bloqade/squin/rewrite/wire_to_stim.py @@ -7,7 +7,6 @@ from bloqade.squin.rewrite.stim_rewrite_util import ( SQUIN_STIM_GATE_MAPPING, rewrite_Control, - are_sites_compatible, insert_qubit_idx_from_address, insert_qubit_idx_from_wire_ssa, ) @@ -28,9 +27,6 @@ def rewrite_Apply_and_Broadcast( self, stmt: wire.Apply | wire.Broadcast ) -> RewriteResult: - if not are_sites_compatible(stmt): - return RewriteResult() - # this is an SSAValue, need it to be the actual operator applied_op = stmt.operator.owner assert isinstance(applied_op, op.stmts.Operator) From dbc105cda0b8a49c6615317811b1180c705da0c5 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 16 May 2025 09:16:17 -0400 Subject: [PATCH 29/29] implement second round review feedback --- src/bloqade/squin/analysis/nsites/impls.py | 2 +- src/bloqade/squin/rewrite/wrap_analysis.py | 5 ++--- src/bloqade/stim/dialects/auxiliary/stmts/annotate.py | 4 ---- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/bloqade/squin/analysis/nsites/impls.py b/src/bloqade/squin/analysis/nsites/impls.py index 7d4e0610..e8089bc9 100644 --- a/src/bloqade/squin/analysis/nsites/impls.py +++ b/src/bloqade/squin/analysis/nsites/impls.py @@ -21,7 +21,7 @@ def apply( stmt: wire.Apply | wire.Broadcast, ): - return tuple([frame.get(input) for input in stmt.inputs]) + return tuple(frame.get(input) for input in stmt.inputs) @interp.impl(wire.MeasureAndReset) def measure_and_reset( diff --git a/src/bloqade/squin/rewrite/wrap_analysis.py b/src/bloqade/squin/rewrite/wrap_analysis.py index fafd3fc4..f4f47c0a 100644 --- a/src/bloqade/squin/rewrite/wrap_analysis.py +++ b/src/bloqade/squin/rewrite/wrap_analysis.py @@ -1,4 +1,3 @@ -from typing import Dict from dataclasses import dataclass from kirin import ir @@ -43,8 +42,8 @@ def print_impl(self, printer: Printer) -> None: @dataclass class WrapSquinAnalysis(RewriteRule): - address_analysis: Dict[ir.SSAValue, Address] - op_site_analysis: Dict[ir.SSAValue, Sites] + address_analysis: dict[ir.SSAValue, Address] + op_site_analysis: dict[ir.SSAValue, Sites] def wrap(self, value: ir.SSAValue) -> bool: address_analysis_result = self.address_analysis[value] diff --git a/src/bloqade/stim/dialects/auxiliary/stmts/annotate.py b/src/bloqade/stim/dialects/auxiliary/stmts/annotate.py index 28acf873..eacd881f 100644 --- a/src/bloqade/stim/dialects/auxiliary/stmts/annotate.py +++ b/src/bloqade/stim/dialects/auxiliary/stmts/annotate.py @@ -45,7 +45,3 @@ class NewPauliString(ir.Statement): flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) result: ir.ResultValue = info.result(type=PauliStringType) - - -# dialect_group = squin_extended = squin in bloqade-circuit + physical in bloqade-qec -# Chen will need squin -> Stim rewrite IF it is a subroutine of rewrite from bloqade-qec extension