Skip to content

Commit 5c301d7

Browse files
committed
added explanation about the GridSearchCV best_score_ attribute
1 parent a8cf7b7 commit 5c301d7

File tree

2 files changed

+329
-16
lines changed

2 files changed

+329
-16
lines changed

code/bonus/svm_iris_pipeline_and_gridsearch.ipynb

+156-13
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@
4242
"output_type": "stream",
4343
"text": [
4444
"Sebastian Raschka \n",
45-
"Last updated: 11/30/2015 \n",
45+
"Last updated: 01/20/2016 \n",
4646
"\n",
47-
"CPython 3.5.0\n",
48-
"IPython 4.0.0\n",
47+
"CPython 3.5.1\n",
48+
"IPython 4.0.1\n",
4949
"\n",
5050
"numpy 1.10.1\n",
5151
"pandas 0.17.1\n",
@@ -77,7 +77,7 @@
7777
"name": "stderr",
7878
"output_type": "stream",
7979
"text": [
80-
"[Parallel(n_jobs=-1)]: Done 40 out of 40 | elapsed: 0.1s finished\n"
80+
"[Parallel(n_jobs=-1)]: Done 40 out of 40 | elapsed: 0.2s finished\n"
8181
]
8282
},
8383
{
@@ -89,7 +89,7 @@
8989
" max_iter=-1, probability=False, random_state=None, shrinking=True,\n",
9090
" tol=0.001, verbose=False))]),\n",
9191
" fit_params={}, iid=True, n_jobs=-1,\n",
92-
" param_grid=[{'svc__C': [1, 10, 100, 1000], 'svc__kernel': ['rbf'], 'svc__gamma': [0.001, 0.0001]}],\n",
92+
" param_grid=[{'svc__kernel': ['rbf'], 'svc__C': [1, 10, 100, 1000], 'svc__gamma': [0.001, 0.0001]}],\n",
9393
" pre_dispatch='2*n_jobs', refit=True, scoring='accuracy', verbose=1)"
9494
]
9595
},
@@ -143,7 +143,7 @@
143143
},
144144
{
145145
"cell_type": "code",
146-
"execution_count": 4,
146+
"execution_count": 3,
147147
"metadata": {
148148
"collapsed": false
149149
},
@@ -153,7 +153,7 @@
153153
"output_type": "stream",
154154
"text": [
155155
"Best GS Score 0.96\n",
156-
"best GS Params {'svc__C': 100, 'svc__kernel': 'rbf', 'svc__gamma': 0.001}\n",
156+
"best GS Params {'svc__kernel': 'rbf', 'svc__C': 100, 'svc__gamma': 0.001}\n",
157157
"\n",
158158
"Train Accuracy: 0.97\n",
159159
"\n",
@@ -162,7 +162,6 @@
162162
}
163163
],
164164
"source": [
165-
"\n",
166165
"print('Best GS Score %.2f' % gs.best_score_)\n",
167166
"print('best GS Params %s' % gs.best_params_)\n",
168167
"\n",
@@ -179,13 +178,157 @@
179178
]
180179
},
181180
{
182-
"cell_type": "code",
183-
"execution_count": null,
181+
"cell_type": "markdown",
184182
"metadata": {
185183
"collapsed": true
186184
},
187-
"outputs": [],
188-
"source": []
185+
"source": [
186+
"### A Note about `GridSearchCV`'s `best_score_` attribute"
187+
]
188+
},
189+
{
190+
"cell_type": "markdown",
191+
"metadata": {},
192+
"source": [
193+
"Please note that `gs.best_score_` is the average k-fold cross-validation score. I.e., if we have a `GridSearchCV` object with 5-fold cross-validation (like the one above), the `best_score_` attribute returns the average score over the 5-folds of the best model. To illustrate this with an example:"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": 4,
199+
"metadata": {
200+
"collapsed": false
201+
},
202+
"outputs": [
203+
{
204+
"data": {
205+
"text/plain": [
206+
"array([ 0.6, 0.4, 0.6, 0.2, 0.6])"
207+
]
208+
},
209+
"execution_count": 4,
210+
"metadata": {},
211+
"output_type": "execute_result"
212+
}
213+
],
214+
"source": [
215+
"from sklearn.cross_validation import StratifiedKFold, cross_val_score\n",
216+
"from sklearn.linear_model import LogisticRegression\n",
217+
"import numpy as np\n",
218+
"\n",
219+
"np.random.seed(0)\n",
220+
"np.set_printoptions(precision=6)\n",
221+
"y = [np.random.randint(3) for i in range(25)]\n",
222+
"X = (y + np.random.randn(25)).reshape(-1, 1)\n",
223+
"\n",
224+
"cv5_idx = list(StratifiedKFold(y, n_folds=5, shuffle=False, random_state=0))\n",
225+
"cross_val_score(LogisticRegression(random_state=123), X, y, cv=cv5_idx)"
226+
]
227+
},
228+
{
229+
"cell_type": "markdown",
230+
"metadata": {},
231+
"source": [
232+
"By executing the code above, we created a simple data set of random integers that shall represent our class labels. Next, we fed the indices of 5 cross-validation folds (`cv3_idx`) to the `cross_val_score` scorer, which returned 5 accuracy scores -- these are the 5 accuracy values for the 5 test folds. \n",
233+
"\n",
234+
"Next, let us use the `GridSearchCV` object and feed it the same 5 cross-validation sets (via the pre-generated `cv3_idx` indices):"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": 5,
240+
"metadata": {
241+
"collapsed": false
242+
},
243+
"outputs": [
244+
{
245+
"name": "stdout",
246+
"output_type": "stream",
247+
"text": [
248+
"Fitting 5 folds for each of 1 candidates, totalling 5 fits\n",
249+
"[CV] ................................................................\n",
250+
"[CV] ....................................... , score=0.600000 - 0.0s\n",
251+
"[CV] ................................................................\n",
252+
"[CV] ....................................... , score=0.400000 - 0.0s\n",
253+
"[CV] ................................................................\n",
254+
"[CV] ....................................... , score=0.600000 - 0.0s\n",
255+
"[CV] ................................................................\n",
256+
"[CV] ....................................... , score=0.200000 - 0.0s\n",
257+
"[CV] ................................................................\n",
258+
"[CV] ....................................... , score=0.600000 - 0.0s\n"
259+
]
260+
},
261+
{
262+
"name": "stderr",
263+
"output_type": "stream",
264+
"text": [
265+
"[Parallel(n_jobs=1)]: Done 5 out of 5 | elapsed: 0.0s finished\n"
266+
]
267+
}
268+
],
269+
"source": [
270+
"from sklearn.grid_search import GridSearchCV\n",
271+
"gs = GridSearchCV(LogisticRegression(), {}, cv=cv5_idx, verbose=3).fit(X, y) "
272+
]
273+
},
274+
{
275+
"cell_type": "markdown",
276+
"metadata": {},
277+
"source": [
278+
"As we can see, the scores for the 5 folds are exactly the same as the ones from `cross_val_score` earlier. \n",
279+
"Now, the best_score_ attribute of the `GridSearchCV` object, which becomes available after `fit`ting, returns the average accuracy score of the best model:"
280+
]
281+
},
282+
{
283+
"cell_type": "code",
284+
"execution_count": 6,
285+
"metadata": {
286+
"collapsed": false
287+
},
288+
"outputs": [
289+
{
290+
"data": {
291+
"text/plain": [
292+
"0.47999999999999998"
293+
]
294+
},
295+
"execution_count": 6,
296+
"metadata": {},
297+
"output_type": "execute_result"
298+
}
299+
],
300+
"source": [
301+
"gs.best_score_"
302+
]
303+
},
304+
{
305+
"cell_type": "markdown",
306+
"metadata": {},
307+
"source": [
308+
"As we can see, the result above is consistent with the average score computed the `cross_val_score`."
309+
]
310+
},
311+
{
312+
"cell_type": "code",
313+
"execution_count": 7,
314+
"metadata": {
315+
"collapsed": false
316+
},
317+
"outputs": [
318+
{
319+
"data": {
320+
"text/plain": [
321+
"0.47999999999999998"
322+
]
323+
},
324+
"execution_count": 7,
325+
"metadata": {},
326+
"output_type": "execute_result"
327+
}
328+
],
329+
"source": [
330+
"cross_val_score(LogisticRegression(), X, y, cv=cv5_idx).mean()"
331+
]
189332
}
190333
],
191334
"metadata": {
@@ -204,7 +347,7 @@
204347
"name": "python",
205348
"nbconvert_exporter": "python",
206349
"pygments_lexer": "ipython3",
207-
"version": "3.5.0"
350+
"version": "3.5.1"
208351
}
209352
},
210353
"nbformat": 4,

0 commit comments

Comments
 (0)