Skip to content

Commit baa25d8

Browse files
committedApr 2, 2024
fix initializer
1 parent 092dfa2 commit baa25d8

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed
 

‎_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from textwrap import dedent
33
import numpy as np
44
from onnx import ModelProto, TensorProto
5+
from onnx.checker import check_model
56
from onnx.defs import onnx_opset_version
67
from onnx.reference import ReferenceEvaluator
78
from onnx_array_api.ext_test_case import ExtTestCase
@@ -39,7 +40,7 @@ def light_api(
3940
4041
g = GraphBuilder({'': 19})
4142
g.make_tensor_input("X", TensorProto.FLOAT, ())
42-
light_api(g.op, X)
43+
light_api(g.op, "X")
4344
g.make_tensor_output("Y", TensorProto.FLOAT, ())
4445
model = g.to_onnx()
4546
"""
@@ -78,18 +79,43 @@ def test_zdoc(self):
7879
code = translate(onx, api="builder")
7980
expected = dedent(
8081
"""
81-
(
82-
start()
83-
.vin("X")
84-
.reshape((-1, 1))
85-
.Transpose(perm=[1, 0])
86-
.rename("Y")
87-
.vout()
88-
.to_onnx()
89-
)"""
82+
def light_api(
83+
op: "GraphBuilder",
84+
X: "FLOAT[]",
85+
):
86+
r = np.array([-1, 1], dtype=np.int64)
87+
r0_0 = op.Reshape(X, r)
88+
Y = op.Transpose(r0_0, perm=[1, 0])
89+
op.Identity(Y, outputs=["Y"])
90+
return Y
91+
92+
g = GraphBuilder({'': 21})
93+
g.make_tensor_input("X", TensorProto.FLOAT, ())
94+
light_api(g.op, "X")
95+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
96+
model = g.to_onnx()
97+
"""
9098
).strip("\n")
9199
self.maxDiff = None
92-
self.assertEqual(expected, code)
100+
self.assertEqual(expected, code.strip("\n"))
101+
102+
def light_api(
103+
op: "GraphBuilder",
104+
X: "FLOAT[]", # noqa: F722
105+
):
106+
r = np.array([-1, 1], dtype=np.int64)
107+
r0_0 = op.Reshape(X, r)
108+
Y = op.Transpose(r0_0, perm=[1, 0])
109+
op.Identity(Y, outputs=["Y"])
110+
return Y
111+
112+
g = GraphBuilder({"": 21})
113+
X = g.make_tensor_input("X", TensorProto.FLOAT, ())
114+
light_api(g.op, X)
115+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
116+
model = g.to_onnx()
117+
self.assertNotEmpty(model)
118+
check_model(model)
93119

94120

95121
if __name__ == "__main__":

‎onnx_array_api/graph_api/graph_builder.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ def __getattr__(self, name):
119119
except AttributeError as e:
120120
raise AttributeError(f"Unable to access attribute {name!r}.") from e
121121

122+
def Initializer(
123+
self, init: Union[TensorProto, np.ndarray], name: Optional[str] = None
124+
) -> str:
125+
"""
126+
Creates an initializer.
127+
128+
:param init: value
129+
:param name: name if value is not a TensorProto
130+
:return: its name
131+
"""
132+
return self.builder.make_initializer(init, name=name, exists=True)
133+
122134
def make_node(
123135
self,
124136
op_type: str,

‎onnx_array_api/translate_api/builder_emitter.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, Dict, List
22
from onnx import TensorProto
3+
from onnx.numpy_helper import to_array
34
from .base_emitter import BaseEmitter
45

56
_types = {
@@ -31,7 +32,7 @@ def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
3132
return []
3233

3334
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
34-
inps = ", ".join(["g.op", *self.inputs])
35+
inps = ", ".join(["g.op", *[f'"{i}"' for i in self.inputs]])
3536
inputs = []
3637
for inp, stype, shape in self.inputs_full_:
3738
inputs.append(f'g.make_tensor_input("{inp}", TensorProto.{stype}, {shape})')
@@ -64,7 +65,14 @@ def _emit_end_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
6465
return []
6566

6667
def _emit_initializer(self, **kwargs: Dict[str, Any]) -> List[str]:
67-
assert False, f"not implemented yet with {kwargs}"
68+
init = kwargs["init"]
69+
if isinstance(init, TensorProto):
70+
assert (
71+
kwargs["name"] == init.name
72+
), f"Name mismatch init.name={init.name!r}, name={kwargs['name']!r}"
73+
self.inits.append(init)
74+
return []
75+
raise AssertionError(f"Unsupported type for an initializer {type(init)}")
6876

6977
def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
7078
name = kwargs["name"]
@@ -90,6 +98,10 @@ def _emit_end_signature(self, **kwargs: Dict[str, Any]) -> List[str]:
9098
for i in self.inputs_full:
9199
rows.append(f" {i},")
92100
rows.append("):")
101+
for init in self.inits:
102+
val = to_array(init)
103+
stype = str(val.dtype).split(".")[-1]
104+
rows.append(f" {init.name} = np.array({val.tolist()}, dtype=np.{stype})")
93105
return rows
94106

95107
def _emit_begin_return(self, **kwargs: Dict[str, Any]) -> List[str]:

0 commit comments

Comments
 (0)
Please sign in to comment.