Skip to content

Commit c23acc1

Browse files
committed
Adding compatibility check in GeoDA for neural network-based classifiers
Signed-off-by: Vishal Gawade <vishalgawade311@gmail.com>
1 parent 34961d3 commit c23acc1

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

art/attacks/evasion/geometric_decision_based_attack.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ def __init__(
9494
"""
9595
super().__init__(estimator=estimator)
9696

97+
# **Compatibility Checks**
98+
if not hasattr(estimator, "input_shape") or not hasattr(estimator, "channels_first"):
99+
raise ValueError(
100+
f"GeoDA is incompatible with {type(estimator)}. "
101+
"Please use a neural network-based classifier."
102+
)
103+
97104
self.batch_size = batch_size
98105
self.norm = norm
99106
self.sub_dim = sub_dim
@@ -102,16 +109,18 @@ def __init__(
102109
self.lambda_param = lambda_param
103110
self.sigma = sigma
104111
self._targeted = False
105-
106112
self.verbose = verbose
113+
107114
self._check_params()
108115

109116
self.sub_basis: np.ndarray
110117
self.nb_calls = 0
111118
self.clip_min = 0.0
112119
self.clip_max = 0.0
120+
113121
if self.estimator.input_shape is None: # pragma: no cover
114-
raise ValueError("The `input_shape` of the is required but None.")
122+
raise ValueError("The `input_shape` of the estimator is required but None.")
123+
115124
self.nb_channels = (
116125
self.estimator.input_shape[0] if self.estimator.channels_first else self.estimator.input_shape[2]
117126
)
@@ -450,4 +459,4 @@ def _check_params(self) -> None:
450459
# raise ValueError("The argument `targeted` has to be of type bool.")
451460

452461
if not isinstance(self.verbose, bool):
453-
raise ValueError("The argument `verbose` has to be of type bool.")
462+
raise ValueError("The argument `verbose` has to be of type bool.")

requirements_test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pycodestyle==2.12.1
6060
black==24.8.0
6161
ruff==0.5.5
6262
types-six==1.16.21.9
63-
types-PyYAML==6.0.12.20240724
63+
types-PyYAML==6.0.12.20240917
6464
types-setuptools==71.1.0.20240726
6565

6666
# other
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# test_geoda_incompatibility.py
2+
3+
import pytest
4+
from sklearn.ensemble import RandomForestClassifier
5+
from sklearn.datasets import load_iris
6+
from sklearn.model_selection import train_test_split
7+
from art.estimators.classification import SklearnClassifier
8+
from art.attacks.evasion import GeoDA
9+
10+
11+
def test_geoda_with_random_forest():
12+
# Load the Iris dataset
13+
data = load_iris()
14+
X_train, X_test, y_train, y_test = train_test_split(
15+
data.data, data.target, test_size=0.2, random_state=42
16+
)
17+
18+
# Train a RandomForestClassifier
19+
model = RandomForestClassifier()
20+
model.fit(X_train, y_train)
21+
22+
# Wrap the model with ART's SklearnClassifier
23+
classifier = SklearnClassifier(model=model)
24+
25+
# Expect GeoDA to raise ValueError when used with RandomForestClassifier
26+
with pytest.raises(ValueError, match="GeoDA is incompatible with"):
27+
attack = GeoDA(classifier)
28+
attack.generate(X_test)

0 commit comments

Comments
 (0)