Skip to content

Commit 53f4e4a

Browse files
authored
MAINT compatibility sklearn 1.5 (#1074)
1 parent 6739f3d commit 53f4e4a

File tree

6 files changed

+58
-18
lines changed

6 files changed

+58
-18
lines changed

doc/whats_new/0.13.rst

+3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ Bug fixes
1414
Compatibility
1515
.............
1616

17+
- Compatibility with scikit-learn 1.5
18+
:pr:`1074` by :user:`Guillaume Lemaitre <glemaitre>`.
19+
1720
Deprecations
1821
............
1922

imblearn/ensemble/_bagging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
386386
self.sampler_ = clone(self.sampler)
387387
# RandomUnderSampler is not supporting sample_weight. We need to pass
388388
# None.
389-
return super()._fit(X, y, self.max_samples, sample_weight=None)
389+
return super()._fit(X, y, self.max_samples)
390390

391391
# TODO: remove when minimum supported version of scikit-learn is 1.1
392392
@available_if(_estimator_has("decision_function"))

imblearn/ensemble/_easy_ensemble.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
300300
check_target_type(y)
301301
# RandomUnderSampler is not supporting sample_weight. We need to pass
302302
# None.
303-
return super()._fit(X, y, self.max_samples, sample_weight=None)
303+
return super()._fit(X, y, self.max_samples)
304304

305305
# TODO: remove when minimum supported version of scikit-learn is 1.1
306306
@available_if(_estimator_has("decision_function"))
@@ -365,9 +365,11 @@ def base_estimator_(self):
365365
raise error
366366
raise error
367367

368-
def _more_tags(self):
368+
def _get_estimator(self):
369369
if self.estimator is None:
370-
estimator = AdaBoostClassifier(algorithm="SAMME")
371-
else:
372-
estimator = self.estimator
373-
return {"allow_nan": _safe_tags(estimator, "allow_nan")}
370+
return AdaBoostClassifier(algorithm="SAMME")
371+
return self.estimator
372+
373+
# TODO: remove when minimum supported version of scikit-learn is 1.5
374+
def _more_tags(self):
375+
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}

imblearn/over_sampling/_smote/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@
1111
import warnings
1212

1313
import numpy as np
14+
import sklearn
1415
from scipy import sparse
1516
from sklearn.base import clone
1617
from sklearn.exceptions import DataConversionWarning
1718
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
1819
from sklearn.utils import (
19-
_get_column_indices,
2020
_safe_indexing,
2121
check_array,
2222
check_random_state,
2323
)
24+
from sklearn.utils.fixes import parse_version
2425
from sklearn.utils.sparsefuncs_fast import (
2526
csr_mean_variance_axis0,
2627
)
@@ -34,6 +35,12 @@
3435
from ...utils.fixes import _is_pandas_df, _mode
3536
from ..base import BaseOverSampler
3637

38+
sklearn_version = parse_version(sklearn.__version__).base_version
39+
if parse_version(sklearn_version) < parse_version("1.5"):
40+
from sklearn.utils import _get_column_indices
41+
else:
42+
from sklearn.utils._indexing import _get_column_indices
43+
3744

3845
class BaseSMOTE(BaseOverSampler):
3946
"""Base class for the different SMOTE algorithms."""

imblearn/pipeline.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# Christos Aridas
1313
# Guillaume Lemaitre <g.lemaitre58@gmail.com>
1414
# License: BSD
15+
import sklearn
1516
from sklearn import pipeline
1617
from sklearn.base import clone
17-
from sklearn.utils import Bunch, _print_elapsed_time
18+
from sklearn.utils import Bunch
19+
from sklearn.utils.fixes import parse_version
1820
from sklearn.utils.metaestimators import available_if
1921
from sklearn.utils.validation import check_memory
2022

@@ -34,6 +36,12 @@
3436

3537
__all__ = ["Pipeline", "make_pipeline"]
3638

39+
sklearn_version = parse_version(sklearn.__version__).base_version
40+
if parse_version(sklearn_version) < parse_version("1.5"):
41+
from sklearn.utils import _print_elapsed_time
42+
else:
43+
from sklearn.utils._user_interface import _print_elapsed_time
44+
3745

3846
class Pipeline(_ParamsValidationMixin, pipeline.Pipeline):
3947
"""Pipeline of transforms and resamples with a final estimator.

imblearn/utils/_metadata_requests.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -1086,9 +1086,12 @@ def _serialize(self):
10861086

10871087
def __iter__(self):
10881088
if self._self_request:
1089-
yield "$self_request", RouterMappingPair(
1090-
mapping=MethodMapping.from_str("one-to-one"),
1091-
router=self._self_request,
1089+
yield (
1090+
"$self_request",
1091+
RouterMappingPair(
1092+
mapping=MethodMapping.from_str("one-to-one"),
1093+
router=self._self_request,
1094+
),
10921095
)
10931096
for name, route_mapping in self._route_mappings.items():
10941097
yield (name, route_mapping)
@@ -1234,7 +1237,7 @@ def __init__(self, name, keys, validate_keys=True):
12341237

12351238
def __get__(self, instance, owner):
12361239
# we would want to have a method which accepts only the expected args
1237-
def func(**kw):
1240+
def func(*args, **kw):
12381241
"""Updates the request for provided parameters
12391242
12401243
This docstring is overwritten below.
@@ -1253,15 +1256,32 @@ def func(**kw):
12531256
f"arguments are: {set(self.keys)}"
12541257
)
12551258

1256-
requests = instance._get_metadata_request()
1259+
# This makes it possible to use the decorated method as an unbound
1260+
# method, for instance when monkeypatching.
1261+
# https://github.com/scikit-learn/scikit-learn/issues/28632
1262+
if instance is None:
1263+
_instance = args[0]
1264+
args = args[1:]
1265+
else:
1266+
_instance = instance
1267+
1268+
# Replicating python's behavior when positional args are given other
1269+
# than `self`, and `self` is only allowed if this method is unbound.
1270+
if args:
1271+
raise TypeError(
1272+
f"set_{self.name}_request() takes 0 positional argument but"
1273+
f" {len(args)} were given"
1274+
)
1275+
1276+
requests = _instance._get_metadata_request()
12571277
method_metadata_request = getattr(requests, self.name)
12581278

12591279
for prop, alias in kw.items():
12601280
if alias is not UNCHANGED:
12611281
method_metadata_request.add_request(param=prop, alias=alias)
1262-
instance._metadata_request = requests
1282+
_instance._metadata_request = requests
12631283

1264-
return instance
1284+
return _instance
12651285

12661286
# Now we set the relevant attributes of the function so that it seems
12671287
# like a normal method to the end user, with known expected arguments.
@@ -1525,13 +1545,13 @@ def process_routing(_obj, _method, /, **kwargs):
15251545
metadata to corresponding methods or corresponding child objects. The object
15261546
names are those defined in `obj.get_metadata_routing()`.
15271547
"""
1528-
if not _routing_enabled() and not kwargs:
1548+
if not kwargs:
15291549
# If routing is not enabled and kwargs are empty, then we don't have to
15301550
# try doing any routing, we can simply return a structure which returns
15311551
# an empty dict on routed_params.ANYTHING.ANY_METHOD.
15321552
class EmptyRequest:
15331553
def get(self, name, default=None):
1534-
return default if default else {}
1554+
return Bunch(**{method: dict() for method in METHODS})
15351555

15361556
def __getitem__(self, name):
15371557
return Bunch(**{method: dict() for method in METHODS})

0 commit comments

Comments
 (0)