Skip to content

Commit 191d2dd

Browse files
authored
util: function to extract part of an ONNX model (onnx#2994)
* util: function to extract part of an ONNX model Sometimes, people would like to _extract_ part of an model for development, validation or other purposes. With `onnx.util.extract`, this is doable by specifying the input and output tensor names of the subgraph. Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> * Address review comments onnx#1 * Re-style function of Extractor * Reject empty input/output names * Misc change for readibility * Type annotation Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
1 parent 6d16b32 commit 191d2dd

File tree

4 files changed

+229
-4
lines changed

4 files changed

+229
-4
lines changed

docs/PythonAPIOverview.md

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,13 +257,32 @@ Function `polish_model` runs model checker, optimizer, shape inference engine on
257257
and also strips the doc_string for you.
258258
```python
259259
import onnx
260-
import onnx.utils
261260

262261

263262
model = onnx.load('path/to/the/model.onnx')
264263
polished_model = onnx.utils.polish_model(model)
265264
```
266265

266+
### Extracting Sub-model with Inputs Outputs Tensor Names
267+
268+
Function `extract_model()` extracts sub-model from an ONNX model.
269+
The sub-model is defined by the names of the input and output tensors *exactly*.
270+
271+
```python
272+
import onnx
273+
274+
input_path = 'path/to/the/original/model.onnx'
275+
output_path = 'path/to/save/the/extracted/model.onnx'
276+
input_names = ['input_0', 'input_1', 'input_2']
277+
output_names = ['output_0', 'output_1']
278+
279+
onnx.utils.extract_model(input_path, output_path, input_names, output_names)
280+
```
281+
282+
Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_,
283+
which is defined by the input and output tensors, should not _cut through_ the
284+
subgraph that is connected to the _main graph_ as attributes of these operators.
285+
267286
## Tools
268287
### Updating Model's Inputs Outputs Dimension Sizes with Variable Length
269288
Function `update_inputs_outputs_dims` updates the dimension of the inputs and outputs of the model,

onnx/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from .version import version as __version__ # noqa
1414

1515
# Import common subpackages so they're available when you 'import onnx'
16-
import onnx.helper # noqa
1716
import onnx.checker # noqa
1817
import onnx.defs # noqa
18+
import onnx.helper # noqa
19+
import onnx.utils # noqa
1920

2021
import google.protobuf.message
2122

onnx/test/utils_test.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
from __future__ import print_function
44
from __future__ import unicode_literals
55

6+
import os
7+
import shutil
8+
import tempfile
69
import unittest
7-
import onnx.utils
10+
11+
import onnx
812
from onnx import helper, TensorProto
913

1014

@@ -23,6 +27,53 @@ def test_polish_model(self): # type: () -> None
2327
self.assertEqual(len(polished_def.graph.node), 1)
2428
self.assertFalse(polished_def.graph.node[0].HasField('doc_string'))
2529

30+
def test_extract_model(self): # type: () -> None
31+
def create_tensor(name): # type: ignore
32+
return helper.make_tensor_value_info(name, TensorProto.FLOAT, [1, 2])
33+
A0 = create_tensor("A0")
34+
A1 = create_tensor("A1")
35+
B0 = create_tensor("B0")
36+
B1 = create_tensor("B1")
37+
B2 = create_tensor("B2")
38+
C0 = create_tensor("C0")
39+
C1 = create_tensor("C1")
40+
D0 = create_tensor("D0")
41+
L0_0 = helper.make_node("Add", ["A0", "A1"], ["B0"])
42+
L0_1 = helper.make_node("Sub", ["A0", "A1"], ["B1"])
43+
L0_2 = helper.make_node("Mul", ["A0", "A1"], ["B2"])
44+
L1_0 = helper.make_node("Add", ["B0", "B1"], ["C0"])
45+
L1_1 = helper.make_node("Sub", ["B1", "B2"], ["C1"])
46+
L2_0 = helper.make_node("Mul", ["C0", "C1"], ["D0"])
47+
48+
g0 = helper.make_graph(
49+
[L0_0, L0_1, L0_2, L1_0, L1_1, L2_0],
50+
"test",
51+
[A0, A1],
52+
[D0])
53+
m0 = helper.make_model(g0, producer_name='test')
54+
tdir = tempfile.mkdtemp()
55+
p0 = os.path.join(tdir, "original.onnx")
56+
onnx.save(m0, p0)
57+
58+
p1 = os.path.join(tdir, "extracted.onnx")
59+
input_names = ["B0", "B1", "B2"]
60+
output_names = ["C0", "C1"]
61+
onnx.utils.extract_model(p0, p1, input_names, output_names)
62+
63+
m1 = onnx.load(p1)
64+
self.assertEqual(m1.producer_name, 'onnx.utils.extract_model')
65+
self.assertEqual(m1.ir_version, m0.ir_version)
66+
self.assertEqual(m1.opset_import, m0.opset_import)
67+
self.assertEqual(len(m1.graph.node), 2)
68+
self.assertEqual(len(m1.graph.input), 3)
69+
self.assertEqual(len(m1.graph.output), 2)
70+
self.assertEqual(m1.graph.input[0], B0)
71+
self.assertEqual(m1.graph.input[1], B1)
72+
self.assertEqual(m1.graph.input[2], B2)
73+
self.assertEqual(m1.graph.output[0], C0)
74+
self.assertEqual(m1.graph.output[1], C1)
75+
shutil.rmtree(tdir, ignore_errors=True)
76+
2677

2778
if __name__ == '__main__':
2879
unittest.main()

onnx/utils.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from __future__ import print_function
44
from __future__ import unicode_literals
55

6+
import os
7+
from typing import List, Tuple, Text
8+
69
import onnx.checker
710
import onnx.helper
811
import onnx.optimizer
912
import onnx.shape_inference
1013

11-
from onnx import ModelProto
14+
from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto
1215

1316

1417
def polish_model(model): # type: (ModelProto) -> ModelProto
@@ -21,3 +24,154 @@ def polish_model(model): # type: (ModelProto) -> ModelProto
2124
model = onnx.optimizer.optimize(model)
2225
onnx.checker.check_model(model)
2326
return model
27+
28+
29+
class Extractor:
30+
def __init__(self, model): # type: (ModelProto) -> None
31+
self.model = onnx.shape_inference.infer_shapes(model)
32+
self.graph = self.model.graph
33+
self.wmap = self._build_name2obj_dict(self.graph.initializer)
34+
self.vimap = self._build_name2obj_dict(self.graph.value_info)
35+
36+
@staticmethod
37+
def _build_name2obj_dict(objs): # type: ignore
38+
return {obj.name: obj for obj in objs}
39+
40+
def _collect_new_io_core(self, original_io, io_names_to_extract): # type: ignore
41+
original_io_map = self._build_name2obj_dict(original_io)
42+
original_io_names = set(original_io_map.keys())
43+
s_io_names_to_extract = set(io_names_to_extract)
44+
io_names_to_keep = s_io_names_to_extract & original_io_names
45+
new_io_names_to_add = s_io_names_to_extract - original_io_names
46+
47+
new_io_tensors = []
48+
for name in io_names_to_keep:
49+
new_io_tensors.append(original_io_map[name])
50+
for name in new_io_names_to_add:
51+
# activation become input or output
52+
new_io_tensors.append(self.vimap[name])
53+
54+
# adjust sequence
55+
new_io_tensors_map = self._build_name2obj_dict(new_io_tensors)
56+
return [new_io_tensors_map[name] for name in io_names_to_extract]
57+
58+
def _collect_new_inputs(self, names): # type: (List[Text]) -> List[ValueInfoProto]
59+
return self._collect_new_io_core(self.graph.input, names) # type: ignore
60+
61+
def _collect_new_outputs(self, names): # type: (List[Text]) -> List[ValueInfoProto]
62+
return self._collect_new_io_core(self.graph.output, names) # type: ignore
63+
64+
def _dfs_search_reachable_nodes(
65+
self,
66+
node_output_name, # type: Text
67+
graph_input_names, # type: List[Text]
68+
reachable_nodes, # type: List[NodeProto]
69+
): # type: (...) -> None
70+
if node_output_name in graph_input_names:
71+
return
72+
for node in self.graph.node:
73+
if node in reachable_nodes:
74+
continue
75+
if node_output_name not in node.output:
76+
continue
77+
reachable_nodes.append(node)
78+
for name in node.input:
79+
self._dfs_search_reachable_nodes(name, graph_input_names, reachable_nodes)
80+
81+
def _collect_reachable_nodes(
82+
self,
83+
input_names, # type: List[Text]
84+
output_names, # type: List[Text]
85+
): # type: (...) -> List[NodeProto]
86+
reachable_nodes = list() # type: ignore
87+
for name in output_names:
88+
self._dfs_search_reachable_nodes(name, input_names, reachable_nodes)
89+
# needs to be topology sorted.
90+
nodes = [n for n in self.graph.node if n in reachable_nodes]
91+
return nodes
92+
93+
def _collect_reachable_tensors(
94+
self,
95+
nodes, # type: List[NodeProto]
96+
): # type: (...) -> Tuple[List[TensorProto], List[ValueInfoProto]]
97+
all_tensors_name = set()
98+
for node in nodes:
99+
for name in node.input:
100+
all_tensors_name.add(name)
101+
for name in node.output:
102+
all_tensors_name.add(name)
103+
104+
initializer = [self.wmap[t] for t in self.wmap.keys() if t in all_tensors_name]
105+
value_info = [self.vimap[t] for t in self.vimap.keys() if t in all_tensors_name]
106+
assert(len(self.graph.sparse_initializer) == 0)
107+
assert(len(self.graph.quantization_annotation) == 0)
108+
return (initializer, value_info)
109+
110+
def _make_model(
111+
self,
112+
nodes, # type: List[NodeProto]
113+
inputs, # type: List[ValueInfoProto]
114+
outputs, # type: List[ValueInfoProto]
115+
initializer, # type: List[TensorProto]
116+
value_info # type: List[ValueInfoProto]
117+
): # type: (...) -> ModelProto
118+
name = 'Extracted from {' + self.graph.name + '}'
119+
graph = onnx.helper.make_graph(nodes, name, inputs, outputs, initializer=initializer,
120+
value_info=value_info)
121+
122+
meta = {
123+
'ir_version': self.model.ir_version,
124+
'opset_imports': self.model.opset_import,
125+
'producer_name': 'onnx.utils.extract_model',
126+
}
127+
return onnx.helper.make_model(graph, **meta)
128+
129+
def extract_model(
130+
self,
131+
input_names, # type: List[Text]
132+
output_names, # type: List[Text]
133+
): # type: (...) -> ModelProto
134+
inputs = self._collect_new_inputs(input_names)
135+
outputs = self._collect_new_outputs(output_names)
136+
nodes = self._collect_reachable_nodes(input_names, output_names)
137+
initializer, value_info = self._collect_reachable_tensors(nodes)
138+
model = self._make_model(nodes, inputs, outputs, initializer, value_info)
139+
140+
return model
141+
142+
143+
def extract_model(
144+
input_path, # type: Text
145+
output_path, # type: Text
146+
input_names, # type: List[Text]
147+
output_names # type: List[Text]
148+
): # type: (...) -> None
149+
"""Extracts sub-model from an ONNX model.
150+
151+
The sub-model is defined by the names of the input and output tensors *exactly*.
152+
153+
Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_,
154+
which is defined by the input and output tensors, should not _cut through_ the
155+
subgraph that is connected to the _main graph_ as attributes of these operators.
156+
157+
Arguments:
158+
input_path (string): The path to original ONNX model.
159+
output_path (string): The path to save the extracted ONNX model.
160+
input_names (list of string): The names of the input tensors that to be extracted.
161+
output_names (list of string): The names of the output tensors that to be extracted.
162+
"""
163+
if not os.path.exists(input_path):
164+
raise ValueError("Invalid input model path: %s" % input_path)
165+
if not output_path:
166+
raise ValueError("Output model path shall not be empty!")
167+
if not output_names:
168+
raise ValueError("Output tensor names shall not be empty!")
169+
170+
onnx.checker.check_model(input_path)
171+
model = onnx.load(input_path)
172+
173+
e = Extractor(model)
174+
extracted = e.extract_model(input_names, output_names)
175+
176+
onnx.save(extracted, output_path)
177+
onnx.checker.check_model(output_path)

0 commit comments

Comments
 (0)