Skip to content

Rewrite py.Mult to squin.Mult in squin kernel #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from kirin.rewrite.walk import Walk

from . import op, wire, qubit
from .op.rewrite import PyMultToSquinMult
from .rewrite.measure_desugar import MeasureDesugarRule


Expand All @@ -13,16 +14,21 @@
typeinfer_pass = passes.TypeInfer(self)
ilist_desugar_pass = ilist.IListDesugar(self)
measure_desugar_pass = Walk(MeasureDesugarRule())
py_mult_to_mult_pass = PyMultToSquinMult(self)

def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
method.verify()
if fold:
fold_pass.fixpoint(method)

py_mult_to_mult_pass(method)

if typeinfer:
typeinfer_pass(method)
measure_desugar_pass.rewrite(method.code)

ilist_desugar_pass(method)

if typeinfer:
typeinfer_pass(method) # fix types after desugaring
method.verify_type()
Expand All @@ -32,7 +38,9 @@

@ir.dialect_group(structural_no_opt.union([op, wire]))
def wired(self):
py_mult_to_mult_pass = PyMultToSquinMult(self)

def run_pass(method):
pass
py_mult_to_mult_pass(method)

Check warning on line 44 in src/bloqade/squin/groups.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/groups.py#L44

Added line #L44 was not covered by tests

return run_pass
2 changes: 1 addition & 1 deletion src/bloqade/squin/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.prelude import structural_no_opt as _structural_no_opt
from kirin.lowering import wraps as _wraps

from . import stmts as stmts, types as types
from . import stmts as stmts, types as types, rewrite as rewrite
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
from ._dialect import dialect as dialect

Expand Down
6 changes: 0 additions & 6 deletions src/bloqade/squin/op/complex.py

This file was deleted.

5 changes: 5 additions & 0 deletions src/bloqade/squin/op/number.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import numbers

from kirin.ir.attrs.types import PyClass

NumberType = PyClass(numbers.Number)
46 changes: 46 additions & 0 deletions src/bloqade/squin/op/rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Rewrite py.binop.mult to Mult stmt"""

from kirin import ir
from kirin.passes import Pass
from kirin.rewrite import Walk
from kirin.dialects import py
from kirin.rewrite.abc import RewriteRule, RewriteResult

from .stmts import Mult, Scale
from .types import OpType


class _PyMultToSquinMult(RewriteRule):

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, py.Mult):
return RewriteResult()

lhs_is_op = node.lhs.type.is_subseteq(OpType)
rhs_is_op = node.rhs.type.is_subseteq(OpType)

if not lhs_is_op and not rhs_is_op:
return RewriteResult()

Check warning on line 23 in src/bloqade/squin/op/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/rewrite.py#L23

Added line #L23 was not covered by tests

if lhs_is_op and rhs_is_op:
mult = Mult(node.lhs, node.rhs)
node.replace_by(mult)
return RewriteResult(has_done_something=True)

if lhs_is_op:
scale = Scale(node.lhs, node.rhs)
node.replace_by(scale)
return RewriteResult(has_done_something=True)

if rhs_is_op:
scale = Scale(node.rhs, node.lhs)
node.replace_by(scale)
return RewriteResult(has_done_something=True)

return RewriteResult()

Check warning on line 40 in src/bloqade/squin/op/rewrite.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/rewrite.py#L40

Added line #L40 was not covered by tests


class PyMultToSquinMult(Pass):

def unsafe_run(self, mt: ir.Method) -> RewriteResult:
return Walk(_PyMultToSquinMult()).rewrite(mt.code)
4 changes: 2 additions & 2 deletions src/bloqade/squin/op/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from kirin.decl import info, statement

from .types import OpType
from .number import NumberType
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
from .complex import Complex
from ._dialect import dialect


Expand Down Expand Up @@ -54,7 +54,7 @@ class Scale(CompositeOp):
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()})
is_unitary: bool = info.attribute(default=False)
op: ir.SSAValue = info.argument(OpType)
factor: ir.SSAValue = info.argument(Complex)
factor: ir.SSAValue = info.argument(NumberType)
result: ir.ResultValue = info.result(OpType)


Expand Down
14 changes: 14 additions & 0 deletions src/bloqade/squin/op/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import overload

from kirin import types


Expand All @@ -6,5 +8,17 @@
def __matmul__(self, other: "Op") -> "Op":
raise NotImplementedError("@ can only be used within a squin kernel program")

@overload
def __mul__(self, other: "Op") -> "Op": ...

@overload
def __mul__(self, other: complex) -> "Op": ...

def __mul__(self, other) -> "Op":
raise NotImplementedError("@ can only be used within a squin kernel program")

Check warning on line 18 in src/bloqade/squin/op/types.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/types.py#L18

Added line #L18 was not covered by tests

def __rmul__(self, other: complex) -> "Op":
raise NotImplementedError("@ can only be used within a squin kernel program")

Check warning on line 21 in src/bloqade/squin/op/types.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/op/types.py#L21

Added line #L21 was not covered by tests


OpType = types.PyClass(Op)
166 changes: 166 additions & 0 deletions test/squin/test_mult_rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from kirin.types import PyClass
from kirin.dialects import py, func

from bloqade import squin


def test_mult_rewrite():

@squin.kernel
def helper(x: squin.op.types.Op, y: squin.op.types.Op):
return x * y

@squin.kernel
def main():
q = squin.qubit.new(1)
x = squin.op.x()
y = squin.op.y()
z = x * y
t = helper(x, z)

squin.qubit.apply(t, q)
return q

helper.print()
main.print()

assert isinstance(helper.code, func.Function)

helper_stmts = list(helper.code.body.stmts())
assert len(helper_stmts) == 2 # [Mult(), Return()]
assert isinstance(helper_stmts[0], squin.op.stmts.Mult)

assert isinstance(main.code, func.Function)

count_mults_in_main = 0
for stmt in main.code.body.stmts():
assert not isinstance(stmt, py.Mult)

count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult)

assert count_mults_in_main == 1


def test_scale_rewrite():

@squin.kernel
def simple_rmul():
x = squin.op.x()
y = 2 * x
return y

simple_rmul.print()

assert isinstance(simple_rmul.code, func.Function)

simple_rmul_stmts = list(simple_rmul.code.body.stmts())
assert any(
map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_rmul_stmts)
)
assert not any(
map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_rmul_stmts)
)
assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_rmul_stmts))

@squin.kernel
def simple_lmul():
x = squin.op.x()
y = x * 2
return y

simple_lmul.print()

assert isinstance(simple_lmul.code, func.Function)

simple_lmul_stmts = list(simple_lmul.code.body.stmts())
assert any(
map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_lmul_stmts)
)
assert not any(
map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_lmul_stmts)
)
assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_lmul_stmts))

@squin.kernel
def scale_mult():
x = squin.op.x()
y = squin.op.y()
return 2 * (x * y)

assert isinstance(scale_mult.code, func.Function)

scale_mult_stmts = list(scale_mult.code.body.stmts())
assert (
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult_stmts))
== 1
)
assert (
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult_stmts))
== 1
)

@squin.kernel
def scale_mult2():
x = squin.op.x()
y = squin.op.y()
return 2 * x * y

scale_mult2.print()

assert isinstance(scale_mult2.code, func.Function)

scale_mult2_stmts = list(scale_mult2.code.body.stmts())
assert (
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult2_stmts))
== 1
)
assert (
sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts))
== 1
)


def test_scale_types():
@squin.kernel
def simple_lmul():
x = squin.op.x()
y = x * (2 + 0j)
return y

@squin.kernel
def simple_rmul():
x = squin.op.x()
y = 2.1 * x
return y

@squin.kernel
def nested_rmul():
x = squin.op.x()
y = squin.op.y()
return 2 * x * y

@squin.kernel
def nested_rmul2():
x = squin.op.x()
y = squin.op.y()
return 2 * (x * y)

@squin.kernel
def nested_lmul():
x = squin.op.x()
y = squin.op.y()
return x * y * 2.0j

def check_stmt_type(code, typ):
for stmt in code.body.stmts():
if isinstance(stmt, func.Return):
continue
is_op = stmt.result.type.is_subseteq(squin.op.types.OpType)
is_num = stmt.result.type.is_equal(PyClass(typ))
assert is_op or is_num

check_stmt_type(simple_lmul.code, complex)
check_stmt_type(simple_rmul.code, float)
check_stmt_type(nested_rmul.code, int)
check_stmt_type(nested_rmul2.code, int)
check_stmt_type(nested_lmul.code, complex)