3
3
from __future__ import print_function
4
4
from __future__ import unicode_literals
5
5
6
+ import os
7
+ from typing import List , Tuple , Text
8
+
6
9
import onnx .checker
7
10
import onnx .helper
8
11
import onnx .optimizer
9
12
import onnx .shape_inference
10
13
11
- from onnx import ModelProto
14
+ from onnx import ModelProto , NodeProto , TensorProto , ValueInfoProto
12
15
13
16
14
17
def polish_model (model ): # type: (ModelProto) -> ModelProto
@@ -21,3 +24,154 @@ def polish_model(model): # type: (ModelProto) -> ModelProto
21
24
model = onnx .optimizer .optimize (model )
22
25
onnx .checker .check_model (model )
23
26
return model
27
+
28
+
29
+ class Extractor :
30
+ def __init__ (self , model ): # type: (ModelProto) -> None
31
+ self .model = onnx .shape_inference .infer_shapes (model )
32
+ self .graph = self .model .graph
33
+ self .wmap = self ._build_name2obj_dict (self .graph .initializer )
34
+ self .vimap = self ._build_name2obj_dict (self .graph .value_info )
35
+
36
+ @staticmethod
37
+ def _build_name2obj_dict (objs ): # type: ignore
38
+ return {obj .name : obj for obj in objs }
39
+
40
+ def _collect_new_io_core (self , original_io , io_names_to_extract ): # type: ignore
41
+ original_io_map = self ._build_name2obj_dict (original_io )
42
+ original_io_names = set (original_io_map .keys ())
43
+ s_io_names_to_extract = set (io_names_to_extract )
44
+ io_names_to_keep = s_io_names_to_extract & original_io_names
45
+ new_io_names_to_add = s_io_names_to_extract - original_io_names
46
+
47
+ new_io_tensors = []
48
+ for name in io_names_to_keep :
49
+ new_io_tensors .append (original_io_map [name ])
50
+ for name in new_io_names_to_add :
51
+ # activation become input or output
52
+ new_io_tensors .append (self .vimap [name ])
53
+
54
+ # adjust sequence
55
+ new_io_tensors_map = self ._build_name2obj_dict (new_io_tensors )
56
+ return [new_io_tensors_map [name ] for name in io_names_to_extract ]
57
+
58
+ def _collect_new_inputs (self , names ): # type: (List[Text]) -> List[ValueInfoProto]
59
+ return self ._collect_new_io_core (self .graph .input , names ) # type: ignore
60
+
61
+ def _collect_new_outputs (self , names ): # type: (List[Text]) -> List[ValueInfoProto]
62
+ return self ._collect_new_io_core (self .graph .output , names ) # type: ignore
63
+
64
+ def _dfs_search_reachable_nodes (
65
+ self ,
66
+ node_output_name , # type: Text
67
+ graph_input_names , # type: List[Text]
68
+ reachable_nodes , # type: List[NodeProto]
69
+ ): # type: (...) -> None
70
+ if node_output_name in graph_input_names :
71
+ return
72
+ for node in self .graph .node :
73
+ if node in reachable_nodes :
74
+ continue
75
+ if node_output_name not in node .output :
76
+ continue
77
+ reachable_nodes .append (node )
78
+ for name in node .input :
79
+ self ._dfs_search_reachable_nodes (name , graph_input_names , reachable_nodes )
80
+
81
+ def _collect_reachable_nodes (
82
+ self ,
83
+ input_names , # type: List[Text]
84
+ output_names , # type: List[Text]
85
+ ): # type: (...) -> List[NodeProto]
86
+ reachable_nodes = list () # type: ignore
87
+ for name in output_names :
88
+ self ._dfs_search_reachable_nodes (name , input_names , reachable_nodes )
89
+ # needs to be topology sorted.
90
+ nodes = [n for n in self .graph .node if n in reachable_nodes ]
91
+ return nodes
92
+
93
+ def _collect_reachable_tensors (
94
+ self ,
95
+ nodes , # type: List[NodeProto]
96
+ ): # type: (...) -> Tuple[List[TensorProto], List[ValueInfoProto]]
97
+ all_tensors_name = set ()
98
+ for node in nodes :
99
+ for name in node .input :
100
+ all_tensors_name .add (name )
101
+ for name in node .output :
102
+ all_tensors_name .add (name )
103
+
104
+ initializer = [self .wmap [t ] for t in self .wmap .keys () if t in all_tensors_name ]
105
+ value_info = [self .vimap [t ] for t in self .vimap .keys () if t in all_tensors_name ]
106
+ assert (len (self .graph .sparse_initializer ) == 0 )
107
+ assert (len (self .graph .quantization_annotation ) == 0 )
108
+ return (initializer , value_info )
109
+
110
+ def _make_model (
111
+ self ,
112
+ nodes , # type: List[NodeProto]
113
+ inputs , # type: List[ValueInfoProto]
114
+ outputs , # type: List[ValueInfoProto]
115
+ initializer , # type: List[TensorProto]
116
+ value_info # type: List[ValueInfoProto]
117
+ ): # type: (...) -> ModelProto
118
+ name = 'Extracted from {' + self .graph .name + '}'
119
+ graph = onnx .helper .make_graph (nodes , name , inputs , outputs , initializer = initializer ,
120
+ value_info = value_info )
121
+
122
+ meta = {
123
+ 'ir_version' : self .model .ir_version ,
124
+ 'opset_imports' : self .model .opset_import ,
125
+ 'producer_name' : 'onnx.utils.extract_model' ,
126
+ }
127
+ return onnx .helper .make_model (graph , ** meta )
128
+
129
+ def extract_model (
130
+ self ,
131
+ input_names , # type: List[Text]
132
+ output_names , # type: List[Text]
133
+ ): # type: (...) -> ModelProto
134
+ inputs = self ._collect_new_inputs (input_names )
135
+ outputs = self ._collect_new_outputs (output_names )
136
+ nodes = self ._collect_reachable_nodes (input_names , output_names )
137
+ initializer , value_info = self ._collect_reachable_tensors (nodes )
138
+ model = self ._make_model (nodes , inputs , outputs , initializer , value_info )
139
+
140
+ return model
141
+
142
+
143
+ def extract_model (
144
+ input_path , # type: Text
145
+ output_path , # type: Text
146
+ input_names , # type: List[Text]
147
+ output_names # type: List[Text]
148
+ ): # type: (...) -> None
149
+ """Extracts sub-model from an ONNX model.
150
+
151
+ The sub-model is defined by the names of the input and output tensors *exactly*.
152
+
153
+ Note: For control-flow operators, e.g. If and Loop, the _boundary of sub-model_,
154
+ which is defined by the input and output tensors, should not _cut through_ the
155
+ subgraph that is connected to the _main graph_ as attributes of these operators.
156
+
157
+ Arguments:
158
+ input_path (string): The path to original ONNX model.
159
+ output_path (string): The path to save the extracted ONNX model.
160
+ input_names (list of string): The names of the input tensors that to be extracted.
161
+ output_names (list of string): The names of the output tensors that to be extracted.
162
+ """
163
+ if not os .path .exists (input_path ):
164
+ raise ValueError ("Invalid input model path: %s" % input_path )
165
+ if not output_path :
166
+ raise ValueError ("Output model path shall not be empty!" )
167
+ if not output_names :
168
+ raise ValueError ("Output tensor names shall not be empty!" )
169
+
170
+ onnx .checker .check_model (input_path )
171
+ model = onnx .load (input_path )
172
+
173
+ e = Extractor (model )
174
+ extracted = e .extract_model (input_names , output_names )
175
+
176
+ onnx .save (extracted , output_path )
177
+ onnx .checker .check_model (output_path )
0 commit comments