Skip to content

Commit 5ce3275

Browse files
committed
fix suffix
1 parent 664e084 commit 5ce3275

File tree

5 files changed

+243
-24
lines changed

5 files changed

+243
-24
lines changed

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from textwrap import dedent
33
import numpy as np
4+
import onnx.helper as oh
45
from onnx import ModelProto, TensorProto
56
from onnx.checker import check_model
67
from onnx.defs import onnx_opset_version
@@ -29,8 +30,9 @@ def test_exp(self):
2930
self.assertEqualArray(np.exp(a), got)
3031

3132
code = translate(onx, api="builder")
32-
expected = dedent(
33-
"""
33+
expected = (
34+
dedent(
35+
"""
3436
def light_api(
3537
op: "GraphBuilder",
3638
X: "FLOAT[]",
@@ -42,10 +44,13 @@ def light_api(
4244
g = GraphBuilder({'': 19}, ir_version=10)
4345
g.make_tensor_input("X", TensorProto.FLOAT, ())
4446
light_api(g.op, "X")
45-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
47+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
4648
model = g.to_onnx()
4749
"""
48-
).strip("\n")
50+
)
51+
.strip("\n")
52+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
53+
)
4954
self.assertEqual(expected, code.strip("\n"))
5055

5156
def light_api(
@@ -59,7 +64,9 @@ def light_api(
5964
g2 = GraphBuilder({"": 19})
6065
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
6166
light_api(g2.op, "X")
62-
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
67+
g2.make_tensor_output(
68+
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
69+
)
6370
onx2 = g2.to_onnx()
6471

6572
ref = ReferenceEvaluator(onx2)
@@ -78,8 +85,9 @@ def test_zdoc(self):
7885
.to_onnx()
7986
)
8087
code = translate(onx, api="builder")
81-
expected = dedent(
82-
"""
88+
expected = (
89+
dedent(
90+
"""
8391
def light_api(
8492
op: "GraphBuilder",
8593
X: "FLOAT[]",
@@ -93,10 +101,13 @@ def light_api(
93101
g = GraphBuilder({'': 19}, ir_version=10)
94102
g.make_tensor_input("X", TensorProto.FLOAT, ())
95103
light_api(g.op, "X")
96-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
104+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
97105
model = g.to_onnx()
98106
"""
99-
).strip("\n")
107+
)
108+
.strip("\n")
109+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
110+
)
100111
self.maxDiff = None
101112
self.assertEqual(expected, code.strip("\n"))
102113

@@ -130,8 +141,9 @@ def test_exp_f(self):
130141
tr = Translater(onx, emitter=BuilderEmitter("mm"))
131142
code = tr.export(as_str=True)
132143

133-
expected = dedent(
134-
"""
144+
expected = (
145+
dedent(
146+
"""
135147
def light_api(
136148
op: "GraphBuilder",
137149
X: "FLOAT[]",
@@ -145,14 +157,17 @@ def mm() -> "ModelProto":
145157
g = GraphBuilder({'': 19}, ir_version=10)
146158
g.make_tensor_input("X", TensorProto.FLOAT, ())
147159
light_api(g.op, "X")
148-
g.make_tensor_output("Y", TensorProto.FLOAT, ())
160+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
149161
model = g.to_onnx()
150162
return model
151163
152164
153165
model = mm()
154166
"""
155-
).strip("\n")
167+
)
168+
.strip("\n")
169+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
170+
)
156171
self.assertEqual(expected, code.strip("\n"))
157172

158173
def light_api(
@@ -166,14 +181,104 @@ def light_api(
166181
g2 = GraphBuilder({"": 19})
167182
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
168183
light_api(g2.op, "X")
169-
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
184+
g2.make_tensor_output(
185+
"Y", TensorProto.FLOAT, ("A",), is_dimension=False, indexed=False
186+
)
170187
onx2 = g2.to_onnx()
171188

172189
ref = ReferenceEvaluator(onx2)
173190
a = np.arange(10).astype(np.float32)
174191
got = ref.run(None, {"X": a})[0]
175192
self.assertEqualArray(np.exp(a), got)
176193

194+
def test_local_function(self):
195+
new_domain = "custom"
196+
197+
linear_regression = oh.make_function(
198+
new_domain,
199+
"LinearRegression",
200+
["x", "a", "b"],
201+
["y"],
202+
[
203+
oh.make_node("MatMul", ["x", "a"], ["xa"]),
204+
oh.make_node("Add", ["xa", "b"], ["y"]),
205+
],
206+
[oh.make_opsetid("", 14)],
207+
[],
208+
)
209+
210+
graph = oh.make_graph(
211+
[
212+
oh.make_node(
213+
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
214+
),
215+
oh.make_node("Abs", ["Y1"], ["Y"]),
216+
],
217+
"example",
218+
[
219+
oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None]),
220+
oh.make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
221+
oh.make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
222+
],
223+
[oh.make_tensor_value_info("Y", TensorProto.FLOAT, None)],
224+
)
225+
226+
onnx_model = oh.make_model(
227+
graph,
228+
opset_imports=[oh.make_opsetid("", 14), oh.make_opsetid(new_domain, 1)],
229+
functions=[linear_regression],
230+
)
231+
tr = Translater(onnx_model, emitter=BuilderEmitter("mm"))
232+
code = tr.export(as_str=True)
233+
234+
expected = (
235+
dedent(
236+
"""
237+
def example(
238+
op: "GraphBuilder",
239+
X: "FLOAT[, ]",
240+
A: "FLOAT[, ]",
241+
B: "FLOAT[, ]",
242+
):
243+
Y1 = op.LinearRegression(X, A, B, domain='custom')
244+
Y = op.Abs(Y1)
245+
op.Identity(Y, outputs=["Y"])
246+
return Y
247+
248+
249+
def make_custom_LinearRegression(g: "GraphBuilder"):
250+
gr = GraphBuilder({'': 14}, as_function=True)
251+
x = gr.make_tensor_input('x')
252+
a = gr.make_tensor_input('a')
253+
b = gr.make_tensor_input('b')
254+
op = gr.op
255+
xa = op.MatMul(x, a)
256+
y = op.Add(xa, b)
257+
gr.make_tensor_output(y)
258+
g.add_function(builder=gr)
259+
return gr
260+
261+
262+
def mm() -> "ModelProto":
263+
g = GraphBuilder({'': 14, 'custom': 1}, ir_version=11)
264+
g.make_tensor_input("X", TensorProto.FLOAT, ('', ''))
265+
g.make_tensor_input("A", TensorProto.FLOAT, ('', ''))
266+
g.make_tensor_input("B", TensorProto.FLOAT, ('', ''))
267+
example(g.op, "X", "A", "B")
268+
g.make_tensor_output("Y", TensorProto.FLOAT, ()__SUFFIX__)
269+
make_custom_LinearRegression(g)
270+
model = g.to_onnx()
271+
return model
272+
273+
274+
model = mm()
275+
"""
276+
)
277+
.strip("\n")
278+
.replace("__SUFFIX__", ", is_dimension=False, indexed=False")
279+
)
280+
self.assertEqual(expected, code.strip("\n"))
281+
177282

178283
if __name__ == "__main__":
179284
unittest.main(verbosity=2)

onnx_array_api/graph_api/graph_builder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def __init__(
194194
self._known_shapes = {}
195195
self._known_types = {}
196196
self.constants_ = {}
197+
self.functions_ = {}
197198
elif isinstance(target_opset_or_existing_proto, ModelProto):
198199
assert (
199200
not input_names
@@ -223,6 +224,8 @@ def __init__(
223224
self.constants_[node.output[0]] = node
224225
self.set_shape(node.output[0], self._get_tensor_shape(node))
225226
self.set_type(node.output[0], self._get_tensor_type(node))
227+
for f in proto.function:
228+
self.add_function(f)
226229
else:
227230
raise NotImplementedError(
228231
f"{type(target_opset_or_existing_proto)} is not supported."
@@ -231,6 +234,14 @@ def __init__(
231234
self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None
232235
self._cache_array = []
233236

237+
def add_local_function(self, domain: str, name: str, gr: "GraphBuilder"):
238+
"Adds a local function."
239+
assert (
240+
domain,
241+
name,
242+
) not in self.functions_, f"Function {(domain, name)} was already added."
243+
self.functions_[domain, name] = gr
244+
234245
def _get_tensor_shape(
235246
self, proto: Union[NodeProto, TensorProto]
236247
) -> Tuple[int, ...]:
@@ -417,6 +428,8 @@ def make_tensor_output(
417428
name: Union[str, List[str]],
418429
elem_type: Optional[int] = None,
419430
shape: Optional[Tuple[int, ...]] = None,
431+
is_dimension: bool = False,
432+
indexed: bool = False,
420433
) -> Union[str, List[str]]:
421434
if isinstance(name, list):
422435
res = []

onnx_array_api/translate_api/base_emitter.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class EventType(IntEnum):
2525
END_SIGNATURE = 16
2626
BEGIN_RETURN = 17
2727
END_RETURN = 18
28+
BEGIN_FUNCTION_SIGNATURE = 19
29+
END_FUNCTION_SIGNATURE = 20
30+
BEGIN_FUNCTION_RETURN = 21
31+
END_FUNCTION_RETURN = 22
2832

2933
@classmethod
3034
def to_str(cls, self) -> str:
@@ -76,6 +80,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
7680
if event == EventType.BEGIN_FUNCTION:
7781
return self._emit_begin_function(**kwargs)
7882

83+
if event == EventType.BEGIN_FUNCTION_SIGNATURE:
84+
return self._emit_begin_function_signature(**kwargs)
85+
86+
if event == EventType.END_FUNCTION_SIGNATURE:
87+
return self._emit_end_function_signature(**kwargs)
88+
7989
if event == EventType.END_FUNCTION:
8090
return self._emit_end_function(**kwargs)
8191

@@ -100,6 +110,12 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
100110
if event == EventType.END_RETURN:
101111
return self._emit_end_return(**kwargs)
102112

113+
if event == EventType.BEGIN_FUNCTION_RETURN:
114+
return self._emit_begin_function_return(**kwargs)
115+
116+
if event == EventType.END_FUNCTION_RETURN:
117+
return self._emit_end_function_return(**kwargs)
118+
103119
raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
104120

105121
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
@@ -224,6 +240,12 @@ def _emit_begin_function(self, **kwargs: Dict[str, Any]) -> List[str]:
224240
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
225241
)
226242

243+
def _emit_begin_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
244+
return []
245+
246+
def _emit_end_function_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
247+
return []
248+
227249
def _emit_function_input(self, **kwargs: Dict[str, Any]) -> List[str]:
228250
raise NotImplementedError(
229251
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
@@ -250,3 +272,9 @@ def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
250272

251273
def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
252274
return []
275+
276+
def _emit_begin_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
277+
return []
278+
279+
def _emit_end_function_return(self, **kwargs: Dict[str, Any]) -> List[str]:
280+
return []

0 commit comments

Comments
 (0)