Skip to content

Commit f5cfdd2

Browse files
committed
Allow different authorization header strings (closes #2)
1 parent 5b6aafd commit f5cfdd2

File tree

5 files changed

+47
-12
lines changed

5 files changed

+47
-12
lines changed

Diff for: README.md

+1
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ The available options are:
359359

360360
| Name | Description | Options | Default|
361361
| ------ | ----------- | ------- | ------ |
362+
|JWT_AUTH_HEADER | What to use in the authorization header (ex: Bearer <access_token>) | Any string (empty string to have it just be the access token in the authorization header) | 'Bearer' |
362363
|JWT_ACCESS_TOKEN_EXPIRES | How long an access token should live | datetime.timedelta | 15 minutes|
363364
|JWT_REFRESH_TOKEN_EXPIRES | How long a refresh token should live | datetime.timedelta | 30 days |
364365
|JWT_ALGORITHM | Which algorithm to use with the JWT. | [See here] (https://pyjwt.readthedocs.io/en/latest/algorithms.html) | HS256 |

Diff for: flask_jwt_extended/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
# Defaults
55

6+
# Authorize header type, what we are expecting to see in the auth header
7+
AUTH_HEADER = 'Bearer'
8+
69
# How long an access token will live before it expires.
710
ACCESS_TOKEN_EXPIRES = datetime.timedelta(minutes=15)
811

@@ -28,6 +31,10 @@
2831
BLACKLIST_TOKEN_CHECKS = 'refresh'
2932

3033

34+
def get_auth_header():
35+
return current_app.config.get('JWT_AUTH_HEADER', AUTH_HEADER)
36+
37+
3138
def get_access_expires():
3239
return current_app.config.get('JWT_ACCESS_TOKEN_EXPIRES', ACCESS_TOKEN_EXPIRES)
3340

Diff for: flask_jwt_extended/utils.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from flask import _request_ctx_stack as ctx_stack
1515

1616
from flask_jwt_extended.config import get_access_expires, get_refresh_expires, \
17-
get_algorithm, get_blacklist_enabled, get_blacklist_checks
17+
get_algorithm, get_blacklist_enabled, get_blacklist_checks, get_auth_header
1818
from flask_jwt_extended.exceptions import JWTEncodeError, JWTDecodeError, \
1919
InvalidHeaderError, NoAuthHeaderError, WrongTokenError, RevokedTokenError, \
2020
FreshTokenRequired
@@ -143,15 +143,19 @@ def _decode_jwt_from_request():
143143
raise NoAuthHeaderError("Missing Authorization Header")
144144

145145
# Make sure the header is valid
146+
expected_header = get_auth_header()
146147
parts = auth_header.split()
147-
if parts[0] != 'Bearer':
148-
msg = "Badly formatted authorization header. Should be 'Bearer <JWT>'"
149-
raise InvalidHeaderError(msg)
150-
elif len(parts) != 2:
151-
msg = "Badly formatted authorization header. Should be 'Bearer <JWT>'"
152-
raise InvalidHeaderError(msg)
153-
154-
token = parts[1]
148+
if not expected_header:
149+
if len(parts) != 1:
150+
msg = "Badly formatted authorization header. Should be '<JWT>'"
151+
raise InvalidHeaderError(msg)
152+
token = parts[0]
153+
else:
154+
if parts[0] != expected_header or len(parts) != 2:
155+
msg = "Bad authorization header. Expected '{} <JWT>'".format(expected_header)
156+
raise InvalidHeaderError(msg)
157+
token = parts[1]
158+
155159
secret = _get_secret_key()
156160
algorithm = get_algorithm()
157161
return _decode_jwt(token, secret, algorithm)

Diff for: tests/test_config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from flask_jwt_extended.config import get_access_expires, get_refresh_expires, \
88
get_algorithm, get_blacklist_enabled, get_blacklist_store, \
9-
get_blacklist_checks
9+
get_blacklist_checks, get_auth_header
1010
from flask_jwt_extended import JWTManager
1111

1212

@@ -26,6 +26,7 @@ def test_default_configs(self):
2626
self.assertEqual(get_blacklist_enabled(), False)
2727
self.assertEqual(get_blacklist_store(), None)
2828
self.assertEqual(get_blacklist_checks(), 'refresh')
29+
self.assertEqual(get_auth_header(), 'Bearer')
2930

3031
def test_override_configs(self):
3132
self.app.config['JWT_ACCESS_TOKEN_EXPIRES'] = timedelta(minutes=5)
@@ -34,6 +35,7 @@ def test_override_configs(self):
3435
self.app.config['JWT_BLACKLIST_ENABLED'] = True
3536
self.app.config['JWT_BLACKLIST_STORE'] = simplekv.memory.DictStore()
3637
self.app.config['JWT_BLACKLIST_TOKEN_CHECKS'] = 'all'
38+
self.app.config['JWT_AUTH_HEADER'] = 'JWT'
3739

3840
with self.app.test_request_context():
3941
self.assertEqual(get_access_expires(), timedelta(minutes=5))
@@ -42,3 +44,4 @@ def test_override_configs(self):
4244
self.assertEqual(get_blacklist_enabled(), True)
4345
self.assertIsInstance(get_blacklist_store(), simplekv.memory.DictStore)
4446
self.assertEqual(get_blacklist_checks(), 'all')
47+
self.assertEqual(get_auth_header(), 'JWT')

Diff for: tests/test_protected_endpoints.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def _jwt_post(self, url, jwt):
6262
data = json.loads(response.get_data(as_text=True))
6363
return status_code, data
6464

65-
def _jwt_get(self, url, jwt):
66-
auth_header = 'Bearer {}'.format(jwt)
65+
def _jwt_get(self, url, jwt, auth_header='Bearer'):
66+
auth_header = '{} {}'.format(auth_header, jwt).strip()
6767
response = self.client.get(url, headers={'Authorization': auth_header})
6868
status_code = response.status_code
6969
data = json.loads(response.get_data(as_text=True))
@@ -278,3 +278,23 @@ def claims():
278278
status, data = self._jwt_get('/claims', access_token)
279279
self.assertEqual(status, 200)
280280
self.assertEqual(data, {'username': 'test', 'claims': {'foo': 'bar'}})
281+
282+
def test_different_auth_header(self):
283+
response = self.client.post('/auth/login')
284+
data = json.loads(response.get_data(as_text=True))
285+
access_token = data['access_token']
286+
287+
self.app.config['JWT_AUTH_HEADER'] = 'JWT'
288+
status, data = self._jwt_get('/protected', access_token, auth_header='JWT')
289+
self.assertEqual(data, {'msg': 'hello world'})
290+
self.assertEqual(status, 200)
291+
292+
self.app.config['JWT_AUTH_HEADER'] = ''
293+
status, data = self._jwt_get('/protected', access_token, auth_header='')
294+
self.assertEqual(data, {'msg': 'hello world'})
295+
self.assertEqual(status, 200)
296+
297+
self.app.config['JWT_AUTH_HEADER'] = ''
298+
status, data = self._jwt_get('/protected', access_token, auth_header='Bearer')
299+
self.assertIn('msg', data)
300+
self.assertEqual(status, 422)

0 commit comments

Comments
 (0)