Skip to content

Commit f7bf444

Browse files
committed
perf: Optimized code.
1 parent 4757230 commit f7bf444

File tree

1 file changed

+21
-25
lines changed

1 file changed

+21
-25
lines changed

fastapi_user_auth/auth/auth.py

+21-25
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
from .models import BaseUser, Role, User, UserRoleLink
4040
from .schemas import UserLoginOut
4141

42-
_UserModelT = TypeVar("_UserModelT", bound=BaseUser)
42+
UserModelT = TypeVar("UserModelT", bound=BaseUser)
4343

4444

45-
class AuthBackend(AuthenticationBackend, Generic[_UserModelT]):
45+
class AuthBackend(AuthenticationBackend, Generic[UserModelT]):
4646
def __init__(self, auth: "Auth", token_store: BaseTokenStore):
4747
self.auth = auth
4848
self.token_store = token_store
@@ -53,23 +53,23 @@ def get_user_token(request: Request) -> Optional[str]:
5353
scheme, token = get_authorization_scheme_param(authorization)
5454
return None if not authorization or scheme.lower() != "bearer" else token
5555

56-
async def authenticate(self, request: Request) -> Tuple["Auth", Optional[_UserModelT]]:
56+
async def authenticate(self, request: Request) -> Tuple["Auth", Optional[UserModelT]]:
5757
return self.auth, await self.auth.get_current_user(request)
5858

5959
def attach_middleware(self, app: FastAPI):
6060
app.add_middleware(AuthenticationMiddleware, backend=self) # 添加auth中间件
6161

6262

63-
class Auth(Generic[_UserModelT]):
64-
user_model: Type[_UserModelT] = None
63+
class Auth(Generic[UserModelT]):
64+
user_model: Type[UserModelT] = None
6565
db: Union[AsyncDatabase, Database] = None
66-
backend: AuthBackend[_UserModelT] = None
66+
backend: AuthBackend[UserModelT] = None
6767

6868
def __init__(
6969
self,
7070
db: Union[AsyncDatabase, Database],
7171
token_store: BaseTokenStore = None,
72-
user_model: Type[_UserModelT] = User,
72+
user_model: Type[UserModelT] = User,
7373
pwd_context: CryptContext = CryptContext(schemes=["bcrypt"], deprecated="auto"),
7474
):
7575
self.user_model = user_model or self.user_model
@@ -78,7 +78,7 @@ def __init__(
7878
self.backend = self.backend or AuthBackend(self, token_store or DbTokenStore(self.db))
7979
self.pwd_context = pwd_context
8080

81-
async def authenticate_user(self, username: str, password: Union[str, SecretStr]) -> Optional[_UserModelT]:
81+
async def authenticate_user(self, username: str, password: Union[str, SecretStr]) -> Optional[UserModelT]:
8282
user = await self.db.async_scalar(select(self.user_model).where(self.user_model.username == username))
8383
if user:
8484
pwd = password.get_secret_value() if isinstance(password, SecretStr) else password
@@ -87,21 +87,17 @@ async def authenticate_user(self, username: str, password: Union[str, SecretStr]
8787
return user
8888
return None
8989

90-
@cached_property
91-
def get_current_user(self):
92-
async def _get_current_user(request: Request) -> Optional[_UserModelT]:
93-
if request.scope.get("auth"): # 防止重复授权
94-
return request.scope.get("user")
95-
request.scope["auth"], request.scope["user"] = self, None
96-
token = self.backend.get_user_token(request)
97-
if not token:
98-
return None
99-
token_data = await self.backend.token_store.read_token(token)
100-
if token_data is not None:
101-
request.scope["user"]: _UserModelT = await self.db.async_get(self.user_model, token_data.id)
102-
return request.user
103-
104-
return _get_current_user
90+
async def get_current_user(self, request: Request) -> Optional[UserModelT]:
91+
if request.scope.get("auth"): # 防止重复授权
92+
return request.scope.get("user")
93+
request.scope["auth"], request.scope["user"] = self, None
94+
token = self.backend.get_user_token(request)
95+
if not token:
96+
return None
97+
token_data = await self.backend.token_store.read_token(token)
98+
if token_data is not None:
99+
request.scope["user"]: UserModelT = await self.db.async_get(self.user_model, token_data.id)
100+
return request.user
105101

106102
def requires(
107103
self,
@@ -116,12 +112,12 @@ def requires(
116112
roles_ = (roles,) if not roles or isinstance(roles, str) else tuple(roles)
117113
permissions_ = (permissions,) if not permissions or isinstance(permissions, str) else tuple(permissions)
118114

119-
async def has_requires(user: _UserModelT) -> bool:
115+
async def has_requires(user: UserModelT) -> bool:
120116
return user and await self.db.async_run_sync(user.has_requires, roles=roles, groups=groups, permissions=permissions)
121117

122118
async def depend(
123119
request: Request,
124-
user: _UserModelT = Depends(self.get_current_user),
120+
user: UserModelT = Depends(self.get_current_user),
125121
) -> Union[bool, Response]:
126122
user_auth = request.scope.get("__user_auth__", None)
127123
if user_auth is None:

0 commit comments

Comments
 (0)