Skip to content

Commit 664e084

Browse files
authored
Improves translation to GraphBuilder (#95)
* Improves translation to GraphBuilder * ch * fix issue * ir * urls * check
1 parent 689cc6f commit 664e084

File tree

6 files changed

+127
-17
lines changed

6 files changed

+127
-17
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
print_all: false
4343
timeout: 2
4444
retry_count# : 2
45-
exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/
45+
exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
4747
# force_pass : true

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.1
5+
+++++
6+
7+
* :pr:`95`: improves translation to GraphBuilder
8+
49
0.3.0
510
+++++
611

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from onnx_array_api.ext_test_case import ExtTestCase
99
from onnx_array_api.light_api import start
1010
from onnx_array_api.graph_api import GraphBuilder
11-
from onnx_array_api.translate_api import translate
11+
from onnx_array_api.translate_api import translate, Translater
12+
from onnx_array_api.translate_api.builder_emitter import BuilderEmitter
1213

1314

1415
OPSET_API = min(19, onnx_opset_version() - 1)
@@ -19,7 +20,7 @@ def setUp(self):
1920
self.maxDiff = None
2021

2122
def test_exp(self):
22-
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
23+
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
2324
self.assertIsInstance(onx, ModelProto)
2425
self.assertIn("Exp", str(onx))
2526
ref = ReferenceEvaluator(onx)
@@ -38,7 +39,7 @@ def light_api(
3839
op.Identity(Y, outputs=["Y"])
3940
return Y
4041
41-
g = GraphBuilder({'': 19})
42+
g = GraphBuilder({'': 19}, ir_version=10)
4243
g.make_tensor_input("X", TensorProto.FLOAT, ())
4344
light_api(g.op, "X")
4445
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -68,7 +69,7 @@ def light_api(
6869

6970
def test_zdoc(self):
7071
onx = (
71-
start(opset=19)
72+
start(opset=19, ir_version=10)
7273
.vin("X")
7374
.reshape((-1, 1))
7475
.Transpose(perm=[1, 0])
@@ -89,7 +90,7 @@ def light_api(
8990
op.Identity(Y, outputs=["Y"])
9091
return Y
9192
92-
g = GraphBuilder({'': 19})
93+
g = GraphBuilder({'': 19}, ir_version=10)
9394
g.make_tensor_input("X", TensorProto.FLOAT, ())
9495
light_api(g.op, "X")
9596
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -117,6 +118,62 @@ def light_api(
117118
self.assertNotEmpty(model)
118119
check_model(model)
119120

121+
def test_exp_f(self):
122+
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
123+
self.assertIsInstance(onx, ModelProto)
124+
self.assertIn("Exp", str(onx))
125+
ref = ReferenceEvaluator(onx)
126+
a = np.arange(10).astype(np.float32)
127+
got = ref.run(None, {"X": a})[0]
128+
self.assertEqualArray(np.exp(a), got)
129+
130+
tr = Translater(onx, emitter=BuilderEmitter("mm"))
131+
code = tr.export(as_str=True)
132+
133+
expected = dedent(
134+
"""
135+
def light_api(
136+
op: "GraphBuilder",
137+
X: "FLOAT[]",
138+
):
139+
Y = op.Exp(X)
140+
op.Identity(Y, outputs=["Y"])
141+
return Y
142+
143+
144+
def mm() -> "ModelProto":
145+
g = GraphBuilder({'': 19}, ir_version=10)
146+
g.make_tensor_input("X", TensorProto.FLOAT, ())
147+
light_api(g.op, "X")
148+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
149+
model = g.to_onnx()
150+
return model
151+
152+
153+
model = mm()
154+
"""
155+
).strip("\n")
156+
self.assertEqual(expected, code.strip("\n"))
157+
158+
def light_api(
159+
op: "GraphBuilder",
160+
X: "FLOAT[]", # noqa: F722
161+
):
162+
Y = op.Exp(X)
163+
op.Identity(Y, outputs=["Y"])
164+
return Y
165+
166+
g2 = GraphBuilder({"": 19})
167+
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
168+
light_api(g2.op, "X")
169+
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
170+
onx2 = g2.to_onnx()
171+
172+
ref = ReferenceEvaluator(onx2)
173+
a = np.arange(10).astype(np.float32)
174+
got = ref.run(None, {"X": a})[0]
175+
self.assertEqualArray(np.exp(a), got)
176+
120177

121178
if __name__ == "__main__":
122179
unittest.main(verbosity=2)

onnx_array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__ = "0.3.0"
5+
__version__ = "0.3.1"
66
__author__ = "Xavier Dupré"

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
from .base_emitter import BaseEmitter
55

66
_types = {
7+
TensorProto.DOUBLE: "DOUBLE",
78
TensorProto.FLOAT: "FLOAT",
89
TensorProto.FLOAT16: "FLOAT16",
910
TensorProto.INT64: "INT64",
1011
TensorProto.INT32: "INT32",
12+
TensorProto.INT16: "INT16",
13+
TensorProto.UINT64: "UINT64",
14+
TensorProto.UINT32: "UINT32",
15+
TensorProto.UINT16: "UINT16",
16+
TensorProto.STRING: "STRING",
17+
TensorProto.BOOL: "BOOL",
1118
}
1219

1320

@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
2027
Converts event into proper code.
2128
"""
2229

30+
def __init__(self, make_model_function: str = ""):
31+
super().__init__()
32+
self.make_model_function = make_model_function
33+
2334
def join(self, rows: List[str], single_line: bool = False) -> str:
2435
"Join the rows"
2536
assert (
@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2940

3041
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
3142
self.opsets = kwargs.get("opsets", {})
43+
self.ir_version = kwargs.get("ir_version", None)
3244
return []
3345

3446
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4355
)
4456
rows = [
4557
"",
46-
f"g = GraphBuilder({self.opsets})",
58+
(
59+
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
60+
if self.ir_version
61+
else f"GraphBuilder({self.opsets})"
62+
),
4763
*inputs,
4864
f"{self.name}({inps})",
4965
*outputs,
5066
"model = g.to_onnx()",
5167
]
68+
if self.make_model_function:
69+
rows = [
70+
"",
71+
"",
72+
f'def {self.make_model_function}() -> "ModelProto":',
73+
*[" " + _ for _ in rows[1:]],
74+
" return model",
75+
"",
76+
"",
77+
f"model = {self.make_model_function}()",
78+
]
5279
return rows
5380

5481
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
78105
name = kwargs["name"]
79106
itype = kwargs.get("elem_type", 0)
80107
shape = kwargs.get("shape", None)
108+
name = self._clean_result_name(name)
81109
if itype == 0:
82-
inp = "X"
110+
inp = name or "X"
83111
else:
84112
if shape is None:
85-
inp = f'X: "{_itype_to_string(itype)}"'
113+
inp = f'{name}: "{_itype_to_string(itype)}"'
86114
else:
87-
inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
115+
inp = (
116+
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
117+
)
88118
self.inputs_full.append(inp)
89119
self.inputs.append(name)
90120
self.inputs_full_.append((name, _itype_to_string(itype), shape))
@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
113143

114144
def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
115145
name = kwargs["name"]
146+
name = self._clean_result_name(name)
116147
itype = kwargs.get("elem_type", 0)
117148
shape = kwargs.get("shape", None)
118149
self.outputs.append(name)
@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
126157
if kwargs.get("domain", "") != "":
127158
domain = kwargs["domain"]
128159
op_type = f"{domain}.{op_type}"
160+
else:
161+
domain = ""
129162
atts = kwargs.get("atts", {})
130163
args = []
131164
for k, v in atts.items():
@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
134167
raise NotImplementedError("Graph attribute not supported yet.")
135168
args.append(f"{k}={vatt}")
136169

137-
outs = ", ".join(outputs)
138-
inps = ", ".join(inputs)
170+
outs = ", ".join(map(self._clean_result_name, outputs))
171+
inps = ", ".join(map(self._clean_result_name, inputs))
172+
op_type = self._emit_node_type(op_type, domain)
173+
sdomain = "" if not domain else f", domain={domain!r}"
139174
if args:
140175
sargs = ", ".join(args)
141-
row = f" {outs} = op.{op_type}({inps}, {sargs})"
176+
if inps:
177+
row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})"
178+
else:
179+
row = f" {outs} = op.{op_type}({sargs}{sdomain})"
142180
else:
143-
row = f" {outs} = op.{op_type}({inps})"
181+
row = f" {outs} = op.{op_type}({inps}{sdomain})"
144182
return [row]
183+
184+
def _clean_result_name(self, name):
185+
return name
186+
187+
def _emit_node_type(self, op_type, domain):
188+
return op_type

onnx_array_api/translate_api/translate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
3535
last_event = None
3636
if isinstance(self.proto_, ModelProto):
3737
opsets = {d.domain: d.version for d in self.proto_.opset_import}
38-
rows.extend(self.emitter(EventType.START, opsets=opsets))
38+
rows.extend(
39+
self.emitter(
40+
EventType.START, opsets=opsets, ir_version=self.proto_.ir_version
41+
)
42+
)
3943
inputs = self.proto_.graph.input
4044
outputs = self.proto_.graph.output
4145
nodes = self.proto_.graph.node

0 commit comments

Comments
 (0)