Skip to content

Commit a54de21

Browse files
authored
Better support for ir_version (#82)
* fixes for ir_version * fix ut * fix ut
1 parent 492b6d4 commit a54de21

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

_unittests/ut_light_api/test_backend_export.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
242242

243243
# The following tests are too slow with the reference implementation (Conv).
244244
backend_test.exclude(
245-
"(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_"
245+
"(FLOAT8|BFLOAT16|INT4|_opt_|_3d_|_momentum_|_4d_|int4"
246246
"|test_adagrad"
247247
"|test_adam"
248248
"|test_ai_onnx_ml_"
@@ -270,6 +270,8 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
270270
"|test_squeezenet"
271271
"|test_vgg19"
272272
"|test_zfnet512"
273+
"|test_range_float_type_positive_delta_expanded"
274+
"|test_range_int32_type_negative_delta_expanded"
273275
")"
274276
)
275277

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
149149
"|test_scan_sum)"
150150
)
151151

152-
if onnx_opset_version() < 21:
152+
if onnx_opset_version() < 200:
153153
# The following tests are using types not supported by NumPy.
154154
# They could be if method to_array is extended to support custom
155155
# types the same as the reference implementation does
@@ -164,8 +164,10 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
164164
"|test_cast_no_saturate_"
165165
"|_to_FLOAT8"
166166
"|_FLOAT8"
167+
"|INT4"
167168
"|test_quantizelinear_e4m3fn"
168169
"|test_quantizelinear_e5m2"
170+
"|test_scatter_with"
169171
")"
170172
)
171173

onnx_array_api/graph_api/graph_builder.py

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
optimization_options: Optional[OptimizationOptions] = None,
157157
args: Optional[List[Any]] = None,
158158
verbose: int = 0,
159+
ir_version: Optional[int] = None,
159160
):
160161
self.optimization_options = optimization_options or OptimizationOptions()
161162
self.as_function = as_function
@@ -170,6 +171,7 @@ def __init__(
170171
if isinstance(target_opset_or_existing_proto, int)
171172
else target_opset_or_existing_proto
172173
)
174+
self.ir_version = ir_version
173175
self.nodes = []
174176
self.initializers_dict = {}
175177
self.inputs = []
@@ -186,6 +188,7 @@ def __init__(
186188
), "input_names must be empty if the input is an existing model."
187189
proto = target_opset_or_existing_proto
188190
self.opsets = {d.domain: d.version for d in proto.opset_import}
191+
self.ir_version = ir_version or target_opset_or_existing_proto.ir_version
189192
self.nodes = list(proto.graph.node)
190193
self.initializers_dict = {i.name: i for i in proto.graph.initializer}
191194
self.initializers_dict.update(
@@ -674,6 +677,8 @@ def to_onnx(
674677
if self.verbose:
675678
print("[GraphBuilder] onh.make_model")
676679
model = oh.make_model(graph, opset_imports=opsets)
680+
if self.ir_version:
681+
model.ir_version = self.ir_version
677682
return model
678683

679684
def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]):

0 commit comments

Comments
 (0)