Skip to content

Commit 0c2a92d

Browse files
committed
fix unit test
1 parent af88e8d commit 0c2a92d

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

_unittests/ut_translate_api/test_translate.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def test_transpose(self):
8080
"""
8181
(
8282
start(opset=19)
83+
.vin('X', elem_type=TensorProto.FLOAT)
8384
.cst(np.array([-1, 1], dtype=np.int64))
8485
.rename('r')
85-
.vin('X', elem_type=TensorProto.FLOAT)
8686
.bring('X', 'r')
8787
.Reshape()
8888
.rename('r0_0')
@@ -166,9 +166,9 @@ def test_export_if(self):
166166
f"""
167167
(
168168
start(opset=19)
169+
.vin('X', elem_type=TensorProto.FLOAT)
169170
.cst(np.array([0.0], dtype=np.float32))
170171
.rename('r')
171-
.vin('X', elem_type=TensorProto.FLOAT)
172172
.bring('X')
173173
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
174174
.rename('Xs')
@@ -202,9 +202,9 @@ def test_aionnxml(self):
202202
"""
203203
(
204204
start(opset=19, opsets={'ai.onnx.ml': 3})
205+
.vin('X', elem_type=TensorProto.FLOAT)
205206
.cst(np.array([-1, 1], dtype=np.int64))
206207
.rename('r')
207-
.vin('X', elem_type=TensorProto.FLOAT)
208208
.bring('X', 'r')
209209
.Reshape()
210210
.rename('USE')

_unittests/ut_translate_api/test_translate_classic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ def test_transpose(self):
138138
initializers = []
139139
sparse_initializers = []
140140
functions = []
141+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
141142
initializers.append(
142143
from_array(
143144
np.array([-1, 1], dtype=np.int64),
144145
name='r'
145146
)
146147
)
147-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
148148
nodes.append(
149149
make_node_extended(
150150
'Reshape',
@@ -278,13 +278,13 @@ def test_aionnxml(self):
278278
initializers = []
279279
sparse_initializers = []
280280
functions = []
281+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
281282
initializers.append(
282283
from_array(
283284
np.array([-1, 1], dtype=np.int64),
284285
name='r'
285286
)
286287
)
287-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
288288
nodes.append(
289289
make_node_extended(
290290
'Reshape',

onnx_array_api/translate_api/translate.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
7575
domain=self.proto_.domain,
7676
)
7777
)
78+
elif isinstance(self.proto_, GraphProto):
79+
rows.extend(self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.name))
7880
else:
7981
rows.extend(
8082
self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)

0 commit comments

Comments
 (0)