From a2ab3a5fde08d8fdd9c874bc6689a893f31f7415 Mon Sep 17 00:00:00 2001 From: John Long Date: Fri, 25 Apr 2025 17:16:19 -0400 Subject: [PATCH 1/2] initial noise statements sketch --- src/bloqade/squin/__init__.py | 2 +- src/bloqade/squin/noise/__init__.py | 31 ++++++++++++++ src/bloqade/squin/noise/_dialect.py | 3 ++ src/bloqade/squin/noise/stmts.py | 64 +++++++++++++++++++++++++++++ 4 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 src/bloqade/squin/noise/__init__.py create mode 100644 src/bloqade/squin/noise/_dialect.py create mode 100644 src/bloqade/squin/noise/stmts.py diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index f770f9ac..568c80f7 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -1,2 +1,2 @@ -from . import op as op, wire as wire, qubit as qubit +from . import op as op, wire as wire, noise as noise, qubit as qubit from .groups import wired as wired, kernel as kernel diff --git a/src/bloqade/squin/noise/__init__.py b/src/bloqade/squin/noise/__init__.py new file mode 100644 index 00000000..27b508c9 --- /dev/null +++ b/src/bloqade/squin/noise/__init__.py @@ -0,0 +1,31 @@ +# Put all the proper wrappers here + +# from kirin.types import Float, Tuple, Vararg +from kirin.lowering import wraps as _wraps + +from bloqade.squin.op.types import Op + +from . import stmts as stmts + + +@_wraps(stmts.PauliError) +def pauli_error(basis: Op, p: float) -> Op: ... + + +@_wraps(stmts.PPError) +def pp_error(op: Op, p: float) -> Op: ... + + +@_wraps(stmts.Depolarize) +def depolarize(n_qubits: int, p: float) -> Op: ... + + +# How do you give `types.Tuple[types.Vararg(types.Float)])` nicely to +# the wrapper? + +# @_wraps(stmts.PauliChannel) +# def pauli_channel(n_qubits: int, params:) -> Op:... + + +@_wraps(stmts.QubitLoss) +def qubit_loss(p: float) -> Op: ... diff --git a/src/bloqade/squin/noise/_dialect.py b/src/bloqade/squin/noise/_dialect.py new file mode 100644 index 00000000..025b2dfe --- /dev/null +++ b/src/bloqade/squin/noise/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect(name="squin.noise") diff --git a/src/bloqade/squin/noise/stmts.py b/src/bloqade/squin/noise/stmts.py new file mode 100644 index 00000000..743e26be --- /dev/null +++ b/src/bloqade/squin/noise/stmts.py @@ -0,0 +1,64 @@ +from kirin import ir, types +from kirin.decl import info, statement + +from bloqade.squin.op.types import OpType + +from ._dialect import dialect + + +@statement +class NoiseChannel(ir.Statement): + pass + + +@statement(dialect=dialect) +class PauliError(NoiseChannel): + name = "pauli_error" + basis: ir.SSAValue = info.argument(OpType) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class PPError(NoiseChannel): + """ + Pauli Product Error + """ + + name = "pp_error" + op: ir.SSAValue = info.argument(OpType) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class Depolarize(NoiseChannel): + """ + Apply n-qubit depolaize error to qubits + NOTE For Stim, this can only accept 1 or 2 qubits + """ + + name = "depolarize" + n_qubits: int = info.attribute(types.Int) + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class PauliChannel(NoiseChannel): + # NOTE: + # 1-qubit 3 params px, py, pz + # 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz + # TODO add validation for params (maybe during lowering via custom lower?) + name = "pauli_channel" + n_qubits: int = info.attribute() + params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)]) + result: ir.ResultValue = info.result(OpType) + + +@statement(dialect=dialect) +class QubitLoss(NoiseChannel): + # NOTE: qubit loss error (not supported by Stim) + name = "qubit_loss" + p: ir.SSAValue = info.argument(types.Float) + result: ir.ResultValue = info.result(OpType) From 747d8fce9af787a466186e22aa303c822a8c2329 Mon Sep 17 00:00:00 2001 From: John Long Date: Tue, 29 Apr 2025 14:06:33 -0400 Subject: [PATCH 2/2] remove snakecase, fix wrappers --- src/bloqade/squin/noise/__init__.py | 8 ++------ src/bloqade/squin/noise/stmts.py | 5 ----- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/bloqade/squin/noise/__init__.py b/src/bloqade/squin/noise/__init__.py index 27b508c9..f553b4aa 100644 --- a/src/bloqade/squin/noise/__init__.py +++ b/src/bloqade/squin/noise/__init__.py @@ -1,6 +1,5 @@ # Put all the proper wrappers here -# from kirin.types import Float, Tuple, Vararg from kirin.lowering import wraps as _wraps from bloqade.squin.op.types import Op @@ -20,11 +19,8 @@ def pp_error(op: Op, p: float) -> Op: ... def depolarize(n_qubits: int, p: float) -> Op: ... -# How do you give `types.Tuple[types.Vararg(types.Float)])` nicely to -# the wrapper? - -# @_wraps(stmts.PauliChannel) -# def pauli_channel(n_qubits: int, params:) -> Op:... +@_wraps(stmts.PauliChannel) +def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ... @_wraps(stmts.QubitLoss) diff --git a/src/bloqade/squin/noise/stmts.py b/src/bloqade/squin/noise/stmts.py index 743e26be..58b59ce5 100644 --- a/src/bloqade/squin/noise/stmts.py +++ b/src/bloqade/squin/noise/stmts.py @@ -13,7 +13,6 @@ class NoiseChannel(ir.Statement): @statement(dialect=dialect) class PauliError(NoiseChannel): - name = "pauli_error" basis: ir.SSAValue = info.argument(OpType) p: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -25,7 +24,6 @@ class PPError(NoiseChannel): Pauli Product Error """ - name = "pp_error" op: ir.SSAValue = info.argument(OpType) p: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -38,7 +36,6 @@ class Depolarize(NoiseChannel): NOTE For Stim, this can only accept 1 or 2 qubits """ - name = "depolarize" n_qubits: int = info.attribute(types.Int) p: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -50,7 +47,6 @@ class PauliChannel(NoiseChannel): # 1-qubit 3 params px, py, pz # 2-qubit 15 params pix, piy, piz, pxi, pxx, pxy, pxz, pyi, pyx ..., pzz # TODO add validation for params (maybe during lowering via custom lower?) - name = "pauli_channel" n_qubits: int = info.attribute() params: ir.SSAValue = info.argument(types.Tuple[types.Vararg(types.Float)]) result: ir.ResultValue = info.result(OpType) @@ -59,6 +55,5 @@ class PauliChannel(NoiseChannel): @statement(dialect=dialect) class QubitLoss(NoiseChannel): # NOTE: qubit loss error (not supported by Stim) - name = "qubit_loss" p: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType)