Skip to content

Commit 3413ce5

Browse files
committed
Add pass to tag external constants for delegates
generate pte+ptd file for a delegated linear example Differential Revision: [D73281924](https://our.internmc.facebook.com/intern/diff/D73281924/) ghstack-source-id: 279347167 Pull Request resolved: #10328
1 parent 08c5d93 commit 3413ce5

File tree

5 files changed

+85
-3
lines changed

5 files changed

+85
-3
lines changed

backends/xnnpack/operators/node_visitor.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,15 @@ def get_serialized_buffer_index(
592592
xnn_graph.constant_data.append(
593593
ConstantDataOffset(offset=UINT64_MAX, size=size, named_key=named_key)
594594
)
595+
596+
external_tag = None
597+
if tensor.meta.get("delegate_constant_tag", None) is not None:
598+
external_tag = tensor.meta["delegate_constant_tag"]
595599
self._named_data_store.add_named_data(
596-
named_key, bytes(array), alignment=CONSTANT_TENSOR_ALIGNMENT
600+
named_key,
601+
bytes(array),
602+
alignment=CONSTANT_TENSOR_ALIGNMENT,
603+
external_tag=external_tag,
597604
)
598605

599606
return buffer_idx

backends/xnnpack/runtime/XNNCompiler.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,9 @@ const uint8_t* getConstantDataPtr(
204204
if (!buffer.ok()) {
205205
ET_LOG(
206206
Error,
207-
"Failed to get constant data for key %s",
208-
data_name.c_str());
207+
"Failed to get constant data for key %s from named_data_map. Error code: %u",
208+
data_name.c_str(),
209+
static_cast<uint32_t>(buffer.error()));
209210
return nullptr;
210211
}
211212
const uint8_t* data_ptr =

exir/passes/external_constants_pass.py

+29
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
from typing import List, Optional
10+
911
import torch
1012
from executorch.exir.pass_base import PassResult
1113
from executorch.exir.tensor import TensorSpec
@@ -74,3 +76,30 @@ def external_mutable_weights_pass(
7476
node.meta["constant_tag"] = "_default_external_constant"
7577
mutated = True
7678
return PassResult(gm, mutated)
79+
80+
81+
def xnnpack_external_constants_pass(
82+
gm: GraphModule,
83+
names: Optional[List[str]] = None,
84+
) -> PassResult:
85+
"""
86+
Tag external constants before to_backend. Tagged constants will be saved
87+
to an external file.
88+
89+
Args:
90+
gm: GraphModule to tag.
91+
names: List of constant names to tag. If None, tag all constants.
92+
Returns:
93+
PassResult: The resulting gm, and if it was mutated or not.
94+
"""
95+
mutated = False
96+
for module in gm.modules():
97+
if not isinstance(module, torch.fx.GraphModule):
98+
continue
99+
for node in module.graph.nodes:
100+
if node.op == "placeholder":
101+
# Move specified constants to external file. If none, move all constants.
102+
if names is None or node.name in names:
103+
node.meta["delegate_constant_tag"] = "_default_external_constant"
104+
mutated = True
105+
return PassResult(gm, mutated)

test/models/export_delegated_program.py

+26
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import inspect
1111
import os
1212
import sys
13+
14+
from functools import partial
1315
from typing import Dict, final, Optional, Sequence, Type
1416

1517
import executorch.exir as exir
@@ -21,6 +23,9 @@
2123
from executorch.exir.backend.test.backend_with_compiler_demo import (
2224
BackendWithCompilerDemo,
2325
)
26+
from executorch.exir.passes.external_constants_pass import (
27+
xnnpack_external_constants_pass,
28+
)
2429
from executorch.exir.program import ExecutorchProgramManager
2530
from torch import nn
2631
from torch.export import export
@@ -129,6 +134,7 @@ def export_module_to_program(
129134
constant_tensor_alignment: Optional[int] = None,
130135
delegate_alignment: Optional[int] = None,
131136
method_name: str = "forward",
137+
external_constants: bool = False,
132138
) -> ExecutorchProgramManager:
133139
eager_module = module_class().eval()
134140
inputs = ()
@@ -158,8 +164,13 @@ def forward(self, *args, **kwargs):
158164
XnnpackPartitioner,
159165
)
160166

167+
transform_passes = []
168+
if external_constants:
169+
partial_function = partial(xnnpack_external_constants_pass, names=None)
170+
transform_passes.append(partial_function)
161171
executorch_program = to_edge_transform_and_lower(
162172
exported_program,
173+
transform_passes=transform_passes,
163174
compile_config=edge_config,
164175
partitioner=[XnnpackPartitioner()],
165176
).to_executorch(config=et_config)
@@ -221,6 +232,11 @@ def main() -> None:
221232
parser.add_argument(
222233
"--delegate_alignment", type=int, default=None, help="Delegate alignment."
223234
)
235+
parser.add_argument(
236+
"--external_constants",
237+
action="store_true",
238+
help="Export the model with all constants saved to an external file.",
239+
)
224240
parser.add_argument(
225241
"--outdir",
226242
type=str,
@@ -247,16 +263,26 @@ def main() -> None:
247263
suffix += "-nosegments"
248264
if args.delegate_alignment is not None:
249265
suffix += f"-da{args.delegate_alignment}"
266+
if args.external_constants:
267+
suffix += f"-e"
250268
outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte")
251269
executorch_program = export_module_to_program(
252270
module_class,
253271
backend_id=args.backend_id,
254272
extract_delegate_segments=not args.inline_delegate_segments,
255273
delegate_alignment=args.delegate_alignment,
274+
external_constants=args.external_constants,
256275
)
257276
with open(outfile, "wb") as fp:
258277
fp.write(executorch_program.buffer)
259278
print(f"Exported {module_name} and wrote program data to {outfile}")
279+
if args.external_constants:
280+
# current infra doesnt easily allow renaming this file, so just hackily do it here.
281+
executorch_program._tensor_data[f"{module_name}{suffix}"] = (
282+
executorch_program._tensor_data.pop("_default_external_constant")
283+
)
284+
print(f"Saving external constants to {module_name}{suffix}.ptd")
285+
executorch_program.write_tensor_data_to_file(args.outdir)
260286

261287

262288
if __name__ == "__main__":

test/models/targets.bzl

+19
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,22 @@ def define_common_targets():
206206
],
207207
env = {"PYTORCH_DISABLE_JUSTKNOBS": "1",},
208208
)
209+
210+
runtime.genrule(
211+
name = "exported_program_data",
212+
cmd = "$(exe :export_delegated_program)" +
213+
" --modules ModuleLinear" +
214+
" --backend_id XnnpackBackend" +
215+
" --external_constants" +
216+
" --outdir $OUT",
217+
218+
outs = {
219+
"ModuleLinear-e.pte": ["ModuleLinear-e.pte"],
220+
"ModuleLinear-e.ptd": ["ModuleLinear-e.ptd"],
221+
},
222+
default_outs = ["."],
223+
visibility = [
224+
"//executorch/runtime/executor/test/...",
225+
"//executorch/test/...",
226+
],
227+
)

0 commit comments

Comments
 (0)