Skip to content

Update to authlib related to issue #113 - slightly better PR #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@ __pycache__/

# C extensions
*.so
.pytest_cache/

# Distribution / packaging
.Python
/.venv/
/env/
/build/
*.egg-info/
.installed.cfg
*.egg
example.db

# Installer logs
pip-log.txt
Expand All @@ -34,6 +37,7 @@ coverage.xml
.project
.pydevproject


# Rope
.ropeproject

Expand All @@ -49,7 +53,6 @@ docs/_build/
*.bak
local_config.py
static/
example.db
.idea/
clients/*/swagger.json
clients/*/dist
2 changes: 1 addition & 1 deletion app/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
force_auto_coercion()
force_instant_defaults()

from flask_login import LoginManager
from .login import LoginManager
login_manager = LoginManager()

from flask_marshmallow import Marshmallow
Expand Down
32 changes: 19 additions & 13 deletions app/extensions/api/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def decorator(func_or_class):
else:
_oauth_scopes = oauth_scopes

oauth_protection_decorator = oauth2.require_oauth(*_oauth_scopes, locations=locations)
# oauth_protection_decorator = oauth2.require_oauth(*_oauth_scopes, locations=locations)
oauth_protection_decorator = oauth2.require_oauth( *_oauth_scopes )
self._register_access_restriction_decorator(protected_func, oauth_protection_decorator)
oauth_protected_func = oauth_protection_decorator(protected_func)

Expand Down Expand Up @@ -295,21 +296,26 @@ def commit_or_abort(self, session, default_error_message="The operation failed t
session: db.session instance
default_error_message: Custom error message

Exampple:
Example:
>>> with api.commit_or_abort(db.session):
... team = Team(**args)
... db.session.add(team)
... return team
"""
from werkzeug.exceptions import HTTPException
try:
with session.begin():
yield
except ValueError as exception:
log.info("Database transaction was rolled back due to: %r", exception)
http_exceptions.abort(code=HTTPStatus.CONFLICT, message=str(exception))
except sqlalchemy.exc.IntegrityError as exception:
log.info("Database transaction was rolled back due to: %r", exception)
http_exceptions.abort(
code=HTTPStatus.CONFLICT,
message=default_error_message
)
try:
yield session
session.commit()
except ValueError as exception:
log.info( "Database transaction was rolled back due to: %r", exception )
http_exceptions.abort( code=HTTPStatus.CONFLICT, message=str( exception ) )
except sqlalchemy.exc.IntegrityError as exception:
log.info( "Database transaction was rolled back due to: %r", exception )
http_exceptions.abort(
code=HTTPStatus.CONFLICT,
message=default_error_message
)
except HTTPException:
session.rollback()
raise
232 changes: 123 additions & 109 deletions app/extensions/auth/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,36 @@
# encoding: utf-8
# pylint: disable=no-self-use
"""
OAuth2 provider setup.

It is based on the code from the example:
https://github.com/lepture/example-oauth2-server

More details are available here:
* http://flask-oauthlib.readthedocs.org/en/latest/oauth2.html
* http://lepture.com/en/2013/create-oauth-server
"""

from datetime import datetime, timedelta
import functools
import logging

from flask_login import current_user
from flask_oauthlib import provider
import functools, logging
from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector, current_token
from authlib.flask.oauth2.sqla import (
create_query_client_func,
create_save_token_func,
create_revocation_endpoint,
create_bearer_token_validator,
)
from authlib.specs.rfc6749 import grants
from werkzeug.security import gen_salt
from app.extensions import api, login_manager
from app.modules.users.models import User
from app.modules.auth.models import OAuth2Client, OAuth2AuthorizationCode, OAuth2Token
from flask_restplus_patched._http import HTTPStatus
import sqlalchemy

from app.extensions import api, db

from authlib.specs.rfc6750 import BearerTokenValidator as _BearerTokenValidator

log = logging.getLogger(__name__)


class OAuth2RequestValidator(provider.OAuth2RequestValidator):
# pylint: disable=abstract-method
@login_manager.request_loader
def load_user_from_request(request):
"""
A project-specific implementation of OAuth2RequestValidator, which connects
our User and OAuth2* implementations together.
Load user from OAuth2 Authentication header.
"""

def __init__(self):
from app.modules.auth.models import OAuth2Client, OAuth2Grant, OAuth2Token
self._client_class = OAuth2Client
self._grant_class = OAuth2Grant
self._token_class = OAuth2Token
super(OAuth2RequestValidator, self).__init__(
usergetter=self._usergetter,
clientgetter=self._client_class.find,
tokengetter=self._token_class.find,
grantgetter=self._grant_class.find,
tokensetter=self._tokensetter,
grantsetter=self._grantsetter,
)

def _usergetter(self, username, password, client, request):
# pylint: disable=method-hidden,unused-argument
# Avoid circular dependencies
from app.modules.users.models import User
return User.find_with_password(username, password)

def _tokensetter(self, token, request, *args, **kwargs):
# pylint: disable=method-hidden,unused-argument
# TODO: review expiration time
expires_in = token['expires_in']
expires = datetime.utcnow() + timedelta(seconds=expires_in)

try:
with db.session.begin():
token_instance = self._token_class(
access_token=token['access_token'],
refresh_token=token.get('refresh_token'),
token_type=token['token_type'],
scopes=[scope for scope in token['scope'].split(' ') if scope],
expires=expires,
client_id=request.client.client_id,
user_id=request.user.id,
)
db.session.add(token_instance)
except sqlalchemy.exc.IntegrityError:
log.exception("Token-setter has failed.")
return None
return token_instance

def _grantsetter(self, client_id, code, request, *args, **kwargs):
# pylint: disable=method-hidden,unused-argument
# TODO: review expiration time
# decide the expires time yourself
expires = datetime.utcnow() + timedelta(seconds=100)
try:
with db.session.begin():
grant_instance = self._grant_class(
client_id=client_id,
code=code['code'],
redirect_uri=request.redirect_uri,
scopes=request.scopes,
user=current_user,
expires=expires
)
db.session.add(grant_instance)
except sqlalchemy.exc.IntegrityError:
log.exception("Grant-setter has failed.")
return None
return grant_instance

from app.modules.users.models import User
if current_token:
user = current_token.user
if user:
return user
user_id = current_token.user.id
if user_id:
return User.query.get(user_id)
return None

def api_invalid_response(req):
"""
Expand All @@ -107,19 +41,97 @@ def api_invalid_response(req):
api.abort(code=HTTPStatus.UNAUTHORIZED.value)


class OAuth2Provider(provider.OAuth2Provider):
"""
A helper class which connects OAuth2RequestValidator with OAuth2Provider.
"""
class BearerTokenValidator(_BearerTokenValidator):
def authenticate_token(self, token_string):
return OAuth2Token.query.filter_by(access_token=token_string).first()

def request_invalid(self, request):
return False

def __init__(self, *args, **kwargs):
super(OAuth2Provider, self).__init__(*args, **kwargs)
self.invalid_response(api_invalid_response)
def token_revoked(self, token):
# TODO: return token.revoked
return token.revoked

class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
def create_authorization_code(self, client, grant_user, request):
from app.extensions import db
code = gen_salt(48)
item = OAuth2AuthorizationCode(
code=code,
client_id=client.client_id,
redirect_uri=request.redirect_uri,
scope=request.scope,
user_id=grant_user.id,
)
db.session.add(item)
db.session.commit()
return code

def parse_authorization_code(self, code, client):
item = OAuth2AuthorizationCode.query.filter_by(
code=code, client_id=client.client_id).first()
if item and not item.is_expired():
return item

def delete_authorization_code(self, authorization_code):
from app.extensions import db
db.session.delete(authorization_code)
db.session.commit()

def authenticate_user(self, authorization_code):
return User.query.get(authorization_code.user_id)


class PasswordGrant(grants.ResourceOwnerPasswordCredentialsGrant):
def authenticate_user(self, username, password):
return User.find_with_password(username, password)

def init_app(self, app):
assert app.config['SECRET_KEY'], "SECRET_KEY must be configured!"
super(OAuth2Provider, self).init_app(app)
self._validator = OAuth2RequestValidator()

class RefreshTokenGrant(grants.RefreshTokenGrant):
def authenticate_refresh_token(self, refresh_token):
item = OAuth2Token.query.filter_by(refresh_token=refresh_token).first()
if item and not item.is_refresh_token_expired():
return item

def authenticate_user(self, credential):
return User.query.get(credential.user_id)


class OAuth2ResourceProtector(ResourceProtector):
def __init__( self ):
super().__init__()


class OAuth2Provider(AuthorizationServer):
def __init__(self):
super().__init__()
self._require_oauth = None

def init_app( self, app, query_client=None, save_token=None ):
from app.extensions import db
if query_client is None:
query_client = create_query_client_func(db.session, OAuth2Client)
if save_token is None:
save_token = create_save_token_func(db.session, OAuth2Token)

super().init_app(
app, query_client=query_client, save_token=save_token)

# support all grants
self.register_grant(grants.ImplicitGrant)
self.register_grant(grants.ClientCredentialsGrant)
self.register_grant(AuthorizationCodeGrant)
self.register_grant(PasswordGrant)
self.register_grant(RefreshTokenGrant)

# support revocation
revocation_cls = create_revocation_endpoint(db.session, OAuth2Token)
self.register_endpoint(revocation_cls)

# protect resource
bearer_cls = create_bearer_token_validator(db.session, OAuth2Token)
OAuth2ResourceProtector.register_token_validator(bearer_cls())
self._require_oauth = OAuth2ResourceProtector()

def require_oauth(self, *args, **kwargs):
# pylint: disable=arguments-differ
Expand All @@ -134,8 +146,8 @@ def require_oauth(self, *args, **kwargs):
Returns:
function: a decorator.
"""
locations = kwargs.pop('locations', ('cookies',))
origin_decorator = super(OAuth2Provider, self).require_oauth(*args, **kwargs)
locations = kwargs.get('locations', ('cookies',)) # don't want to pop - original decorator may need
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to pop it since the base OAuth2Provider didn't like extra arguments. If it is working fine with get, I am fine with that.

origin_decorator = self._require_oauth(*args, **kwargs)

def decorator(func):
# pylint: disable=missing-docstring
Expand All @@ -148,11 +160,13 @@ def wrapper(*args, **kwargs):
# pylint: disable=missing-docstring
if 'headers' not in locations:
# Invalidate authorization if developer specifically
# disables the lookup in the headers.
# disables the lookup in the headers. (this may or may not be worth all the hassle)
request.authorization = '!'
if 'form' in locations:
if 'access_token' in request.form:
request.authorization = 'Bearer %s' % request.form['access_token']
# don't think we need below lines because bearer validator already registered
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that this trick will not work anymore anyway or you just don't see a use-case of passing the token via POST form parameters? The use-case is designed for native HTML forms, where we don't control the request from JS, but, instead, rely on a browser to do all the work, and thus, we cannot set the access token in the header.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I didn't understand what it was doing and didn't see a similar thing i the authlib example, so I thought it was related to flask-oauthlib. I will add it back in.

# if 'form' in locations:
# if 'access_token' in request.form:
# request.authorization = 'Bearer %s' % request.form['access_token']

return origin_decorated_func(*args, **kwargs)

return wrapper
Expand Down
6 changes: 3 additions & 3 deletions app/extensions/flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class SQLAlchemy(BaseSQLAlchemy):
"""

def __init__(self, *args, **kwargs):
if 'session_options' not in kwargs:
kwargs['session_options'] = {}
kwargs['session_options']['autocommit'] = True
# if 'session_options' not in kwargs:
# kwargs['session_options'] = {}
# kwargs['session_options']['autocommit'] = True
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to sort this out one way or another since it is quite a breaking change.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I had a lot of issues with autocommit, namely with sqlalchemy complaining that a transaction already was started when with db.session.begin(): was called. In authlib, the complaints were about no transaction has started. In my own project I turned it off, because I had my own transaction system, and also I was using postgresql.

# Configure Constraint Naming Conventions:
# http://docs.sqlalchemy.org/en/latest/core/constraints.html#constraint-naming-conventions
kwargs['metadata'] = MetaData(
Expand Down
25 changes: 25 additions & 0 deletions app/extensions/login/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from flask import g
from flask.sessions import SecureCookieSessionInterface
from flask_login import user_loaded_from_header
from flask_login import LoginManager as OriginalLoginManager

class CustomSessionInterface(SecureCookieSessionInterface):
"""Prevent creating session from API requests."""
def save_session(self, *args, **kwargs):
if g.get('login_via_header'):
return
return super(CustomSessionInterface, self).save_session(*args,
**kwargs)


@user_loaded_from_header.connect
def user_loaded_from_header(self, user=None):
g.login_via_header = True


class LoginManager(OriginalLoginManager):
def init_app(self, app):
app.session_interface = CustomSessionInterface()
super().init_app(app)


Loading