Skip to content

Commit eee76cc

Browse files
authored
Lint (#89)
* example * lint * exc * array " * fix * fix missing dependency * yml * disable some tests
1 parent 6076c1c commit eee76cc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+295
-189
lines changed

_doc/examples/plot_benchmark_rf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
4040
)
4141

4242
options = scope.get_options(operator.raw_operator)
43-
if "split" in options:
44-
operator.split = options["split"]
45-
else:
46-
operator.split = None
43+
operator.split = options.get("split", None)
4744
convert_lightgbm(scope, operator, container)
4845

4946

@@ -103,7 +100,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
103100
:return: number of runs, sum of the time, average, median
104101
"""
105102
times = []
106-
for n in range(repeat):
103+
for _n in range(repeat):
107104
perf = time.perf_counter()
108105
fct(X)
109106
delta = time.perf_counter() - perf
@@ -241,7 +238,10 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
241238
# onnxruntime
242239
bar.set_description(f"J={n_j} E={n_estimators} D={max_depth} predictO")
243240
r, t, mean, med = measure_inference(
244-
lambda x: sess.run(None, {"X": x}), X, repeat=repeat, max_time=max_time
241+
lambda x, sess=sess: sess.run(None, {"X": x}),
242+
X,
243+
repeat=repeat,
244+
max_time=max_time,
245245
)
246246
o2 = obs.copy()
247247
o2.update(dict(avg=mean, med=med, n_runs=r, ttime=t, name="ort_"))

_doc/examples/plot_onnxruntime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def loop(n=1000):
8787
x = np.random.randn(n, 2).astype(np.float32)
8888
y = np.random.randn(n, 2).astype(np.float32)
8989

90-
obs = measure_time(lambda: myloss(x, y))
90+
obs = measure_time(lambda x=x, y=y: myloss(x, y))
9191
obs["name"] = "numpy"
9292
obs["n"] = n
9393
data.append(obs)
9494

9595
xort = OrtTensor.from_array(x)
9696
yort = OrtTensor.from_array(y)
97-
obs = measure_time(lambda: ort_myloss(xort, yort))
97+
obs = measure_time(lambda xort=xort, yort=yort: ort_myloss(xort, yort))
9898
obs["name"] = "ort"
9999
obs["n"] = n
100100
data.append(obs)

_unittests/ut_array_api/test_hypothesis_array_api.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from os import getenv
44
from functools import reduce
5+
import packaging.version as pv
56
import numpy as np
67
from operator import mul
78
from hypothesis import given
@@ -44,9 +45,12 @@ class TestHypothesisArraysApis(ExtTestCase):
4445

4546
@classmethod
4647
def setUpClass(cls):
47-
with warnings.catch_warnings():
48-
warnings.simplefilter("ignore")
49-
from numpy import array_api as xp
48+
try:
49+
import array_api_strict as xp
50+
except ImportError:
51+
with warnings.catch_warnings():
52+
warnings.simplefilter("ignore")
53+
from numpy import array_api as xp
5054

5155
api_version = getenv(
5256
"ARRAY_API_TESTS_VERSION",
@@ -63,6 +67,9 @@ def test_strategies(self):
6367
self.assertNotEmpty(self.xps)
6468
self.assertNotEmpty(self.onxps)
6569

70+
@unittest.skipIf(
71+
pv.Version(np.__version__) >= pv.Version("2.0"), reason="abandonned"
72+
)
6673
def test_scalar_strategies(self):
6774
dtypes = dict(
6875
integer_dtypes=self.xps.integer_dtypes(),
@@ -139,6 +146,9 @@ def fctonx(x, kw):
139146
fctonx()
140147
self.assertEqual(len(args_onxp), len(args_np))
141148

149+
@unittest.skipIf(
150+
pv.Version(np.__version__) >= pv.Version("2.0"), reason="abandonned"
151+
)
142152
def test_square_sizes_strategies(self):
143153
dtypes = dict(
144154
integer_dtypes=self.xps.integer_dtypes(),

_unittests/ut_light_api/test_backend_export.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import packaging.version as pv
66
import numpy
77
from numpy.testing import assert_allclose
8+
from onnx.defs import onnx_opset_version
89
import onnx.backend.base
910
import onnx.backend.test
1011
import onnx.shape_inference
@@ -31,7 +32,6 @@
3132

3233
class ReferenceImplementationError(RuntimeError):
3334
"Fails, export cannot be compared."
34-
pass
3535

3636

3737
class ExportWrapper:
@@ -64,7 +64,8 @@ def run(
6464
expected = self.expected_sess.run(names, feeds)
6565
except (RuntimeError, AssertionError, TypeError, KeyError) as e:
6666
raise ReferenceImplementationError(
67-
f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}"
67+
f"ReferenceImplementation fails with "
68+
f"{onnx_simple_text_plot(self.model)}"
6869
f"\n--RAW--\n{self.model}"
6970
) from e
7071

@@ -85,7 +86,7 @@ def run(
8586
new_code = "\n".join(
8687
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
8788
)
88-
raise AssertionError(f"ERROR {e}\n{new_code}")
89+
raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
8990

9091
locs = {
9192
"np": numpy,
@@ -154,7 +155,8 @@ def run(
154155
):
155156
if a.tolist() != b.tolist():
156157
raise AssertionError(
157-
f"Text discrepancies for api {api!r} with a.dtype={a.dtype} "
158+
f"Text discrepancies for api {api!r} "
159+
f"with a.dtype={a.dtype} "
158160
f"and b.dtype={b.dtype}"
159161
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
160162
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
@@ -275,6 +277,22 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
275277
")"
276278
)
277279

280+
if onnx_opset_version() < 22:
281+
backend_test.exclude(
282+
"("
283+
"test_dft_inverse_cpu"
284+
"|test_dft_inverse_opset19_cpu"
285+
"|test_lppool_1d_default_cpu"
286+
"|test_lppool_2d_default_cpu"
287+
"|test_lppool_2d_dilations_cpu"
288+
"|test_lppool_2d_pads_cpu"
289+
"|test_lppool_2d_same_lower_cpu"
290+
"|test_lppool_2d_same_upper_cpu"
291+
"|test_lppool_2d_strides_cpu"
292+
"|test_lppool_3d_default_cpu"
293+
")"
294+
)
295+
278296
if pv.Version(onnx_version) < pv.Version("1.16.0"):
279297
backend_test.exclude("(test_strnorm|test_range_)")
280298

_unittests/ut_light_api/test_light_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def g(self):
484484
def ah(self):
485485
return True
486486

487-
setattr(A, "h", ah)
487+
setattr(A, "h", ah) # noqa: B010
488488

489489
self.assertTrue(A().h())
490490
self.assertIn("(self)", str(inspect.signature(A.h)))

_unittests/ut_plotting/test_dot_plot.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import os
32
import unittest
43

_unittests/ut_plotting/test_text_plot.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import os
32
import textwrap
43
import unittest
@@ -95,6 +94,7 @@ def test_onnx_text_plot_tree_cls_2(self):
9594
+f 0:1 1:0 2:0
9695
"""
9796
).strip(" \n\r")
97+
res = res.replace("np.float32(", "").replace(")", "")
9898
self.assertEqual(expected, res.strip(" \n\r"))
9999

100100
@ignore_warnings((UserWarning, FutureWarning))

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

+19
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,25 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
217217
# The following tests fail due to a type mismatch.
218218
backend_test.exclude("(test_eyelike_without_dtype)")
219219

220+
if onnx_opset_version() < 22:
221+
backend_test.exclude(
222+
"("
223+
"test_adagrad_cpu"
224+
"|test_adagrad_multiple_cpu"
225+
"|test_dft_inverse_cpu"
226+
"|test_dft_inverse_opset19_cpu"
227+
"|test_lppool_1d_default_cpu"
228+
"|test_lppool_2d_default_cpu"
229+
"|test_lppool_2d_dilations_cpu"
230+
"|test_lppool_2d_pads_cpu"
231+
"|test_lppool_2d_same_lower_cpu"
232+
"|test_lppool_2d_same_upper_cpu"
233+
"|test_lppool_2d_strides_cpu"
234+
"|test_lppool_3d_default_cpu"
235+
")"
236+
)
237+
238+
220239
# The following tests fail due to discrepancies (small but still higher than 1e-7).
221240
backend_test.exclude("test_adam_multiple") # 1e-2
222241

_unittests/ut_translate_api/test_translate.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,14 @@ def test_export_if(self):
160160
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
161161

162162
code = translate(onx)
163-
selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
164-
sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
163+
selse = (
164+
"g().cst(np.array([0], dtype=np.int64)).rename('Z')."
165+
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
166+
)
167+
sthen = (
168+
"g().cst(np.array([1], dtype=np.int64)).rename('Z')."
169+
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
170+
)
165171
expected = dedent(
166172
f"""
167173
(

_unittests/ut_translate_api/test_translate_classic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_fft(self):
252252
new_code = "\n".join(
253253
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
254254
)
255-
raise AssertionError(f"ERROR {e}\n{new_code}")
255+
raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
256256

257257
def test_aionnxml(self):
258258
onx = (

_unittests/ut_validation/test_f8.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_fe5m2_to_float32_paper(self):
8888
self.assertEqual(fe5m2_to_float32(int("11111100", 2)), -numpy.inf)
8989

9090
def test_fe4m3fn_to_float32_all(self):
91-
for i in range(0, 256):
91+
for i in range(256):
9292
a = fe4m3_to_float32_float(i)
9393
b = fe4m3_to_float32(i)
9494
if numpy.isnan(a):
@@ -97,7 +97,7 @@ def test_fe4m3fn_to_float32_all(self):
9797
self.assertEqual(a, b)
9898

9999
def test_fe4m3fn_to_float32_all_ml_types(self):
100-
for i in range(0, 256):
100+
for i in range(256):
101101
a = fe4m3_to_float32_float(i)
102102
b = fe4m3_to_float32(i)
103103
c = new_cvt_float32_to_e4m3fn(b)
@@ -188,7 +188,7 @@ def test_search_float32_into_fe5m2_simple(self):
188188
self.assertEqual(b1, b2)
189189

190190
def test_search_float32_into_fe4m3fn_equal(self):
191-
values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
191+
values = [(fe4m3_to_float32_float(i), i) for i in range(256)]
192192
values.sort()
193193

194194
for value, expected in values:
@@ -208,7 +208,7 @@ def test_search_float32_into_fe4m3fn_equal(self):
208208
self.assertIn(nf, (0, 128))
209209

210210
def test_search_float32_into_fe5m2_equal(self):
211-
values = [(fe5m2_to_float32_float(i), i) for i in range(0, 256)]
211+
values = [(fe5m2_to_float32_float(i), i) for i in range(256)]
212212
values.sort()
213213

214214
for value, expected in values:
@@ -233,7 +233,7 @@ def test_search_float32_into_fe5m2_equal(self):
233233
self.assertEqual(fe5m2_to_float32(nf), float(cf))
234234

235235
def test_search_float32_into_fe4m3fn(self):
236-
values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
236+
values = [(fe4m3_to_float32_float(i), i) for i in range(256)]
237237
values.sort()
238238

239239
obs = []
@@ -308,7 +308,7 @@ def test_search_float32_into_fe4m3fn(self):
308308
)
309309

310310
def test_search_float32_into_fe5m2(self):
311-
values = [(fe5m2_to_float32_float(i), i) for i in range(0, 256)]
311+
values = [(fe5m2_to_float32_float(i), i) for i in range(256)]
312312
values.sort()
313313

314314
obs = []
@@ -651,7 +651,7 @@ def test_search_float32_into_fe5m2fnuz_simple(self):
651651
self.assertEqual(expected, got)
652652

653653
def test_fe4m3fnuz_to_float32_all(self):
654-
for i in range(0, 256):
654+
for i in range(256):
655655
a = fe4m3_to_float32_float(i, uz=True)
656656
b = fe4m3_to_float32(i, uz=True)
657657
if numpy.isnan(a):
@@ -660,7 +660,7 @@ def test_fe4m3fnuz_to_float32_all(self):
660660
self.assertEqual(a, b)
661661

662662
def test_fe5m2fnuz_to_float32_all(self):
663-
for i in range(0, 256):
663+
for i in range(256):
664664
a = fe5m2_to_float32_float(i, fn=True, uz=True)
665665
b = fe5m2_to_float32(i, fn=True, uz=True)
666666
if numpy.isnan(a):
@@ -669,7 +669,7 @@ def test_fe5m2fnuz_to_float32_all(self):
669669
self.assertEqual(a, b)
670670

671671
def test_search_float32_into_fe4m3fnuz(self):
672-
values = [(fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256)]
672+
values = [(fe4m3_to_float32_float(i, uz=True), i) for i in range(256)]
673673
values.sort()
674674

675675
obs = []
@@ -715,9 +715,7 @@ def test_search_float32_into_fe4m3fnuz(self):
715715
)
716716

717717
def test_search_float32_into_fe5m2fnuz(self):
718-
values = [
719-
(fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(0, 256)
720-
]
718+
values = [(fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(256)]
721719
values.sort()
722720

723721
obs = []
@@ -1235,7 +1233,7 @@ def test_nan(self):
12351233
expected,
12361234
)
12371235
]
1238-
for i in range(0, 23):
1236+
for i in range(23):
12391237
v = 0x7F800000 | (1 << i)
12401238
f = numpy.uint32(v).view(numpy.float32)
12411239
values.append((i, v, f, expected))

_unittests/ut_xrun_doc/test_documentation_examples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
4949
if verbose:
5050
print(f"failed: {name!r} due to missing dot.")
5151
return 0
52-
raise AssertionError(
52+
raise AssertionError( # noqa: B904
5353
"Example '{}' (cmd: {} - exec_prefix='{}') "
5454
"failed due to\n{}"
5555
"".format(name, cmds, sys.exec_prefix, st)

0 commit comments

Comments
 (0)