Skip to content

Commit 1b49fee

Browse files
committedMar 31, 2024
export to builder
1 parent a54de21 commit 1b49fee

File tree

6 files changed

+278
-12
lines changed

6 files changed

+278
-12
lines changed
 

‎_unittests/ut_translate_api/test_translate.py

-1
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,4 @@ def test_aionnxml(self):
221221

222222

223223
if __name__ == "__main__":
224-
TestTranslate().test_export_if()
225224
unittest.main(verbosity=2)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import unittest
2+
from textwrap import dedent
3+
import numpy as np
4+
from onnx import ModelProto, TensorProto
5+
from onnx.defs import onnx_opset_version
6+
from onnx.reference import ReferenceEvaluator
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.light_api import start
9+
from onnx_array_api.graph_api import GraphBuilder
10+
from onnx_array_api.translate_api import translate
11+
12+
13+
OPSET_API = min(19, onnx_opset_version() - 1)
14+
15+
16+
class TestTranslateBuilder(ExtTestCase):
17+
def setUp(self):
18+
self.maxDiff = None
19+
20+
def test_exp(self):
21+
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
22+
self.assertIsInstance(onx, ModelProto)
23+
self.assertIn("Exp", str(onx))
24+
ref = ReferenceEvaluator(onx)
25+
a = np.arange(10).astype(np.float32)
26+
got = ref.run(None, {"X": a})[0]
27+
self.assertEqualArray(np.exp(a), got)
28+
29+
code = translate(onx, api="builder")
30+
expected = dedent(
31+
"""
32+
def light_api(
33+
op: "GraphBuilder",
34+
X: "FLOAT[]",
35+
):
36+
Y = op.Exp(X)
37+
op.Identity(Y, outputs=["Y"])
38+
return Y
39+
40+
g = GraphBuilder({'': 19})
41+
g.make_tensor_input("X", TensorProto.FLOAT, ())
42+
light_api(g.op, X)
43+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
44+
model = g.to_onnx()
45+
"""
46+
).strip("\n")
47+
self.assertEqual(expected, code.strip("\n"))
48+
49+
def light_api(
50+
op: "GraphBuilder",
51+
X: "FLOAT[]", # noqa: F722
52+
):
53+
Y = op.Exp(X)
54+
op.Identity(Y, outputs=["Y"])
55+
return Y
56+
57+
g2 = GraphBuilder({"": 19})
58+
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
59+
light_api(g2.op, "X")
60+
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
61+
onx2 = g2.to_onnx()
62+
63+
ref = ReferenceEvaluator(onx2)
64+
a = np.arange(10).astype(np.float32)
65+
got = ref.run(None, {"X": a})[0]
66+
self.assertEqualArray(np.exp(a), got)
67+
68+
69+
if __name__ == "__main__":
70+
unittest.main(verbosity=2)

‎onnx_array_api/translate_api/__init__.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from onnx import ModelProto
22
from .translate import Translater
33
from .inner_emitter import InnerEmitter
4+
from .builder_emitter import BuilderEmitter
45

56

67
def translate(proto: ModelProto, single_line: bool = False, api: str = "light") -> str:
@@ -14,7 +15,8 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
1415
default is `"light"` and this is handle by class
1516
:class:`onnx_array_api.translate_api.light_emitter.LightEmitter`,
1617
another value is `"onnx"` which is the inner API implemented
17-
in onnx package.
18+
in onnx package, `"builder"` follows the syntax for the
19+
class :class:`onnx_array_api.graph_api.GraphBuilder`
1820
:return: code
1921
2022
.. runpython::
@@ -35,7 +37,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
3537
code = translate(onx)
3638
print(code)
3739
38-
The inner API from onnx packahe is also available.
40+
The inner API from onnx package is also available.
3941
4042
.. runpython::
4143
:showcode:
@@ -54,11 +56,35 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
5456
)
5557
code = translate(onx, api="onnx")
5658
print(code)
59+
60+
The :class:`GraphBuilder
61+
<onnx_array_api.graph_api.GraphBuilder>` API returns this:
62+
63+
.. runpython::
64+
:showcode:
65+
66+
from onnx_array_api.light_api import start
67+
from onnx_array_api.translate_api import translate
68+
69+
onx = (
70+
start()
71+
.vin("X")
72+
.reshape((-1, 1))
73+
.Transpose(perm=[1, 0])
74+
.rename("Y")
75+
.vout()
76+
.to_onnx()
77+
)
78+
code = translate(onx, api="builder")
79+
print(code)
5780
"""
5881
if api == "light":
5982
tr = Translater(proto)
6083
return tr.export(single_line=single_line, as_str=True)
6184
if api == "onnx":
6285
tr = Translater(proto, emitter=InnerEmitter())
6386
return tr.export(as_str=True)
87+
if api == "builder":
88+
tr = Translater(proto, emitter=BuilderEmitter())
89+
return tr.export(as_str=True)
6490
raise ValueError(f"Unexpected value {api!r} for api.")

‎onnx_array_api/translate_api/base_emitter.py

+28
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class EventType(IntEnum):
2121
FUNCTION_OUTPUT = 12
2222
FUNCTION_ATTRIBUTES = 13
2323
TO_ONNX_FUNCTION = 14
24+
BEGIN_SIGNATURE = 15
25+
END_SIGNATURE = 16
26+
BEGIN_RETURN = 17
27+
END_RETURN = 18
2428

2529
@classmethod
2630
def to_str(cls, self) -> str:
@@ -84,6 +88,18 @@ def __call__(self, event: EventType, **kwargs: Dict[str, Any]) -> List[str]:
8488
if event == EventType.FUNCTION_ATTRIBUTES:
8589
return self._emit_function_attributes(**kwargs)
8690

91+
if event == EventType.BEGIN_SIGNATURE:
92+
return self._emit_begin_signature(**kwargs)
93+
94+
if event == EventType.END_SIGNATURE:
95+
return self._emit_end_signature(**kwargs)
96+
97+
if event == EventType.BEGIN_RETURN:
98+
return self._emit_begin_return(**kwargs)
99+
100+
if event == EventType.END_RETURN:
101+
return self._emit_end_return(**kwargs)
102+
87103
raise ValueError(f"Unexpected event {EventType.to_str(event)}.")
88104

89105
def render_attribute_value(self, value: Any) -> Tuple[List[str], str]:
@@ -222,3 +238,15 @@ def _emit_function_attributes(self, **kwargs: Dict[str, Any]) -> List[str]:
222238
raise NotImplementedError(
223239
f"Method {inspect.currentframe().f_code.co_name!r} was not overloaded."
224240
)
241+
242+
def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
243+
return []
244+
245+
def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
246+
return []
247+
248+
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
249+
return []
250+
251+
def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
252+
return []
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any, Dict, List
2+
from onnx import TensorProto
3+
from .base_emitter import BaseEmitter
4+
5+
_types = {
6+
TensorProto.FLOAT: "FLOAT",
7+
TensorProto.FLOAT16: "FLOAT16",
8+
TensorProto.INT64: "INT64",
9+
TensorProto.INT32: "INT32",
10+
}
11+
12+
13+
def _itype_to_string(itype: int) -> str:
14+
return _types[itype]
15+
16+
17+
class BuilderEmitter(BaseEmitter):
18+
"""
19+
Converts event into proper code.
20+
"""
21+
22+
def join(self, rows: List[str], single_line: bool = False) -> str:
23+
"Join the rows"
24+
assert (
25+
not single_line
26+
), f"The emitter {type(self)} does not work with single_line=True."
27+
return "\n".join(rows)
28+
29+
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
30+
self.opsets = kwargs.get("opsets", {})
31+
return []
32+
33+
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
34+
inps = ", ".join(["g.op", *self.inputs])
35+
inputs = []
36+
for inp, stype, shape in self.inputs_full_:
37+
inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})')
38+
outputs = []
39+
for inp, stype, shape in self.outputs_full_:
40+
outputs.append(
41+
f'g.make_tensor_output("{inp}", TensorProto.{stype}, {shape})'
42+
)
43+
rows = [
44+
"",
45+
f"g = GraphBuilder({self.opsets})",
46+
*inputs,
47+
f"{self.name}({inps})",
48+
*outputs,
49+
"model = g.to_onnx()",
50+
]
51+
return rows
52+
53+
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
54+
self.inputs = []
55+
self.inputs_full = []
56+
self.outputs = []
57+
self.inits = []
58+
self.inputs_full_ = []
59+
self.outputs_full_ = []
60+
self.name = kwargs.get("name", "make_graph")
61+
return []
62+
63+
def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
64+
return []
65+
66+
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
67+
assert False, f"not implemented yet with {kwargs}"
68+
69+
def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
70+
name = kwargs["name"]
71+
itype = kwargs.get("elem_type", 0)
72+
shape = kwargs.get("shape", None)
73+
if itype == 0:
74+
inp = "X"
75+
else:
76+
if shape is None:
77+
inp = f'X: "{_itype_to_string(itype)}"'
78+
else:
79+
inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
80+
self.inputs_full.append(inp)
81+
self.inputs.append(name)
82+
self.inputs_full_.append((name, _itype_to_string(itype), shape))
83+
return []
84+
85+
def _emit_begin_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
86+
return []
87+
88+
def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
89+
rows = ["", f"def {self.name}(", ' op: "GraphBuilder",']
90+
for i in self.inputs_full:
91+
rows.append(f" {i},")
92+
rows.append("):")
93+
return rows
94+
95+
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:
96+
return []
97+
98+
def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
99+
outs = ", ".join(self.outputs)
100+
return [f" return {outs}"]
101+
102+
def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
103+
name = kwargs["name"]
104+
itype = kwargs.get("elem_type", 0)
105+
shape = kwargs.get("shape", None)
106+
self.outputs.append(name)
107+
self.outputs_full_.append((name, _itype_to_string(itype), shape))
108+
return [f' op.Identity({name}, outputs=["{name}"])']
109+
110+
def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
111+
op_type = kwargs["op_type"]
112+
inputs = kwargs["inputs"]
113+
outputs = kwargs["outputs"]
114+
if kwargs.get("domain", "") != "":
115+
domain = kwargs["domain"]
116+
op_type = f"{domain}.{op_type}"
117+
atts = kwargs.get("atts", {})
118+
args = []
119+
for k, v in atts.items():
120+
before, vatt = self.render_attribute_value(v)
121+
if before:
122+
raise NotImplementedError("Graph attribute not supported yet.")
123+
args.append(f"{k}={vatt}")
124+
125+
outs = ", ".join(outputs)
126+
inps = ", ".join(inputs)
127+
if args:
128+
sargs = ", ".join(args)
129+
row = f" {outs} = op.{op_type}({inps}, {sargs})"
130+
else:
131+
row = f" {outs} = op.{op_type}({inps})"
132+
return [row]

‎onnx_array_api/translate_api/translate.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,12 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
7676
)
7777
)
7878
else:
79-
rows.extend(self.emitter(EventType.BEGIN_GRAPH))
80-
81-
for i in initializers:
8279
rows.extend(
83-
self.emitter(
84-
EventType.INITIALIZER,
85-
name=i.name,
86-
init=i,
87-
value=to_array_extended(i),
88-
)
80+
self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)
8981
)
9082

83+
rows.extend(self.emitter(EventType.BEGIN_SIGNATURE))
84+
9185
for i in inputs:
9286
if is_function:
9387
rows.extend(self.emitter(EventType.FUNCTION_INPUT, name=i))
@@ -109,6 +103,18 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
109103
self.emitter(EventType.FUNCTION_ATTRIBUTES, attributes=list(attributes))
110104
)
111105

106+
rows.extend(self.emitter(EventType.END_SIGNATURE))
107+
108+
for i in initializers:
109+
rows.extend(
110+
self.emitter(
111+
EventType.INITIALIZER,
112+
name=i.name,
113+
init=i,
114+
value=to_array_extended(i),
115+
)
116+
)
117+
112118
for node in nodes:
113119
atts = self.extract_attributes(node)
114120
rows.extend(
@@ -122,6 +128,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
122128
)
123129
)
124130

131+
rows.extend(self.emitter(EventType.BEGIN_RETURN))
132+
125133
for o in outputs:
126134
if is_function:
127135
rows.extend(self.emitter(EventType.FUNCTION_OUTPUT, name=o))
@@ -137,6 +145,9 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
137145
),
138146
)
139147
)
148+
149+
rows.extend(self.emitter(EventType.END_RETURN))
150+
140151
if isinstance(self.proto_, (GraphProto, FunctionProto)):
141152
name = self.proto_.name
142153
else:

0 commit comments

Comments
 (0)