39
39
from .models import BaseUser , Role , User , UserRoleLink
40
40
from .schemas import UserLoginOut
41
41
42
- _UserModelT = TypeVar ("_UserModelT " , bound = BaseUser )
42
+ UserModelT = TypeVar ("UserModelT " , bound = BaseUser )
43
43
44
44
45
- class AuthBackend (AuthenticationBackend , Generic [_UserModelT ]):
45
+ class AuthBackend (AuthenticationBackend , Generic [UserModelT ]):
46
46
def __init__ (self , auth : "Auth" , token_store : BaseTokenStore ):
47
47
self .auth = auth
48
48
self .token_store = token_store
@@ -53,23 +53,23 @@ def get_user_token(request: Request) -> Optional[str]:
53
53
scheme , token = get_authorization_scheme_param (authorization )
54
54
return None if not authorization or scheme .lower () != "bearer" else token
55
55
56
- async def authenticate (self , request : Request ) -> Tuple ["Auth" , Optional [_UserModelT ]]:
56
+ async def authenticate (self , request : Request ) -> Tuple ["Auth" , Optional [UserModelT ]]:
57
57
return self .auth , await self .auth .get_current_user (request )
58
58
59
59
def attach_middleware (self , app : FastAPI ):
60
60
app .add_middleware (AuthenticationMiddleware , backend = self ) # 添加auth中间件
61
61
62
62
63
- class Auth (Generic [_UserModelT ]):
64
- user_model : Type [_UserModelT ] = None
63
+ class Auth (Generic [UserModelT ]):
64
+ user_model : Type [UserModelT ] = None
65
65
db : Union [AsyncDatabase , Database ] = None
66
- backend : AuthBackend [_UserModelT ] = None
66
+ backend : AuthBackend [UserModelT ] = None
67
67
68
68
def __init__ (
69
69
self ,
70
70
db : Union [AsyncDatabase , Database ],
71
71
token_store : BaseTokenStore = None ,
72
- user_model : Type [_UserModelT ] = User ,
72
+ user_model : Type [UserModelT ] = User ,
73
73
pwd_context : CryptContext = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" ),
74
74
):
75
75
self .user_model = user_model or self .user_model
@@ -78,7 +78,7 @@ def __init__(
78
78
self .backend = self .backend or AuthBackend (self , token_store or DbTokenStore (self .db ))
79
79
self .pwd_context = pwd_context
80
80
81
- async def authenticate_user (self , username : str , password : Union [str , SecretStr ]) -> Optional [_UserModelT ]:
81
+ async def authenticate_user (self , username : str , password : Union [str , SecretStr ]) -> Optional [UserModelT ]:
82
82
user = await self .db .async_scalar (select (self .user_model ).where (self .user_model .username == username ))
83
83
if user :
84
84
pwd = password .get_secret_value () if isinstance (password , SecretStr ) else password
@@ -87,21 +87,17 @@ async def authenticate_user(self, username: str, password: Union[str, SecretStr]
87
87
return user
88
88
return None
89
89
90
- @cached_property
91
- def get_current_user (self ):
92
- async def _get_current_user (request : Request ) -> Optional [_UserModelT ]:
93
- if request .scope .get ("auth" ): # 防止重复授权
94
- return request .scope .get ("user" )
95
- request .scope ["auth" ], request .scope ["user" ] = self , None
96
- token = self .backend .get_user_token (request )
97
- if not token :
98
- return None
99
- token_data = await self .backend .token_store .read_token (token )
100
- if token_data is not None :
101
- request .scope ["user" ]: _UserModelT = await self .db .async_get (self .user_model , token_data .id )
102
- return request .user
103
-
104
- return _get_current_user
90
+ async def get_current_user (self , request : Request ) -> Optional [UserModelT ]:
91
+ if request .scope .get ("auth" ): # 防止重复授权
92
+ return request .scope .get ("user" )
93
+ request .scope ["auth" ], request .scope ["user" ] = self , None
94
+ token = self .backend .get_user_token (request )
95
+ if not token :
96
+ return None
97
+ token_data = await self .backend .token_store .read_token (token )
98
+ if token_data is not None :
99
+ request .scope ["user" ]: UserModelT = await self .db .async_get (self .user_model , token_data .id )
100
+ return request .user
105
101
106
102
def requires (
107
103
self ,
@@ -116,12 +112,12 @@ def requires(
116
112
roles_ = (roles ,) if not roles or isinstance (roles , str ) else tuple (roles )
117
113
permissions_ = (permissions ,) if not permissions or isinstance (permissions , str ) else tuple (permissions )
118
114
119
- async def has_requires (user : _UserModelT ) -> bool :
115
+ async def has_requires (user : UserModelT ) -> bool :
120
116
return user and await self .db .async_run_sync (user .has_requires , roles = roles , groups = groups , permissions = permissions )
121
117
122
118
async def depend (
123
119
request : Request ,
124
- user : _UserModelT = Depends (self .get_current_user ),
120
+ user : UserModelT = Depends (self .get_current_user ),
125
121
) -> Union [bool , Response ]:
126
122
user_auth = request .scope .get ("__user_auth__" , None )
127
123
if user_auth is None :
0 commit comments