1
1
import unittest
2
2
from textwrap import dedent
3
3
import numpy as np
4
+ import onnx .helper as oh
4
5
from onnx import ModelProto , TensorProto
5
6
from onnx .checker import check_model
6
7
from onnx .defs import onnx_opset_version
@@ -29,8 +30,9 @@ def test_exp(self):
29
30
self .assertEqualArray (np .exp (a ), got )
30
31
31
32
code = translate (onx , api = "builder" )
32
- expected = dedent (
33
- """
33
+ expected = (
34
+ dedent (
35
+ """
34
36
def light_api(
35
37
op: "GraphBuilder",
36
38
X: "FLOAT[]",
@@ -42,10 +44,13 @@ def light_api(
42
44
g = GraphBuilder({'': 19}, ir_version=10)
43
45
g.make_tensor_input("X", TensorProto.FLOAT, ())
44
46
light_api(g.op, "X")
45
- g.make_tensor_output("Y", TensorProto.FLOAT, ())
47
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
46
48
model = g.to_onnx()
47
49
"""
48
- ).strip ("\n " )
50
+ )
51
+ .strip ("\n " )
52
+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
53
+ )
49
54
self .assertEqual (expected , code .strip ("\n " ))
50
55
51
56
def light_api (
@@ -59,7 +64,9 @@ def light_api(
59
64
g2 = GraphBuilder ({"" : 19 })
60
65
g2 .make_tensor_input ("X" , TensorProto .FLOAT , ("A" ,))
61
66
light_api (g2 .op , "X" )
62
- g2 .make_tensor_output ("Y" , TensorProto .FLOAT , ("A" ,))
67
+ g2 .make_tensor_output (
68
+ "Y" , TensorProto .FLOAT , ("A" ,), is_dimension = False , indexed = False
69
+ )
63
70
onx2 = g2 .to_onnx ()
64
71
65
72
ref = ReferenceEvaluator (onx2 )
@@ -78,8 +85,9 @@ def test_zdoc(self):
78
85
.to_onnx ()
79
86
)
80
87
code = translate (onx , api = "builder" )
81
- expected = dedent (
82
- """
88
+ expected = (
89
+ dedent (
90
+ """
83
91
def light_api(
84
92
op: "GraphBuilder",
85
93
X: "FLOAT[]",
@@ -93,10 +101,13 @@ def light_api(
93
101
g = GraphBuilder({'': 19}, ir_version=10)
94
102
g.make_tensor_input("X", TensorProto.FLOAT, ())
95
103
light_api(g.op, "X")
96
- g.make_tensor_output("Y", TensorProto.FLOAT, ())
104
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
97
105
model = g.to_onnx()
98
106
"""
99
- ).strip ("\n " )
107
+ )
108
+ .strip ("\n " )
109
+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
110
+ )
100
111
self .maxDiff = None
101
112
self .assertEqual (expected , code .strip ("\n " ))
102
113
@@ -130,8 +141,9 @@ def test_exp_f(self):
130
141
tr = Translater (onx , emitter = BuilderEmitter ("mm" ))
131
142
code = tr .export (as_str = True )
132
143
133
- expected = dedent (
134
- """
144
+ expected = (
145
+ dedent (
146
+ """
135
147
def light_api(
136
148
op: "GraphBuilder",
137
149
X: "FLOAT[]",
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145
157
g = GraphBuilder({'': 19}, ir_version=10)
146
158
g.make_tensor_input("X", TensorProto.FLOAT, ())
147
159
light_api(g.op, "X")
148
- g.make_tensor_output("Y", TensorProto.FLOAT, ())
160
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__ )
149
161
model = g.to_onnx()
150
162
return model
151
163
152
164
153
165
model = mm()
154
166
"""
155
- ).strip ("\n " )
167
+ )
168
+ .strip ("\n " )
169
+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
170
+ )
156
171
self .assertEqual (expected , code .strip ("\n " ))
157
172
158
173
def light_api (
@@ -166,14 +181,104 @@ def light_api(
166
181
g2 = GraphBuilder ({"" : 19 })
167
182
g2 .make_tensor_input ("X" , TensorProto .FLOAT , ("A" ,))
168
183
light_api (g2 .op , "X" )
169
- g2 .make_tensor_output ("Y" , TensorProto .FLOAT , ("A" ,))
184
+ g2 .make_tensor_output (
185
+ "Y" , TensorProto .FLOAT , ("A" ,), is_dimension = False , indexed = False
186
+ )
170
187
onx2 = g2 .to_onnx ()
171
188
172
189
ref = ReferenceEvaluator (onx2 )
173
190
a = np .arange (10 ).astype (np .float32 )
174
191
got = ref .run (None , {"X" : a })[0 ]
175
192
self .assertEqualArray (np .exp (a ), got )
176
193
194
+ def test_local_function (self ):
195
+ new_domain = "custom"
196
+
197
+ linear_regression = oh .make_function (
198
+ new_domain ,
199
+ "LinearRegression" ,
200
+ ["x" , "a" , "b" ],
201
+ ["y" ],
202
+ [
203
+ oh .make_node ("MatMul" , ["x" , "a" ], ["xa" ]),
204
+ oh .make_node ("Add" , ["xa" , "b" ], ["y" ]),
205
+ ],
206
+ [oh .make_opsetid ("" , 14 )],
207
+ [],
208
+ )
209
+
210
+ graph = oh .make_graph (
211
+ [
212
+ oh .make_node (
213
+ "LinearRegression" , ["X" , "A" , "B" ], ["Y1" ], domain = new_domain
214
+ ),
215
+ oh .make_node ("Abs" , ["Y1" ], ["Y" ]),
216
+ ],
217
+ "example" ,
218
+ [
219
+ oh .make_tensor_value_info ("X" , TensorProto .FLOAT , [None , None ]),
220
+ oh .make_tensor_value_info ("A" , TensorProto .FLOAT , [None , None ]),
221
+ oh .make_tensor_value_info ("B" , TensorProto .FLOAT , [None , None ]),
222
+ ],
223
+ [oh .make_tensor_value_info ("Y" , TensorProto .FLOAT , None )],
224
+ )
225
+
226
+ onnx_model = oh .make_model (
227
+ graph ,
228
+ opset_imports = [oh .make_opsetid ("" , 14 ), oh .make_opsetid (new_domain , 1 )],
229
+ functions = [linear_regression ],
230
+ )
231
+ tr = Translater (onnx_model , emitter = BuilderEmitter ("mm" ))
232
+ code = tr .export (as_str = True )
233
+
234
+ expected = (
235
+ dedent (
236
+ """
237
+ def example(
238
+ op: "GraphBuilder",
239
+ X: "FLOAT[, ]",
240
+ A: "FLOAT[, ]",
241
+ B: "FLOAT[, ]",
242
+ ):
243
+ Y1 = op.LinearRegression(X, A, B, domain='custom')
244
+ Y = op.Abs(Y1)
245
+ op.Identity(Y, outputs=["Y"])
246
+ return Y
247
+
248
+
249
+ def make_custom_LinearRegression(g: "GraphBuilder"):
250
+ gr = GraphBuilder({'': 14}, as_function=True)
251
+ x = gr.make_tensor_input('x')
252
+ a = gr.make_tensor_input('a')
253
+ b = gr.make_tensor_input('b')
254
+ op = gr.op
255
+ xa = op.MatMul(x, a)
256
+ y = op.Add(xa, b)
257
+ gr.make_tensor_output(y)
258
+ g.add_function(builder=gr)
259
+ return gr
260
+
261
+
262
+ def mm() -> "ModelProto":
263
+ g = GraphBuilder({'': 14, 'custom': 1}, ir_version=11)
264
+ g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
265
+ g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
266
+ g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
267
+ example(g.op, "X", "A", "B")
268
+ g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
269
+ make_custom_LinearRegression(g)
270
+ model = g.to_onnx()
271
+ return model
272
+
273
+
274
+ model = mm()
275
+ """
276
+ )
277
+ .strip ("\n " )
278
+ .replace ("__SUFFIX__" , ", is_dimension=False, indexed=False" )
279
+ )
280
+ self .assertEqual (expected , code .strip ("\n " ))
281
+
177
282
178
283
if __name__ == "__main__" :
179
284
unittest .main (verbosity = 2 )
0 commit comments