Skip to content

Commit 64ec456

Browse files
authored
Merge pull request #66 from psafont/sub-decode
Allow changing subject claim
2 parents 17c3254 + f8d83f2 commit 64ec456

9 files changed

+116
-48
lines changed

docs/options.rst

+3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ General Options:
3030
such as ``RS*`` or ``ES*``. PEM format expected.
3131
``JWT_PRIVATE_KEY`` The private key needed for asymmetric based signing algorithms,
3232
such as ``RS*`` or ``ES*``. PEM format expected.
33+
``JWT_IDENTITY_CLAIM`` Claim in the tokens that is used as source of identity.
34+
For interoperativity, the JWT RFC recommends using ``'sub'``.
35+
Defaults to ``'identity'``.
3336
================================= =========================================
3437

3538

flask_jwt_extended/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def cookie_max_age(self):
223223
# seconds a long ways in the future
224224
return None if self.session_cookie else 2147483647 # 2^31
225225

226+
@property
227+
def identity_claim(self):
228+
return current_app.config['JWT_IDENTITY_CLAIM']
229+
226230
config = _Config()
227231

228232

flask_jwt_extended/jwt_manager.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def _set_default_configuration_options(app):
164164
app.config.setdefault('JWT_BLACKLIST_ENABLED', False)
165165
app.config.setdefault('JWT_BLACKLIST_TOKEN_CHECKS', ['access', 'refresh'])
166166

167+
app.config.setdefault('JWT_IDENTITY_CLAIM', 'identity')
168+
167169
def user_claims_loader(self, callback):
168170
"""
169171
This sets the callback method for adding custom user claims to a JWT.
@@ -319,7 +321,8 @@ def create_refresh_token(self, identity, expires_delta=None):
319321
secret=config.encode_key,
320322
algorithm=config.algorithm,
321323
expires_delta=expires_delta,
322-
csrf=config.csrf_protect
324+
csrf=config.csrf_protect,
325+
identity_claim=config.identity_claim
323326
)
324327
return refresh_token
325328

@@ -352,7 +355,8 @@ def create_access_token(self, identity, fresh=False, expires_delta=None):
352355
expires_delta=expires_delta,
353356
fresh=fresh,
354357
user_claims=self._user_claims_callback(identity),
355-
csrf=config.csrf_protect
358+
csrf=config.csrf_protect,
359+
identity_claim=config.identity_claim
356360
)
357361
return access_token
358362

flask_jwt_extended/tokens.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _encode_jwt(additional_token_data, expires_delta, secret, algorithm):
2525

2626

2727
def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
28-
user_claims, csrf):
28+
user_claims, csrf, identity_claim):
2929
"""
3030
Creates a new encoded (utf-8) access token.
3131
@@ -40,11 +40,12 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
4040
be json serializable
4141
:param csrf: Whether to include a csrf double submit claim in this token
4242
(boolean)
43+
:param identity_claim: Which claim should be used to store the identity in
4344
:return: Encoded access token
4445
"""
4546
# Create the jwt
4647
token_data = {
47-
'identity': identity,
48+
identity_claim: identity,
4849
'fresh': fresh,
4950
'type': 'access',
5051
'user_claims': user_claims,
@@ -54,7 +55,7 @@ def encode_access_token(identity, secret, algorithm, expires_delta, fresh,
5455
return _encode_jwt(token_data, expires_delta, secret, algorithm)
5556

5657

57-
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
58+
def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf, identity_claim):
5859
"""
5960
Creates a new encoded (utf-8) refresh token.
6061
@@ -65,18 +66,19 @@ def encode_refresh_token(identity, secret, algorithm, expires_delta, csrf):
6566
(datetime.timedelta)
6667
:param csrf: Whether to include a csrf double submit claim in this token
6768
(boolean)
69+
:param identity_claim: Which claim should be used to store the identity in
6870
:return: Encoded refresh token
6971
"""
7072
token_data = {
71-
'identity': identity,
73+
identity_claim: identity,
7274
'type': 'refresh',
7375
}
7476
if csrf:
7577
token_data['csrf'] = _create_csrf_token()
7678
return _encode_jwt(token_data, expires_delta, secret, algorithm)
7779

7880

79-
def decode_jwt(encoded_token, secret, algorithm, csrf):
81+
def decode_jwt(encoded_token, secret, algorithm, csrf, identity_claim):
8082
"""
8183
Decodes an encoded JWT
8284
@@ -85,6 +87,7 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
8587
:param algorithm: Algorithm used to encode the JWT
8688
:param csrf: If this token is expected to have a CSRF double submit
8789
value present (boolean)
90+
:param identity_claim: expected claim that is used to identify the subject
8891
:return: Dictionary containing contents of the JWT
8992
"""
9093
# This call verifies the ext, iat, and nbf claims
@@ -93,8 +96,8 @@ def decode_jwt(encoded_token, secret, algorithm, csrf):
9396
# Make sure that any custom claims we expect in the token are present
9497
if 'jti' not in data:
9598
raise JWTDecodeError("Missing claim: jti")
96-
if 'identity' not in data:
97-
raise JWTDecodeError("Missing claim: identity")
99+
if identity_claim not in data:
100+
raise JWTDecodeError("Missing claim: {}".format(identity_claim))
98101
if 'type' not in data or data['type'] not in ('refresh', 'access'):
99102
raise JWTDecodeError("Missing or invalid claim: type")
100103
if data['type'] == 'access':

flask_jwt_extended/utils.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_jwt_identity():
2727
Returns the identity of the JWT in this context. If no JWT is present,
2828
None is returned.
2929
"""
30-
return get_raw_jwt().get('identity', None)
30+
return get_raw_jwt().get(config.identity_claim, None)
3131

3232

3333
def get_jwt_claims():
@@ -63,7 +63,8 @@ def decode_token(encoded_token):
6363
encoded_token=encoded_token,
6464
secret=config.decode_key,
6565
algorithm=config.algorithm,
66-
csrf=config.csrf_protect
66+
csrf=config.csrf_protect,
67+
identity_claim=config.identity_claim
6768
)
6869

6970

@@ -106,7 +107,13 @@ def token_in_blacklist(*args, **kwargs):
106107

107108

108109
def get_csrf_token(encoded_token):
109-
token = decode_jwt(encoded_token, config.decode_key, config.algorithm, csrf=True)
110+
token = decode_jwt(
111+
encoded_token,
112+
config.decode_key,
113+
config.algorithm,
114+
csrf=True,
115+
identity_claim=config.identity_claim
116+
)
110117
return token['csrf']
111118

112119

flask_jwt_extended/view_decorators.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,13 @@ def _decode_jwt_from_headers():
144144
raise InvalidHeaderError(msg)
145145
token = parts[1]
146146

147-
return decode_jwt(token, config.decode_key, config.algorithm, csrf=False)
147+
return decode_jwt(
148+
encoded_token=token,
149+
secret=config.decode_key,
150+
algorithm=config.algorithm,
151+
csrf=False,
152+
identity_claim=config.identity_claim
153+
)
148154

149155

150156
def _decode_jwt_from_cookies(request_type):
@@ -163,7 +169,8 @@ def _decode_jwt_from_cookies(request_type):
163169
encoded_token=encoded_token,
164170
secret=config.decode_key,
165171
algorithm=config.algorithm,
166-
csrf=config.csrf_protect
172+
csrf=config.csrf_protect,
173+
identity_claim=config.identity_claim
167174
)
168175

169176
# Verify csrf double submit tokens match if required

tests/test_config.py

+6
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def test_default_configs(self):
5454
self.assertEqual(config.decode_key, self.app.secret_key)
5555
self.assertEqual(config.cookie_max_age, None)
5656

57+
self.assertEqual(config.identity_claim, 'identity')
58+
5759
def test_override_configs(self):
5860
self.app.config['JWT_TOKEN_LOCATION'] = ['cookies']
5961
self.app.config['JWT_HEADER_NAME'] = 'TestHeader'
@@ -86,6 +88,8 @@ def test_override_configs(self):
8688

8789
self.app.secret_key = 'banana'
8890

91+
self.app.config['JWT_IDENTITY_CLAIM'] = 'foo'
92+
8993
with self.app.test_request_context():
9094
self.assertEqual(config.token_location, ['cookies'])
9195
self.assertEqual(config.jwt_in_cookies, True)
@@ -122,6 +126,8 @@ def test_override_configs(self):
122126

123127
self.assertEqual(config.cookie_max_age, 2147483647)
124128

129+
self.assertEqual(config.identity_claim, 'foo')
130+
125131
def test_invalid_config_options(self):
126132
with self.app.test_request_context():
127133
self.app.config['JWT_TOKEN_LOCATION'] = 'banana'

0 commit comments

Comments
 (0)