Skip to content

Commit 092dfa2

Browse files
committed
fix order
1 parent 0c2a92d commit 092dfa2

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

_unittests/ut_translate_api/test_translate.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def test_transpose(self):
8080
"""
8181
(
8282
start(opset=19)
83-
.vin('X', elem_type=TensorProto.FLOAT)
8483
.cst(np.array([-1, 1], dtype=np.int64))
8584
.rename('r')
85+
.vin('X', elem_type=TensorProto.FLOAT)
8686
.bring('X', 'r')
8787
.Reshape()
8888
.rename('r0_0')
@@ -166,9 +166,9 @@ def test_export_if(self):
166166
f"""
167167
(
168168
start(opset=19)
169-
.vin('X', elem_type=TensorProto.FLOAT)
170169
.cst(np.array([0.0], dtype=np.float32))
171170
.rename('r')
171+
.vin('X', elem_type=TensorProto.FLOAT)
172172
.bring('X')
173173
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
174174
.rename('Xs')
@@ -202,9 +202,9 @@ def test_aionnxml(self):
202202
"""
203203
(
204204
start(opset=19, opsets={'ai.onnx.ml': 3})
205-
.vin('X', elem_type=TensorProto.FLOAT)
206205
.cst(np.array([-1, 1], dtype=np.int64))
207206
.rename('r')
207+
.vin('X', elem_type=TensorProto.FLOAT)
208208
.bring('X', 'r')
209209
.Reshape()
210210
.rename('USE')

_unittests/ut_translate_api/test_translate_builder.py

+26
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ def light_api(
6565
got = ref.run(None, {"X": a})[0]
6666
self.assertEqualArray(np.exp(a), got)
6767

68+
def test_zdoc(self):
69+
onx = (
70+
start()
71+
.vin("X")
72+
.reshape((-1, 1))
73+
.Transpose(perm=[1, 0])
74+
.rename("Y")
75+
.vout()
76+
.to_onnx()
77+
)
78+
code = translate(onx, api="builder")
79+
expected = dedent(
80+
"""
81+
(
82+
start()
83+
.vin("X")
84+
.reshape((-1, 1))
85+
.Transpose(perm=[1, 0])
86+
.rename("Y")
87+
.vout()
88+
.to_onnx()
89+
)"""
90+
).strip("\n")
91+
self.maxDiff = None
92+
self.assertEqual(expected, code)
93+
6894

6995
if __name__ == "__main__":
7096
unittest.main(verbosity=2)

onnx_array_api/translate_api/translate.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
8282
self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)
8383
)
8484

85+
for i in initializers:
86+
rows.extend(
87+
self.emitter(
88+
EventType.INITIALIZER,
89+
name=i.name,
90+
init=i,
91+
value=to_array_extended(i),
92+
)
93+
)
94+
8595
rows.extend(self.emitter(EventType.BEGIN_SIGNATURE))
8696

8797
for i in inputs:
@@ -107,16 +117,6 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
107117

108118
rows.extend(self.emitter(EventType.END_SIGNATURE))
109119

110-
for i in initializers:
111-
rows.extend(
112-
self.emitter(
113-
EventType.INITIALIZER,
114-
name=i.name,
115-
init=i,
116-
value=to_array_extended(i),
117-
)
118-
)
119-
120120
for node in nodes:
121121
atts = self.extract_attributes(node)
122122
rows.extend(

0 commit comments

Comments
 (0)