Skip to content

Commit cfa469f

Browse files
committed
[exir] Refactor EdgeProgramManager.transform
Mainly refactor, but also update the dialect verifier in the EP created by the `_transform` when the edge_config has been updated. Differential Revision: [D73205728](https://our.internmc.facebook.com/intern/diff/D73205728/) ghstack-source-id: 278776367 Pull Request resolved: #10275
1 parent d6d68d5 commit cfa469f

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

exir/program/_program.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -1435,22 +1435,32 @@ def transform(
14351435
"""
14361436
compile_config = compile_config or self.compile_config
14371437
new_programs: Dict[str, ExportedProgram] = {}
1438+
1439+
def _transform_and_verify(
1440+
program: ExportedProgram,
1441+
passes: Sequence[PassType],
1442+
verifier: EXIREdgeDialectVerifier,
1443+
) -> ExportedProgram:
1444+
# Overwrite the original verifier with the new one
1445+
# This should be a no-op for the most cases where compile_config is none.
1446+
new_program = _transform(program, *passes, [verifier])
1447+
# ExportedProgram constructor should call the verifier, but
1448+
# the validate() function in the constructor is marked for deprecation.
1449+
verifier(new_program.graph_module)
1450+
return new_program
1451+
1452+
verifier = EXIREdgeDialectVerifier(edge_compile_config=compile_config)
14381453
if isinstance(passes, dict):
14391454
for name, program in self._edge_programs.items():
14401455
if name in passes.keys():
1441-
new_programs[name] = _transform(program, *passes[name])
1442-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1443-
new_programs[name].graph_module
1456+
new_programs[name] = _transform_and_verify(
1457+
program, passes[name], verifier
14441458
)
14451459
else:
14461460
new_programs[name] = copy.deepcopy(program)
1447-
14481461
else: # apply passes to every method
14491462
for name, program in self._edge_programs.items():
1450-
new_programs[name] = _transform(program, *passes)
1451-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1452-
new_programs[name].graph_module
1453-
)
1463+
new_programs[name] = _transform_and_verify(program, passes, verifier)
14541464

14551465
return EdgeProgramManager(
14561466
new_programs, copy.deepcopy(self._config_methods), compile_config

0 commit comments

Comments
 (0)