-
Notifications
You must be signed in to change notification settings - Fork 338
Update to authlib related to issue #113 - slightly better PR #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,36 @@ | ||
# encoding: utf-8 | ||
# pylint: disable=no-self-use | ||
""" | ||
OAuth2 provider setup. | ||
|
||
It is based on the code from the example: | ||
https://github.com/lepture/example-oauth2-server | ||
|
||
More details are available here: | ||
* http://flask-oauthlib.readthedocs.org/en/latest/oauth2.html | ||
* http://lepture.com/en/2013/create-oauth-server | ||
""" | ||
|
||
from datetime import datetime, timedelta | ||
import functools | ||
import logging | ||
|
||
from flask_login import current_user | ||
from flask_oauthlib import provider | ||
import functools, logging | ||
from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector, current_token | ||
from authlib.flask.oauth2.sqla import ( | ||
create_query_client_func, | ||
create_save_token_func, | ||
create_revocation_endpoint, | ||
create_bearer_token_validator, | ||
) | ||
from authlib.specs.rfc6749 import grants | ||
from werkzeug.security import gen_salt | ||
from app.extensions import api, login_manager | ||
from app.modules.users.models import User | ||
from app.modules.auth.models import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token | ||
from flask_restplus_patched._http import HTTPStatus | ||
import sqlalchemy | ||
|
||
from app.extensions import api, db | ||
|
||
from authlib.specs.rfc6750 import BearerTokenValidator as _BearerTokenValidator | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class OAuth2RequestValidator(provider.OAuth2RequestValidator): | ||
# pylint: disable=abstract-method | ||
@login_manager.request_loader | ||
def load_user_from_request(request): | ||
""" | ||
A project-specific implementation of OAuth2RequestValidator, which connects | ||
our User and OAuth2* implementations together. | ||
Load user from OAuth2 Authentication header. | ||
""" | ||
|
||
def __init__(self): | ||
from app.modules.auth.models import OAuth2Client, OAuth2Grant, OAuth2Token | ||
self._client_class = OAuth2Client | ||
self._grant_class = OAuth2Grant | ||
self._token_class = OAuth2Token | ||
super(OAuth2RequestValidator, self).__init__( | ||
usergetter=self._usergetter, | ||
clientgetter=self._client_class.find, | ||
tokengetter=self._token_class.find, | ||
grantgetter=self._grant_class.find, | ||
tokensetter=self._tokensetter, | ||
grantsetter=self._grantsetter, | ||
) | ||
|
||
def _usergetter(self, username, password, client, request): | ||
# pylint: disable=method-hidden,unused-argument | ||
# Avoid circular dependencies | ||
from app.modules.users.models import User | ||
return User.find_with_password(username, password) | ||
|
||
def _tokensetter(self, token, request, *args, **kwargs): | ||
# pylint: disable=method-hidden,unused-argument | ||
# TODO: review expiration time | ||
expires_in = token['expires_in'] | ||
expires = datetime.utcnow() + timedelta(seconds=expires_in) | ||
|
||
try: | ||
with db.session.begin(): | ||
token_instance = self._token_class( | ||
access_token=token['access_token'], | ||
refresh_token=token.get('refresh_token'), | ||
token_type=token['token_type'], | ||
scopes=[scope for scope in token['scope'].split(' ') if scope], | ||
expires=expires, | ||
client_id=request.client.client_id, | ||
user_id=request.user.id, | ||
) | ||
db.session.add(token_instance) | ||
except sqlalchemy.exc.IntegrityError: | ||
log.exception("Token-setter has failed.") | ||
return None | ||
return token_instance | ||
|
||
def _grantsetter(self, client_id, code, request, *args, **kwargs): | ||
# pylint: disable=method-hidden,unused-argument | ||
# TODO: review expiration time | ||
# decide the expires time yourself | ||
expires = datetime.utcnow() + timedelta(seconds=100) | ||
try: | ||
with db.session.begin(): | ||
grant_instance = self._grant_class( | ||
client_id=client_id, | ||
code=code['code'], | ||
redirect_uri=request.redirect_uri, | ||
scopes=request.scopes, | ||
user=current_user, | ||
expires=expires | ||
) | ||
db.session.add(grant_instance) | ||
except sqlalchemy.exc.IntegrityError: | ||
log.exception("Grant-setter has failed.") | ||
return None | ||
return grant_instance | ||
|
||
from app.modules.users.models import User | ||
if current_token: | ||
user = current_token.user | ||
if user: | ||
return user | ||
user_id = current_token.user.id | ||
if user_id: | ||
return User.query.get(user_id) | ||
return None | ||
|
||
def api_invalid_response(req): | ||
""" | ||
|
@@ -107,19 +41,97 @@ def api_invalid_response(req): | |
api.abort(code=HTTPStatus.UNAUTHORIZED.value) | ||
|
||
|
||
class OAuth2Provider(provider.OAuth2Provider): | ||
""" | ||
A helper class which connects OAuth2RequestValidator with OAuth2Provider. | ||
""" | ||
class BearerTokenValidator(_BearerTokenValidator): | ||
def authenticate_token(self, token_string): | ||
return OAuth2Token.query.filter_by(access_token=token_string).first() | ||
|
||
def request_invalid(self, request): | ||
return False | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(OAuth2Provider, self).__init__(*args, **kwargs) | ||
self.invalid_response(api_invalid_response) | ||
def token_revoked(self, token): | ||
# TODO: return token.revoked | ||
return token.revoked | ||
|
||
class AuthorizationCodeGrant(grants.AuthorizationCodeGrant): | ||
def create_authorization_code(self, client, grant_user, request): | ||
from app.extensions import db | ||
code = gen_salt(48) | ||
item = OAuth2AuthorizationCode( | ||
code=code, | ||
client_id=client.client_id, | ||
redirect_uri=request.redirect_uri, | ||
scope=request.scope, | ||
user_id=grant_user.id, | ||
) | ||
db.session.add(item) | ||
db.session.commit() | ||
return code | ||
|
||
def parse_authorization_code(self, code, client): | ||
item = OAuth2AuthorizationCode.query.filter_by( | ||
code=code, client_id=client.client_id).first() | ||
if item and not item.is_expired(): | ||
return item | ||
|
||
def delete_authorization_code(self, authorization_code): | ||
from app.extensions import db | ||
db.session.delete(authorization_code) | ||
db.session.commit() | ||
|
||
def authenticate_user(self, authorization_code): | ||
return User.query.get(authorization_code.user_id) | ||
|
||
|
||
class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant): | ||
def authenticate_user(self, username, password): | ||
return User.find_with_password(username, password) | ||
|
||
def init_app(self, app): | ||
assert app.config['SECRET_KEY'], "SECRET_KEY must be configured!" | ||
super(OAuth2Provider, self).init_app(app) | ||
self._validator = OAuth2RequestValidator() | ||
|
||
class RefreshTokenGrant(grants.RefreshTokenGrant): | ||
def authenticate_refresh_token(self, refresh_token): | ||
item = OAuth2Token.query.filter_by(refresh_token=refresh_token).first() | ||
if item and not item.is_refresh_token_expired(): | ||
return item | ||
|
||
def authenticate_user(self, credential): | ||
return User.query.get(credential.user_id) | ||
|
||
|
||
class OAuth2ResourceProtector(ResourceProtector): | ||
def __init__( self ): | ||
super().__init__() | ||
|
||
|
||
class OAuth2Provider(AuthorizationServer): | ||
def __init__(self): | ||
super().__init__() | ||
self._require_oauth = None | ||
|
||
def init_app( self, app, query_client=None, save_token=None ): | ||
from app.extensions import db | ||
if query_client is None: | ||
query_client = create_query_client_func(db.session, OAuth2Client) | ||
if save_token is None: | ||
save_token = create_save_token_func(db.session, OAuth2Token) | ||
|
||
super().init_app( | ||
app, query_client=query_client, save_token=save_token) | ||
|
||
# support all grants | ||
self.register_grant(grants.ImplicitGrant) | ||
self.register_grant(grants.ClientCredentialsGrant) | ||
self.register_grant(AuthorizationCodeGrant) | ||
self.register_grant(PasswordGrant) | ||
self.register_grant(RefreshTokenGrant) | ||
|
||
# support revocation | ||
revocation_cls = create_revocation_endpoint(db.session, OAuth2Token) | ||
self.register_endpoint(revocation_cls) | ||
|
||
# protect resource | ||
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token) | ||
OAuth2ResourceProtector.register_token_validator(bearer_cls()) | ||
self._require_oauth = OAuth2ResourceProtector() | ||
|
||
def require_oauth(self, *args, **kwargs): | ||
# pylint: disable=arguments-differ | ||
|
@@ -134,8 +146,8 @@ def require_oauth(self, *args, **kwargs): | |
Returns: | ||
function: a decorator. | ||
""" | ||
locations = kwargs.pop('locations', ('cookies',)) | ||
origin_decorator = super(OAuth2Provider, self).require_oauth(*args, **kwargs) | ||
locations = kwargs.get('locations', ('cookies',)) # don't want to pop - original decorator may need | ||
origin_decorator = self._require_oauth(*args, **kwargs) | ||
|
||
def decorator(func): | ||
# pylint: disable=missing-docstring | ||
|
@@ -148,11 +160,13 @@ def wrapper(*args, **kwargs): | |
# pylint: disable=missing-docstring | ||
if 'headers' not in locations: | ||
# Invalidate authorization if developer specifically | ||
# disables the lookup in the headers. | ||
# disables the lookup in the headers. (this may or may not be worth all the hassle) | ||
request.authorization = '!' | ||
if 'form' in locations: | ||
if 'access_token' in request.form: | ||
request.authorization = 'Bearer %s' % request.form['access_token'] | ||
# don't think we need below lines because bearer validator already registered | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean that this trick will not work anymore anyway or you just don't see a use-case of passing the token via POST form parameters? The use-case is designed for native HTML forms, where we don't control the request from JS, but, instead, rely on a browser to do all the work, and thus, we cannot set the access token in the header. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I didn't understand what it was doing and didn't see a similar thing i the authlib example, so I thought it was related to flask-oauthlib. I will add it back in. |
||
# if 'form' in locations: | ||
# if 'access_token' in request.form: | ||
# request.authorization = 'Bearer %s' % request.form['access_token'] | ||
|
||
return origin_decorated_func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,9 +45,9 @@ class SQLAlchemy(BaseSQLAlchemy): | |
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
if 'session_options' not in kwargs: | ||
kwargs['session_options'] = {} | ||
kwargs['session_options']['autocommit'] = True | ||
# if 'session_options' not in kwargs: | ||
# kwargs['session_options'] = {} | ||
# kwargs['session_options']['autocommit'] = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to sort this out one way or another since it is quite a breaking change. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I had a lot of issues with autocommit, namely with sqlalchemy complaining that a transaction already was started when |
||
# Configure Constraint Naming Conventions: | ||
# http://docs.sqlalchemy.org/en/latest/core/constraints.html#constraint-naming-conventions | ||
kwargs['metadata'] = MetaData( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from flask import g | ||
from flask.sessions import SecureCookieSessionInterface | ||
from flask_login import user_loaded_from_header | ||
from flask_login import LoginManager as OriginalLoginManager | ||
|
||
class CustomSessionInterface(SecureCookieSessionInterface): | ||
"""Prevent creating session from API requests.""" | ||
def save_session(self, *args, **kwargs): | ||
if g.get('login_via_header'): | ||
return | ||
return super(CustomSessionInterface, self).save_session(*args, | ||
**kwargs) | ||
|
||
|
||
@user_loaded_from_header.connect | ||
def user_loaded_from_header(self, user=None): | ||
g.login_via_header = True | ||
|
||
|
||
class LoginManager(OriginalLoginManager): | ||
def init_app(self, app): | ||
app.session_interface = CustomSessionInterface() | ||
super().init_app(app) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to
pop
it since the base OAuth2Provider didn't like extra arguments. If it is working fine withget
, I am fine with that.