From ab36064645881f844f76739d74433a60abee0c0c Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 28 Apr 2025 10:07:02 +0200 Subject: [PATCH 01/10] Rewrite py.Mult to squin.Mult in squin kernel --- src/bloqade/squin/groups.py | 5 ++++ src/bloqade/squin/op/__init__.py | 2 +- src/bloqade/squin/op/rewrite.py | 48 ++++++++++++++++++++++++++++++++ src/bloqade/squin/op/types.py | 3 ++ test/squin/test_mult_rewrite.py | 39 ++++++++++++++++++++++++++ 5 files changed, 96 insertions(+), 1 deletion(-) create mode 100644 src/bloqade/squin/op/rewrite.py create mode 100644 test/squin/test_mult_rewrite.py diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index dda7a5b1..beb443ec 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -5,6 +5,7 @@ from . import op, wire, qubit from .rewrite.measure_desugar import MeasureDesugarRule +from .op.rewrite import PyMultToSquinMult @ir.dialect_group(structural_no_opt.union([op, qubit])) @@ -13,6 +14,7 @@ def kernel(self): typeinfer_pass = passes.TypeInfer(self) ilist_desugar_pass = ilist.IListDesugar(self) measure_desugar_pass = Walk(MeasureDesugarRule()) + py_mult_to_mult_pass = PyMultToSquinMult(self) def run_pass(method: ir.Method, *, fold=True, typeinfer=True): method.verify() @@ -22,7 +24,10 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): if typeinfer: typeinfer_pass(method) measure_desugar_pass.rewrite(method.code) + ilist_desugar_pass(method) + py_mult_to_mult_pass(method) + if typeinfer: typeinfer_pass(method) # fix types after desugaring method.verify_type() diff --git a/src/bloqade/squin/op/__init__.py b/src/bloqade/squin/op/__init__.py index 77b07c64..42cd426a 100644 --- a/src/bloqade/squin/op/__init__.py +++ b/src/bloqade/squin/op/__init__.py @@ -2,7 +2,7 @@ from kirin.prelude import structural_no_opt as _structural_no_opt from kirin.lowering import wraps as _wraps -from . import stmts as stmts, types as types +from . import stmts as stmts, types as types, rewrite as rewrite from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary from ._dialect import dialect as dialect diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py new file mode 100644 index 00000000..f02657fd --- /dev/null +++ b/src/bloqade/squin/op/rewrite.py @@ -0,0 +1,48 @@ +"""Rewrite py.binop.mult to Mult stmt""" + +from kirin import ir +from kirin.passes import Pass +from kirin.rewrite import Walk +from kirin.dialects import py +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from .stmts import Mult, Scale +from .types import OpType + + +class _PyMultToSquinMult(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if not isinstance(node, py.Mult): + return RewriteResult() + + lhs_is_op = node.lhs.type.is_subseteq(OpType) + rhs_is_op = node.lhs.type.is_subseteq(OpType) + + if not lhs_is_op and not rhs_is_op: + return RewriteResult() + + if lhs_is_op and rhs_is_op: + mult = Mult(node.lhs, node.rhs) + node.replace_by(mult) + return RewriteResult(has_done_something=True) + + if lhs_is_op: + scale = Scale(node.lhs, node.rhs) + node.replace_by(scale) + return RewriteResult(has_done_something=True) + + if rhs_is_op: + scale = Scale(node.rhs, node.lhs) + node.replace_by(scale) + return RewriteResult(has_done_something=True) + + raise ValueError( + "Rewrite of py.binop.mult failed. This exception should not be reachable, please report this issue." + ) + + +class PyMultToSquinMult(Pass): + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + return Walk(_PyMultToSquinMult()).rewrite(mt.code) diff --git a/src/bloqade/squin/op/types.py b/src/bloqade/squin/op/types.py index 0c4564e6..d0e0cf52 100644 --- a/src/bloqade/squin/op/types.py +++ b/src/bloqade/squin/op/types.py @@ -6,5 +6,8 @@ class Op: def __matmul__(self, other: "Op") -> "Op": raise NotImplementedError("@ can only be used within a squin kernel program") + def __mul__(self, other: "Op") -> "Op": + raise NotImplementedError("@ can only be used within a squin kernel program") + OpType = types.PyClass(Op) diff --git a/test/squin/test_mult_rewrite.py b/test/squin/test_mult_rewrite.py new file mode 100644 index 00000000..3ec8481d --- /dev/null +++ b/test/squin/test_mult_rewrite.py @@ -0,0 +1,39 @@ +from kirin.dialects import py, func + +from bloqade import squin + + +def test_mult_rewrite(): + + @squin.kernel + def helper(x: squin.op.types.Op, y: squin.op.types.Op): + return x * y + + @squin.kernel + def main(): + q = squin.qubit.new(1) + x = squin.op.x() + y = squin.op.y() + z = x * y + t = helper(x, z) + + squin.qubit.apply(t, q) + return q + + helper.print() + + assert isinstance(helper.code, func.Function) + + helper_stmts = list(helper.code.body.stmts()) + assert len(helper_stmts) == 2 # [Mult(), Return()] + assert isinstance(helper_stmts[0], squin.op.stmts.Mult) + + assert isinstance(main.code, func.Function) + + count_mults_in_main = 0 + for stmt in main.code.body.stmts(): + assert not isinstance(stmt, py.Mult) + + count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult) + + assert count_mults_in_main == 1 From a09a38c8dbf3cbfa900d25920f34674f03631bef Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Mon, 28 Apr 2025 11:55:36 +0200 Subject: [PATCH 02/10] Fix Scale rewrite --- src/bloqade/squin/op/rewrite.py | 10 ++++- src/bloqade/squin/op/types.py | 13 +++++- test/squin/test_mult_rewrite.py | 80 +++++++++++++++++++++++++++++++++ 3 files changed, 100 insertions(+), 3 deletions(-) diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py index f02657fd..387c458a 100644 --- a/src/bloqade/squin/op/rewrite.py +++ b/src/bloqade/squin/op/rewrite.py @@ -6,7 +6,7 @@ from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from .stmts import Mult, Scale +from .stmts import Mult, Scale, Operator from .types import OpType @@ -17,11 +17,17 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: return RewriteResult() lhs_is_op = node.lhs.type.is_subseteq(OpType) - rhs_is_op = node.lhs.type.is_subseteq(OpType) + rhs_is_op = node.rhs.type.is_subseteq(OpType) if not lhs_is_op and not rhs_is_op: return RewriteResult() + if isinstance(node.lhs, ir.ResultValue): + lhs_is_op = isinstance(node.lhs.stmt, Operator) + + if isinstance(node.rhs, ir.ResultValue): + rhs_is_op = isinstance(node.rhs.stmt, Operator) + if lhs_is_op and rhs_is_op: mult = Mult(node.lhs, node.rhs) node.replace_by(mult) diff --git a/src/bloqade/squin/op/types.py b/src/bloqade/squin/op/types.py index d0e0cf52..25f95c8f 100644 --- a/src/bloqade/squin/op/types.py +++ b/src/bloqade/squin/op/types.py @@ -1,3 +1,5 @@ +from typing import overload + from kirin import types @@ -6,7 +8,16 @@ class Op: def __matmul__(self, other: "Op") -> "Op": raise NotImplementedError("@ can only be used within a squin kernel program") - def __mul__(self, other: "Op") -> "Op": + @overload + def __mul__(self, other: "Op") -> "Op": ... + + @overload + def __mul__(self, other: int | float | complex) -> "Op": ... + + def __mul__(self, other) -> "Op": + raise NotImplementedError("@ can only be used within a squin kernel program") + + def __rmul__(self, other: int | float | complex) -> "Op": raise NotImplementedError("@ can only be used within a squin kernel program") diff --git a/test/squin/test_mult_rewrite.py b/test/squin/test_mult_rewrite.py index 3ec8481d..a7951456 100644 --- a/test/squin/test_mult_rewrite.py +++ b/test/squin/test_mult_rewrite.py @@ -21,6 +21,7 @@ def main(): return q helper.print() + main.print() assert isinstance(helper.code, func.Function) @@ -37,3 +38,82 @@ def main(): count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult) assert count_mults_in_main == 1 + + +def test_scale_rewrite(): + + @squin.kernel + def simple_rmul(): + x = squin.op.x() + y = 2 * x + return y + + simple_rmul.print() + + assert isinstance(simple_rmul.code, func.Function) + + simple_rmul_stmts = list(simple_rmul.code.body.stmts()) + assert any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_rmul_stmts) + ) + assert not any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_rmul_stmts) + ) + assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_rmul_stmts)) + + @squin.kernel + def simple_lmul(): + x = squin.op.x() + y = x * 2 + return y + + simple_lmul.print() + + assert isinstance(simple_lmul.code, func.Function) + + simple_lmul_stmts = list(simple_lmul.code.body.stmts()) + assert any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_lmul_stmts) + ) + assert not any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_lmul_stmts) + ) + assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_lmul_stmts)) + + @squin.kernel + def scale_mult(): + x = squin.op.x() + y = squin.op.y() + return 2 * (x * y) + + assert isinstance(scale_mult.code, func.Function) + + scale_mult_stmts = list(scale_mult.code.body.stmts()) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult_stmts)) + == 1 + ) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult_stmts)) + == 1 + ) + + @squin.kernel + def scale_mult2(): + x = squin.op.x() + y = squin.op.y() + return 2 * x * y + + scale_mult2.print() + + assert isinstance(scale_mult2.code, func.Function) + + scale_mult2_stmts = list(scale_mult2.code.body.stmts()) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult2_stmts)) + == 1 + ) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts)) + == 1 + ) From 95a29b177fff94592a0c1ecc6d08f1088306742b Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 29 Apr 2025 10:02:39 +0200 Subject: [PATCH 03/10] Fix type inference --- src/bloqade/squin/groups.py | 3 ++- src/bloqade/squin/op/complex.py | 6 ----- src/bloqade/squin/op/number.py | 5 +++++ src/bloqade/squin/op/rewrite.py | 8 +------ src/bloqade/squin/op/stmts.py | 4 ++-- src/bloqade/squin/op/types.py | 4 ++-- test/squin/test_mult_rewrite.py | 40 +++++++++++++++++++++++++++++++++ 7 files changed, 52 insertions(+), 18 deletions(-) delete mode 100644 src/bloqade/squin/op/complex.py create mode 100644 src/bloqade/squin/op/number.py diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index beb443ec..96e44e2d 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -21,12 +21,13 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): if fold: fold_pass.fixpoint(method) + py_mult_to_mult_pass(method) + if typeinfer: typeinfer_pass(method) measure_desugar_pass.rewrite(method.code) ilist_desugar_pass(method) - py_mult_to_mult_pass(method) if typeinfer: typeinfer_pass(method) # fix types after desugaring diff --git a/src/bloqade/squin/op/complex.py b/src/bloqade/squin/op/complex.py deleted file mode 100644 index 10e0d630..00000000 --- a/src/bloqade/squin/op/complex.py +++ /dev/null @@ -1,6 +0,0 @@ -# Stopgap Measure, squin dialect needs Complex type but -# this is only available in Kirin 0.15.x - -from kirin.ir.attrs.types import PyClass - -Complex = PyClass(complex) diff --git a/src/bloqade/squin/op/number.py b/src/bloqade/squin/op/number.py new file mode 100644 index 00000000..71968837 --- /dev/null +++ b/src/bloqade/squin/op/number.py @@ -0,0 +1,5 @@ +import numbers + +from kirin.ir.attrs.types import PyClass + +Number = PyClass(numbers.Number) diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py index 387c458a..4e63edaf 100644 --- a/src/bloqade/squin/op/rewrite.py +++ b/src/bloqade/squin/op/rewrite.py @@ -6,7 +6,7 @@ from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from .stmts import Mult, Scale, Operator +from .stmts import Mult, Scale from .types import OpType @@ -22,12 +22,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if not lhs_is_op and not rhs_is_op: return RewriteResult() - if isinstance(node.lhs, ir.ResultValue): - lhs_is_op = isinstance(node.lhs.stmt, Operator) - - if isinstance(node.rhs, ir.ResultValue): - rhs_is_op = isinstance(node.rhs.stmt, Operator) - if lhs_is_op and rhs_is_op: mult = Mult(node.lhs, node.rhs) node.replace_by(mult) diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index a17dd6e7..fa21e5f0 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,8 +2,8 @@ from kirin.decl import info, statement from .types import OpType +from .number import Number from .traits import Unitary, HasSites, FixedSites, MaybeUnitary -from .complex import Complex from ._dialect import dialect @@ -54,7 +54,7 @@ class Scale(CompositeOp): traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) - factor: ir.SSAValue = info.argument(Complex) + factor: ir.SSAValue = info.argument(Number) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/types.py b/src/bloqade/squin/op/types.py index 25f95c8f..4f4388c4 100644 --- a/src/bloqade/squin/op/types.py +++ b/src/bloqade/squin/op/types.py @@ -12,12 +12,12 @@ def __matmul__(self, other: "Op") -> "Op": def __mul__(self, other: "Op") -> "Op": ... @overload - def __mul__(self, other: int | float | complex) -> "Op": ... + def __mul__(self, other: complex) -> "Op": ... def __mul__(self, other) -> "Op": raise NotImplementedError("@ can only be used within a squin kernel program") - def __rmul__(self, other: int | float | complex) -> "Op": + def __rmul__(self, other: complex) -> "Op": raise NotImplementedError("@ can only be used within a squin kernel program") diff --git a/test/squin/test_mult_rewrite.py b/test/squin/test_mult_rewrite.py index a7951456..3a81e205 100644 --- a/test/squin/test_mult_rewrite.py +++ b/test/squin/test_mult_rewrite.py @@ -1,3 +1,4 @@ +from kirin.types import PyClass from kirin.dialects import py, func from bloqade import squin @@ -117,3 +118,42 @@ def scale_mult2(): sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts)) == 1 ) + + +def test_scale_types(): + @squin.kernel + def simple_lmul(): + x = squin.op.x() + y = x * (2 + 0j) + return y + + @squin.kernel + def simple_rmul(): + x = squin.op.x() + y = 2.1 * x + return y + + @squin.kernel + def nested_rmul(): + x = squin.op.x() + y = squin.op.y() + return 2 * x * y + + @squin.kernel + def nested_lmul(): + x = squin.op.x() + y = squin.op.y() + return x * y * 2.0j + + def check_stmt_type(code, typ): + for stmt in code.body.stmts(): + if isinstance(stmt, func.Return): + continue + is_op = stmt.result.type.is_subseteq(squin.op.types.OpType) + is_num = stmt.result.type.is_equal(PyClass(typ)) + assert is_op or is_num + + check_stmt_type(simple_lmul.code, complex) + check_stmt_type(simple_rmul.code, float) + check_stmt_type(nested_rmul.code, int) + check_stmt_type(nested_lmul.code, complex) From 255024d38ffbdd7b605aad438793fbdd00fd9942 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Tue, 29 Apr 2025 10:11:31 +0200 Subject: [PATCH 04/10] Also add the pass to wired --- src/bloqade/squin/groups.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index 96e44e2d..8b72b36e 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -38,7 +38,9 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): @ir.dialect_group(structural_no_opt.union([op, wire])) def wired(self): + py_mult_to_mult_pass = PyMultToSquinMult(self) + def run_pass(method): - pass + py_mult_to_mult_pass(method) return run_pass From c4e860b855da74e5de330dba2b5c8d6eb8a96400 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Apr 2025 10:43:44 +0200 Subject: [PATCH 05/10] Add another test --- test/squin/test_mult_rewrite.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/squin/test_mult_rewrite.py b/test/squin/test_mult_rewrite.py index 3a81e205..8a0bbaea 100644 --- a/test/squin/test_mult_rewrite.py +++ b/test/squin/test_mult_rewrite.py @@ -139,6 +139,12 @@ def nested_rmul(): y = squin.op.y() return 2 * x * y + @squin.kernel + def nested_rmul2(): + x = squin.op.x() + y = squin.op.y() + return 2 * (x * y) + @squin.kernel def nested_lmul(): x = squin.op.x() @@ -156,4 +162,5 @@ def check_stmt_type(code, typ): check_stmt_type(simple_lmul.code, complex) check_stmt_type(simple_rmul.code, float) check_stmt_type(nested_rmul.code, int) + check_stmt_type(nested_rmul2.code, int) check_stmt_type(nested_lmul.code, complex) From 9e94bf75087c45be400de3f2a2195ed4a682fa47 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Apr 2025 10:45:00 +0200 Subject: [PATCH 06/10] Run isort on rebased file --- src/bloqade/squin/groups.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index 8b72b36e..9b94b63d 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -4,8 +4,8 @@ from kirin.rewrite.walk import Walk from . import op, wire, qubit -from .rewrite.measure_desugar import MeasureDesugarRule from .op.rewrite import PyMultToSquinMult +from .rewrite.measure_desugar import MeasureDesugarRule @ir.dialect_group(structural_no_opt.union([op, qubit])) From 593f2dcd096c3fc7cc13138741fb54160fc445a8 Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Apr 2025 16:13:59 +0200 Subject: [PATCH 07/10] Update src/bloqade/squin/op/number.py Co-authored-by: Phillip Weinberg --- src/bloqade/squin/op/number.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/squin/op/number.py b/src/bloqade/squin/op/number.py index 71968837..3cdc12b1 100644 --- a/src/bloqade/squin/op/number.py +++ b/src/bloqade/squin/op/number.py @@ -2,4 +2,4 @@ from kirin.ir.attrs.types import PyClass -Number = PyClass(numbers.Number) +NumberType = PyClass(numbers.Number) From 063183ee8c71f947b9debe46d1354be1f636f23e Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Apr 2025 16:14:12 +0200 Subject: [PATCH 08/10] Update src/bloqade/squin/op/rewrite.py Co-authored-by: Phillip Weinberg --- src/bloqade/squin/op/rewrite.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py index 4e63edaf..001f906e 100644 --- a/src/bloqade/squin/op/rewrite.py +++ b/src/bloqade/squin/op/rewrite.py @@ -37,9 +37,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: node.replace_by(scale) return RewriteResult(has_done_something=True) - raise ValueError( - "Rewrite of py.binop.mult failed. This exception should not be reachable, please report this issue." - ) + return RewriteResult() class PyMultToSquinMult(Pass): From c4560b15584b0d8f1f11de24b87d870d4d23375a Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Apr 2025 16:16:38 +0200 Subject: [PATCH 09/10] Fix renaming to NumberType --- src/bloqade/squin/op/stmts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index fa21e5f0..6b5b44ff 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from .types import OpType -from .number import Number +from .number import NumberType from .traits import Unitary, HasSites, FixedSites, MaybeUnitary from ._dialect import dialect @@ -54,7 +54,7 @@ class Scale(CompositeOp): traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) - factor: ir.SSAValue = info.argument(Number) + factor: ir.SSAValue = info.argument(NumberType) result: ir.ResultValue = info.result(OpType) From b98e266a851c644b39a3bd1e05f204f9c5e5f3bd Mon Sep 17 00:00:00 2001 From: David Plankensteiner Date: Wed, 30 Apr 2025 16:20:01 +0200 Subject: [PATCH 10/10] Fix indent in rewrite --- src/bloqade/squin/op/rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py index 001f906e..64000343 100644 --- a/src/bloqade/squin/op/rewrite.py +++ b/src/bloqade/squin/op/rewrite.py @@ -37,7 +37,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: node.replace_by(scale) return RewriteResult(has_done_something=True) - return RewriteResult() + return RewriteResult() class PyMultToSquinMult(Pass):