|
22 | 22 | from executorch.exir.pass_base import ExportPass
|
23 | 23 | from executorch.exir.passes import MemoryPlanningPass
|
24 | 24 | from executorch.exir.program._program import (
|
| 25 | + _transform, |
25 | 26 | EdgeProgramManager,
|
26 | 27 | ExecutorchProgramManager,
|
27 | 28 | to_edge,
|
|
34 | 35 | from executorch.extension.pybindings.portable_lib import (
|
35 | 36 | _load_for_executorch_from_buffer,
|
36 | 37 | )
|
| 38 | +from torch._export.verifier import Verifier |
37 | 39 | from torch.export import Dim, export, ExportedProgram
|
38 | 40 | from torch.export._trace import _export
|
39 | 41 |
|
@@ -273,7 +275,6 @@ def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
|
273 | 275 | for output_val in method.outputs:
|
274 | 276 | evalue = method.values[output_val]
|
275 | 277 | self.assertNotEqual(evalue.val.allocation_info, None)
|
276 |
| - else: |
277 | 278 | for input_val in method.inputs:
|
278 | 279 | evalue = method.values[input_val]
|
279 | 280 | self.assertEqual(evalue.val.allocation_info, None)
|
@@ -847,3 +848,19 @@ def test_save_fails(self):
|
847 | 848 | et = edge.to_executorch()
|
848 | 849 | with self.assertRaises(ValueError):
|
849 | 850 | _ = et.save("/tmp/test_save.pt")
|
| 851 | + |
| 852 | + def test__transform_override_verifiers(self): |
| 853 | + """Test that _transform can override verifiers in the exported program.""" |
| 854 | + class MyVerifier(Verifier): |
| 855 | + dialect: str = "MY_DIALECT" |
| 856 | + def __init__(self): |
| 857 | + super().__init__() |
| 858 | + |
| 859 | + model = TestLinear() |
| 860 | + program = torch.export.export(model, model._get_random_inputs(), strict=True) |
| 861 | + self.assertFalse(issubclass(program.verifiers[0], MyVerifier)) |
| 862 | + |
| 863 | + # Apply transformation with custom verifier |
| 864 | + transformed = _transform(program, AddToMulPassEdge(), override_verifiers=[MyVerifier]) |
| 865 | + self.assertTrue(issubclass(transformed.verifiers[0], MyVerifier)) |
| 866 | + self.assertFalse(issubclass(program.verifiers[0], MyVerifier)) |
0 commit comments