Skip to content

Commit 66b1f8c

Browse files
authored
[exir] Allow verifiers in _transform
Differential Revision: D73205727 Pull Request resolved: #10274
1 parent 2ecc819 commit 66b1f8c

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

Diff for: exir/program/_program.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,30 @@ def _get_updated_graph_signature(
212212
return new_signature
213213

214214

215-
def _transform(self, *passes: PassType) -> "ExportedProgram":
215+
def _transform(
216+
self,
217+
*passes: PassType,
218+
override_verifiers: None | list[Type[Verifier]] = None,
219+
) -> "ExportedProgram":
220+
"""
221+
Transforms the program according to the provided passes.
222+
223+
Args:
224+
self: The ExportedProgram instance to transform
225+
*passes: A sequence of passes to apply to the program
226+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
227+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
228+
229+
Returns:
230+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
231+
"""
232+
# A user friendly check to avoid vararg surprises, PEP 3102
233+
assert not any(
234+
isinstance(p, (list, Verifier)) for p in passes
235+
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
236+
237+
for p in list(passes):
238+
print(type(p))
216239
pm = PassManager(list(passes))
217240
res = pm(self.graph_module)
218241
transformed_gm = res.graph_module if res is not None else self.graph_module
@@ -221,7 +244,9 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
221244
if transformed_gm is self.graph_module and not res.modified:
222245
return self
223246

224-
return _update_exported_program_graph_module(self, transformed_gm)
247+
return _update_exported_program_graph_module(
248+
self, transformed_gm, override_verifiers
249+
)
225250

226251

227252
def _update_exported_program_graph_module(

Diff for: exir/program/test/test_program.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.exir.pass_base import ExportPass
2323
from executorch.exir.passes import MemoryPlanningPass
2424
from executorch.exir.program._program import (
25+
_transform,
2526
EdgeProgramManager,
2627
ExecutorchProgramManager,
2728
to_edge,
@@ -34,6 +35,7 @@
3435
from executorch.extension.pybindings.portable_lib import (
3536
_load_for_executorch_from_buffer,
3637
)
38+
from torch._export.verifier import Verifier
3739
from torch.export import Dim, export, ExportedProgram
3840
from torch.export._trace import _export
3941

@@ -273,7 +275,6 @@ def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
273275
for output_val in method.outputs:
274276
evalue = method.values[output_val]
275277
self.assertNotEqual(evalue.val.allocation_info, None)
276-
else:
277278
for input_val in method.inputs:
278279
evalue = method.values[input_val]
279280
self.assertEqual(evalue.val.allocation_info, None)
@@ -847,3 +848,19 @@ def test_save_fails(self):
847848
et = edge.to_executorch()
848849
with self.assertRaises(ValueError):
849850
_ = et.save("/tmp/test_save.pt")
851+
852+
def test__transform_override_verifiers(self):
853+
"""Test that _transform can override verifiers in the exported program."""
854+
class MyVerifier(Verifier):
855+
dialect: str = "MY_DIALECT"
856+
def __init__(self):
857+
super().__init__()
858+
859+
model = TestLinear()
860+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
861+
self.assertFalse(issubclass(program.verifiers[0], MyVerifier))
862+
863+
# Apply transformation with custom verifier
864+
transformed = _transform(program, AddToMulPassEdge(), override_verifiers=[MyVerifier])
865+
self.assertTrue(issubclass(transformed.verifiers[0], MyVerifier))
866+
self.assertFalse(issubclass(program.verifiers[0], MyVerifier))

0 commit comments

Comments
 (0)