Skip to content

Rewrite py.Mult to squin.Mult in squin kernel #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

david-pl
Copy link
Contributor

Seemed sufficiently different from #207, so I created a new branch for it.

@Roger-luo two things:

  • I'm not sure whether this should be a Fixpoint.
  • Should this also be part of the wire kernel?

@david-pl david-pl requested a review from Roger-luo April 28, 2025 08:09
@david-pl
Copy link
Contributor Author

CI fails because of kirin v0.17.3 (passes locally with 0.17.2).

@david-pl
Copy link
Contributor Author

Looks like CI failure is caused by this commit: QuEraComputing/kirin@9cb1f5d

The SimpleMergePolicy seems to rely on something being there that is already poped due to the change above.

Copy link
Member

@Roger-luo Roger-luo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to use fixpoint here because the termination condition is simple here - you walk once, all py.binop.Mult is being replaced?

Also on 0.18 this will be simpler - we offer a dialect.canonicalize decorator so you can register the rule only without implementing this Pass. We can come back and cleanup this part after I finish the full refactor.

@Roger-luo
Copy link
Member

The SimpleMergePolicy seems to rely on something being there that is already poped due to the change above.

@weinbe58 I'm wondering if you know what was causing the issue there? should we change the Walk order by changing the reverse option?

@Roger-luo
Copy link
Member

Should this also be part of the wire kernel?

No, this is due to op dialect, so you only need one rule asscociate with op dialect.

@david-pl
Copy link
Contributor Author

It was actually two issues, both of which you hinted at @Roger-luo :

  • There are two type inference passes in the kernel: the second is apparently required because the desugaring pass messes up the types. I put the rewrite pass before the second inference pass, but it needed to be before the first. That fixed the inference.
  • The typing of Scale was factor = info.argument(Complex), which didn't work properly since isinstance(2, complex) is False. I changed it to factor = info.argument(PyClass(numbers.Number)) and it works. Should this be added as a type upstream?

One more question:

No, this is due to op dialect, so you only need one rule asscociate with op dialect.

The op dialect doesn't have its own kernel. Should I create one just to add the rewriting pass there and then define both wired and kernel on top of it? How do I ensure the order with the type inference passes then?

Other than that: ready for review!

@weinbe58
Copy link
Member

@weinbe58 I'm wondering if you know what was causing the issue there? should we change the Walk order by changing the reverse option?

This error is easily fixed in my pr #214 The heuristic noise model needs a big refactor to get it to work, unfortunately. I'll work on that today.

@david-pl david-pl force-pushed the david/squin-mult-rewrite branch from 028251c to 858c4c0 Compare April 30, 2025 08:43
@david-pl
Copy link
Contributor Author

Okay, so after rebasing the tests I added now fail due to kirin v0.17.3. Specifically, the type inference for Scale is again broken somehow. The change is once more the update to the worklist, i.e. this commit: QuEraComputing/kirin@9cb1f5d

The failing test checks to see if this kernel

@squin.kernel
def scale_mult():
    x = squin.op.x()
    y = squin.op.y()
    return 2 * (x * y)

gets rewritten such that it includes exactly one Scale and one Mult statement. Here are the IRs:

kirin v0.17.2:

scale_mult.print()
func.func scale_mult() -> !py.Op {
  ^0(%scale_mult_self):
  │ %x = squin.op.x() : !py.Op%y = squin.op.y() : !py.Op%0 = py.constant.constant 2 : !py.int%1 = squin.op.mult(lhs=%x, rhs=%y){is_unitary=False : !py.bool} : !py.Op%2 = squin.op.scale(op=%1, factor=%0){is_unitary=False : !py.bool} : !py.Opfunc.return %2
} // func.func scale_mult

kirin v0.17.3:

scale_mult.print()
func.func scale_mult() -> !py.int {
  ^0(%scale_mult_self):
  │ %x = squin.op.x() : !py.Op%y = squin.op.y() : !py.Op%0 = py.constant.constant 2 : !Bottom%1 = squin.op.mult(lhs=%x, rhs=%y){is_unitary=False : !py.bool} : !Bottom%2 = py.binop.mult(%0, %1) : !py.intfunc.return %2
} // func.func scale_mult

@Roger-luo do you have an idea what's going on here?

david-pl and others added 4 commits April 30, 2025 16:13
Co-authored-by: Phillip Weinberg <weinbe58@gmail.com>
Co-authored-by: Phillip Weinberg <weinbe58@gmail.com>
@david-pl david-pl requested a review from weinbe58 April 30, 2025 14:23
@david-pl
Copy link
Contributor Author

CI passes on v0.17.4 of kirin. Just need to wait for the release to be available.

@Roger-luo Roger-luo closed this Apr 30, 2025
@Roger-luo Roger-luo reopened this Apr 30, 2025
@Roger-luo
Copy link
Member

@david-pl
Copy link
Contributor Author

@Roger-luo that's a different error though. The new tests pass, maybe I just need another rebase to main.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants