From d84c39dc6242e45e7f4f48cbefa76d5157f72422 Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Mon, 19 May 2025 16:33:52 -0300 Subject: [PATCH 1/2] add type annotations --- firebase_admin/_typing.py | 160 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 firebase_admin/_typing.py diff --git a/firebase_admin/_typing.py b/firebase_admin/_typing.py new file mode 100644 index 00000000..f54eb5d9 --- /dev/null +++ b/firebase_admin/_typing.py @@ -0,0 +1,160 @@ +# Copyright 2020 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This module adds some type annotations that refer to types defined in other +# submodules. To avoid circular import issues (NameError), the evaluation of +# these annotations is deferred by using string literals (forward references). +# This allows the annotations to be valid at runtime without requiring the immediate +# loading of the referenced symbols. + +import typing +import typing_extensions + +import google.auth.credentials +import requests + +import firebase_admin +from firebase_admin import credentials +from firebase_admin import exceptions +from firebase_admin import project_management + + +_KT = typing.TypeVar("_KT") +_VT_co = typing.TypeVar("_VT_co", covariant=True) + + +class SupportsKeysAndGetItem(typing.Protocol[_KT, _VT_co]): + def keys(self) -> typing.Iterable[_KT]: ... + def __getitem__(self, __key: _KT) -> _VT_co: ... + + +class SupportsTrunc(typing.Protocol): + def __trunc__(self) -> int: ... + + +ConvertibleToInt = typing.Union[ + str, + typing_extensions.Buffer, + typing.SupportsInt, + typing.SupportsIndex, + SupportsTrunc +] +ConvertibleToFloat: typing_extensions.TypeAlias = typing.Union[ + str, + typing_extensions.Buffer, + typing.SupportsFloat, + typing.SupportsIndex +] + +_AnyT = typing_extensions.TypeVar("_AnyT", default=typing.Any) +_AnyT_co = typing_extensions.TypeVar("_AnyT_co", covariant=True, default=typing.Any) + +_FirebaseErrorT_co = typing_extensions.TypeVar( + "_FirebaseErrorT_co", covariant=True, default="exceptions.FirebaseError") +_AppMetadataT_co = typing_extensions.TypeVar( + "_AppMetadataT_co", covariant=True, default="project_management._AppMetadata") + +CredentialLike = typing.Union["credentials.Base", google.auth.credentials.Credentials] +HeadersLike = typing.Union[ + SupportsKeysAndGetItem[str, typing.Union[bytes, str]], + typing.Iterable[typing.Tuple[str, typing.Union[bytes, str]]] +] +ServiceInitializer = typing.Callable[["firebase_admin.App"], _AnyT] +RequestErrorHandler = typing.Callable[ + [ + requests.RequestException, + str, + typing.Dict[str, typing.Any] + ], + typing.Optional["exceptions.FirebaseError"] +] +GoogleAPIErrorHandler = typing.Callable[ + [ + Exception, + str, + typing.Dict[str, typing.Any], + requests.Response, + ], + typing.Optional["exceptions.FirebaseError"], +] +Json = typing.Optional[typing.Union[ + typing.Dict[str, "Json"], + typing.List["Json"], + str, + float +]] +EmailActionType = typing.Literal[ + 'VERIFY_EMAIL', + 'EMAIL_SIGNIN', + 'PASSWORD_RESET', +] + +class FirebaseErrorFactory(typing.Protocol[_FirebaseErrorT_co]): + def __call__( + self, + message: str, + cause: typing.Optional[Exception], + http_response: typing.Optional[requests.Response], + ) -> _FirebaseErrorT_co: ... + + +class FirebaseErrorFactoryNoHttp(typing.Protocol[_FirebaseErrorT_co]): + def __call__( + self, + message: str, + cause: typing.Optional[Exception], + ) -> _FirebaseErrorT_co: ... + + +class FirebaseErrorFactoryWithDefaults(typing.Protocol[_FirebaseErrorT_co]): + def __call__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> _FirebaseErrorT_co: ... + + +class FirebaseErrorFactoryNoHttpWithDefaults(typing.Protocol[_FirebaseErrorT_co]): + def __call__( + self, + message: str, + cause: typing.Optional[Exception] = None, + ) -> _FirebaseErrorT_co: ... + + +class AppMetadataSubclass(typing.Protocol[_AppMetadataT_co]): + def __call__( + self, + __identifier: str, + name: str, + app_id: str, + display_name: typing.Optional[str], + project_id: str + ) -> _AppMetadataT_co: ... + + +class ProjectApp(typing.Protocol[_AnyT_co]): + def __call__( + self, + app_id: str, + service: "project_management._ProjectManagementService", + ) -> _AnyT_co: ... + + +class Page(typing.Protocol): + @property + def has_next_page(self) -> bool: ... + + def get_next_page(self) -> typing.Optional[typing_extensions.Self]: ... From 84b4c25d7517929ccadf4a74b869cf1f8b1ef423 Mon Sep 17 00:00:00 2001 From: ViktorSky Date: Mon, 19 May 2025 16:34:31 -0300 Subject: [PATCH 2/2] add type annotations --- firebase_admin/__init__.py | 65 +++-- firebase_admin/_auth_client.py | 195 +++++++++++--- firebase_admin/_auth_providers.py | 208 ++++++++++----- firebase_admin/_auth_utils.py | 272 ++++++++++++++++---- firebase_admin/_gapic_utils.py | 34 ++- firebase_admin/_http_client.py | 79 ++++-- firebase_admin/_messaging_encoder.py | 198 +++++++++----- firebase_admin/_messaging_utils.py | 193 +++++++++++--- firebase_admin/_rfc3339.py | 19 +- firebase_admin/_sseclient.py | 49 ++-- firebase_admin/_token_gen.py | 184 ++++++++----- firebase_admin/_user_identifier.py | 24 +- firebase_admin/_user_import.py | 168 +++++++----- firebase_admin/_user_mgt.py | 261 ++++++++++++------- firebase_admin/_utils.py | 65 +++-- firebase_admin/app_check.py | 52 ++-- firebase_admin/auth.py | 216 ++++++++++++---- firebase_admin/credentials.py | 89 ++++--- firebase_admin/db.py | 254 +++++++++++------- firebase_admin/exceptions.py | 125 +++++++-- firebase_admin/firestore.py | 25 +- firebase_admin/firestore_async.py | 30 +-- firebase_admin/functions.py | 143 +++++----- firebase_admin/instance_id.py | 14 +- firebase_admin/messaging.py | 177 +++++++++---- firebase_admin/ml.py | 372 ++++++++++++++++----------- firebase_admin/project_management.py | 218 ++++++++++------ firebase_admin/remote_config.py | 232 +++++++++++------ firebase_admin/storage.py | 21 +- firebase_admin/tenant_mgt.py | 106 +++++--- pyrightconfig.json | 29 +++ requirements.txt | 6 +- setup.py | 4 + 33 files changed, 2805 insertions(+), 1322 deletions(-) create mode 100644 pyrightconfig.json diff --git a/firebase_admin/__init__.py b/firebase_admin/__init__.py index 7bb9c59c..826c2de8 100644 --- a/firebase_admin/__init__.py +++ b/firebase_admin/__init__.py @@ -17,23 +17,31 @@ import json import os import threading +import typing from google.auth.credentials import Credentials as GoogleAuthCredentials from google.auth.exceptions import DefaultCredentialsError from firebase_admin import credentials from firebase_admin.__about__ import __version__ +from firebase_admin import _typing -_apps = {} +_T = typing.TypeVar("_T") + +_apps: typing.Dict[str, "App"] = {} _apps_lock = threading.RLock() -_clock = datetime.datetime.utcnow +_clock = lambda: datetime.datetime.now(datetime.timezone.utc) _DEFAULT_APP_NAME = '[DEFAULT]' _FIREBASE_CONFIG_ENV_VAR = 'FIREBASE_CONFIG' _CONFIG_VALID_KEYS = ['databaseAuthVariableOverride', 'databaseURL', 'httpTimeout', 'projectId', 'storageBucket'] -def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): +def initialize_app( + credential: typing.Optional[_typing.CredentialLike] = None, + options: typing.Optional[typing.Dict[str, typing.Any]] = None, + name: str = _DEFAULT_APP_NAME +) -> "App": """Initializes and returns a new App instance. Creates a new App instance using the specified options @@ -86,7 +94,7 @@ def initialize_app(credential=None, options=None, name=_DEFAULT_APP_NAME): 'you call initialize_app().').format(name)) -def delete_app(app): +def delete_app(app: "App") -> None: """Gracefully deletes an App instance. Args: @@ -114,7 +122,7 @@ def delete_app(app): 'second argument.').format(app.name)) -def get_app(name=_DEFAULT_APP_NAME): +def get_app(name: str = _DEFAULT_APP_NAME) -> "App": """Retrieves an App instance by name. Args: @@ -148,7 +156,7 @@ def get_app(name=_DEFAULT_APP_NAME): class _AppOptions: """A collection of configuration options for an App.""" - def __init__(self, options): + def __init__(self, options: typing.Optional[typing.Dict[str, typing.Any]]) -> None: if options is None: options = self._load_from_environment() @@ -157,11 +165,16 @@ def __init__(self, options): 'must be a dictionary.'.format(type(options))) self._options = options - def get(self, key, default=None): + @typing.overload + def get(self, key: str, default: None = None) -> typing.Optional[typing.Any]: ... + # possible issue: needs return Any | _T ? + @typing.overload + def get(self, key: str, default: _T) -> _T: ... + def get(self, key: str, default: typing.Any = None) -> typing.Optional[typing.Any]: """Returns the option identified by the provided key.""" return self._options.get(key, default) - def _load_from_environment(self): + def _load_from_environment(self) -> typing.Dict[str, typing.Any]: """Invoked when no options are passed to __init__, loads options from FIREBASE_CONFIG. If the value of the FIREBASE_CONFIG environment variable starts with "{" an attempt is made @@ -193,7 +206,12 @@ class App: common to all Firebase APIs. """ - def __init__(self, name, credential, options): + def __init__( + self, + name: str, + credential: _typing.CredentialLike, + options: typing.Optional[typing.Dict[str, typing.Any]] + ) -> None: """Constructs a new App using the provided name and options. Args: @@ -218,37 +236,37 @@ def __init__(self, name, credential, options): 'with a valid credential instance.') self._options = _AppOptions(options) self._lock = threading.RLock() - self._services = {} + self._services: typing.Optional[typing.Dict[str, typing.Any]] = {} App._validate_project_id(self._options.get('projectId')) self._project_id_initialized = False @classmethod - def _validate_project_id(cls, project_id): + def _validate_project_id(cls, project_id: typing.Optional[str]) -> None: if project_id is not None and not isinstance(project_id, str): raise ValueError( 'Invalid project ID: "{0}". project ID must be a string.'.format(project_id)) @property - def name(self): + def name(self) -> str: return self._name @property - def credential(self): + def credential(self) -> credentials.Base: return self._credential @property - def options(self): + def options(self) -> _AppOptions: return self._options @property - def project_id(self): + def project_id(self) -> typing.Optional[str]: if not self._project_id_initialized: self._project_id = self._lookup_project_id() self._project_id_initialized = True return self._project_id - def _lookup_project_id(self): + def _lookup_project_id(self) -> typing.Optional[str]: """Looks up the Firebase project ID associated with an App. If a ``projectId`` is specified in app options, it is returned. Then tries to @@ -259,10 +277,10 @@ def _lookup_project_id(self): Returns: str: A project ID string or None. """ - project_id = self._options.get('projectId') + project_id: typing.Optional[str] = self._options.get('projectId') if not project_id: try: - project_id = self._credential.project_id + project_id = getattr(self._credential, "project_id") except (AttributeError, DefaultCredentialsError): pass if not project_id: @@ -271,7 +289,7 @@ def _lookup_project_id(self): App._validate_project_id(self._options.get('projectId')) return project_id - def _get_service(self, name, initializer): + def _get_service(self, name: str, initializer: _typing.ServiceInitializer[_T]) -> _T: """Returns the service instance identified by the given name. Services are functional entities exposed by the Admin SDK (e.g. auth, database). Each @@ -301,7 +319,7 @@ def _get_service(self, name, initializer): self._services[name] = initializer(self) return self._services[name] - def _cleanup(self): + def _cleanup(self) -> None: """Cleans up any services associated with this App. Checks whether each service contains a close() method, and calls it if available. @@ -309,7 +327,8 @@ def _cleanup(self): any services started by the App. """ with self._lock: - for service in self._services.values(): - if hasattr(service, 'close') and hasattr(service.close, '__call__'): - service.close() + if self._services: + for service in self._services.values(): + if hasattr(service, 'close') and hasattr(service.close, '__call__'): + service.close() self._services = None diff --git a/firebase_admin/_auth_client.py b/firebase_admin/_auth_client.py index 38b42993..31a01c4e 100644 --- a/firebase_admin/_auth_client.py +++ b/firebase_admin/_auth_client.py @@ -15,22 +15,25 @@ """Firebase auth client sub module.""" import time +import typing import firebase_admin from firebase_admin import _auth_providers from firebase_admin import _auth_utils from firebase_admin import _http_client from firebase_admin import _token_gen +from firebase_admin import _typing from firebase_admin import _user_identifier from firebase_admin import _user_import from firebase_admin import _user_mgt from firebase_admin import _utils +from firebase_admin import exceptions class Client: """Firebase Authentication client scoped to a specific tenant.""" - def __init__(self, app, tenant_id=None): + def __init__(self, app: firebase_admin.App, tenant_id: typing.Optional[str] = None) -> None: if not app.project_id: raise ValueError("""A project ID is required to access the auth service. 1. Use a service account credential, or @@ -41,7 +44,7 @@ def __init__(self, app, tenant_id=None): version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) # Non-default endpoint URLs for emulator support are set in this dict later. - endpoint_urls = {} + endpoint_urls: typing.Dict[str, str] = {} self.emulated = False # If an emulator is present, check that the given value matches the expected format and set @@ -70,11 +73,15 @@ def __init__(self, app, tenant_id=None): http_client, app.project_id, tenant_id, url_override=endpoint_urls.get('v2')) @property - def tenant_id(self): + def tenant_id(self) -> typing.Optional[str]: """Tenant ID associated with this client.""" return self._tenant_id - def create_custom_token(self, uid, developer_claims=None): + def create_custom_token( + self, + uid: str, + developer_claims: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> bytes: """Builds and signs a Firebase custom auth token. Args: @@ -92,7 +99,12 @@ def create_custom_token(self, uid, developer_claims=None): return self._token_generator.create_custom_token( uid, developer_claims, tenant_id=self.tenant_id) - def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): + def verify_id_token( + self, + id_token: typing.Union[bytes, str], + check_revoked: bool = False, + clock_skew_seconds: int = 0, + ) -> typing.Dict[str, typing.Any]: """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, was issued @@ -138,7 +150,7 @@ def verify_id_token(self, id_token, check_revoked=False, clock_skew_seconds=0): verified_claims, _token_gen.RevokedIdTokenError, 'ID token') return verified_claims - def revoke_refresh_tokens(self, uid): + def revoke_refresh_tokens(self, uid: str) -> None: """Revokes all refresh tokens for an existing user. This method updates the user's ``tokens_valid_after_timestamp`` to the current UTC @@ -159,7 +171,7 @@ def revoke_refresh_tokens(self, uid): """ self._user_manager.update_user(uid, valid_since=int(time.time())) - def get_user(self, uid): + def get_user(self, uid: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user ID. Args: @@ -176,7 +188,7 @@ def get_user(self, uid): response = self._user_manager.get_user(uid=uid) return _user_mgt.UserRecord(response) - def get_user_by_email(self, email): + def get_user_by_email(self, email: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user email. Args: @@ -193,7 +205,7 @@ def get_user_by_email(self, email): response = self._user_manager.get_user(email=email) return _user_mgt.UserRecord(response) - def get_user_by_phone_number(self, phone_number): + def get_user_by_phone_number(self, phone_number: str) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified phone number. Args: @@ -210,7 +222,7 @@ def get_user_by_phone_number(self, phone_number): response = self._user_manager.get_user(phone_number=phone_number) return _user_mgt.UserRecord(response) - def get_users(self, identifiers): + def get_users(self, identifiers: typing.Sequence[_user_identifier.UserIdentifier]) -> _user_mgt.GetUsersResult: """Gets the user data corresponding to the specified identifiers. There are no ordering guarantees; in particular, the nth entry in the @@ -235,7 +247,7 @@ def get_users(self, identifiers): """ response = self._user_manager.get_users(identifiers=identifiers) - def _matches(identifier, user_record): + def _matches(identifier: _user_identifier.UserIdentifier, user_record: _user_mgt.UserRecord) -> bool: if isinstance(identifier, _user_identifier.UidIdentifier): return identifier.uid == user_record.uid if isinstance(identifier, _user_identifier.EmailIdentifier): @@ -251,7 +263,10 @@ def _matches(identifier, user_record): ), False) raise TypeError("Unexpected type: {}".format(type(identifier))) - def _is_user_found(identifier, user_records): + def _is_user_found( + identifier: _user_identifier.UserIdentifier, + user_records: typing.List[_user_mgt.UserRecord], + ) -> bool: return any(_matches(identifier, user_record) for user_record in user_records) users = [_user_mgt.UserRecord(user) for user in response] @@ -260,7 +275,11 @@ def _is_user_found(identifier, user_records): return _user_mgt.GetUsersResult(users=users, not_found=not_found) - def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS): + def list_users( + self, + page_token: typing.Optional[str] = None, + max_results: int = _user_mgt.MAX_LIST_USERS_RESULTS, + ) -> _user_mgt.ListUsersPage: """Retrieves a page of user accounts from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -282,11 +301,22 @@ def list_users(self, page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESUL ValueError: If max_results or page_token are invalid. FirebaseError: If an error occurs while retrieving the user accounts. """ - def download(page_token, max_results): + def download(page_token: typing.Optional[str], max_results: int) -> typing.Dict[str, _typing.Json]: return self._user_manager.list_users(page_token, max_results) return _user_mgt.ListUsersPage(download, page_token, max_results) - def create_user(self, **kwargs): # pylint: disable=differing-param-doc + def create_user( + self, + uid: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + email: typing.Optional[str] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + password: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + email_verified: typing.Optional[bool] = None, + **kwargs: typing.Any, + ) -> _user_mgt.UserRecord: """Creates a new user account with the specified properties. Args: @@ -310,10 +340,26 @@ def create_user(self, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ - uid = self._user_manager.create_user(**kwargs) + uid = self._user_manager.create_user(uid=uid, display_name=display_name, email=email, + phone_number=phone_number, photo_url=photo_url, password=password, disabled=disabled, + email_verified=email_verified, **kwargs) return self.get_user(uid=uid) - def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc + def update_user( + self, + uid: str, + display_name: typing.Optional[str] = None, + email: typing.Optional[str] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + password: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + email_verified: typing.Optional[bool] = None, + valid_since: typing.Optional[_typing.ConvertibleToInt] = None, + custom_claims: typing.Optional[typing.Union[typing.Dict[str, typing.Any], str]] = None, + providers_to_delete: typing.Optional[typing.List[str]] = None, + **kwargs: typing.Any, + ) -> _user_mgt.UserRecord: """Updates an existing user account with the specified properties. Args: @@ -348,10 +394,16 @@ def update_user(self, uid, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ - self._user_manager.update_user(uid, **kwargs) + self._user_manager.update_user(uid, display_name=display_name, email=email, phone_number=phone_number, + photo_url=photo_url, password=password, disabled=disabled, email_verified=email_verified, + valid_since=valid_since, custom_claims=custom_claims, providers_to_delete=providers_to_delete, **kwargs) return self.get_user(uid=uid) - def set_custom_user_claims(self, uid, custom_claims): + def set_custom_user_claims( + self, + uid: str, + custom_claims: typing.Optional[typing.Union[typing.Dict[str, typing.Any], str]], + ) -> None: """Sets additional claims on an existing user account. Custom claims set via this function can be used to define user roles and privilege levels. @@ -374,7 +426,7 @@ def set_custom_user_claims(self, uid, custom_claims): custom_claims = _user_mgt.DELETE_ATTRIBUTE self._user_manager.update_user(uid, custom_claims=custom_claims) - def delete_user(self, uid): + def delete_user(self, uid: str) -> None: """Deletes the user identified by the specified user ID. Args: @@ -386,7 +438,7 @@ def delete_user(self, uid): """ self._user_manager.delete_user(uid) - def delete_users(self, uids): + def delete_users(self, uids: typing.Sequence[str]) -> _user_mgt.DeleteUsersResult: """Deletes the users specified by the given identifiers. Deleting a non-existing user does not generate an error (the method is @@ -413,7 +465,11 @@ def delete_users(self, uids): result = self._user_manager.delete_users(uids, force_delete=True) return _user_mgt.DeleteUsersResult(result, len(uids)) - def import_users(self, users, hash_alg=None): + def import_users( + self, + users: typing.Sequence[_user_import.ImportUserRecord], + hash_alg: typing.Optional[_user_import.UserImportHash] = None, + ) -> _user_import.UserImportResult: """Imports the specified list of users into Firebase Auth. At most 1000 users can be imported at a time. This operation is optimized for bulk imports @@ -437,7 +493,11 @@ def import_users(self, users, hash_alg=None): result = self._user_manager.import_users(users, hash_alg) return _user_import.UserImportResult(result, len(users)) - def generate_password_reset_link(self, email, action_code_settings=None): + def generate_password_reset_link( + self, + email: typing.Optional[str], + action_code_settings: typing.Optional[_user_mgt.ActionCodeSettings] = None, + ) -> str: """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -458,7 +518,11 @@ def generate_password_reset_link(self, email, action_code_settings=None): return self._user_manager.generate_email_action_link( 'PASSWORD_RESET', email, action_code_settings=action_code_settings) - def generate_email_verification_link(self, email, action_code_settings=None): + def generate_email_verification_link( + self, + email: typing.Optional[str], + action_code_settings: typing.Optional[_user_mgt.ActionCodeSettings] = None, + ) -> str: """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -479,7 +543,11 @@ def generate_email_verification_link(self, email, action_code_settings=None): return self._user_manager.generate_email_action_link( 'VERIFY_EMAIL', email, action_code_settings=action_code_settings) - def generate_sign_in_with_email_link(self, email, action_code_settings): + def generate_sign_in_with_email_link( + self, + email: typing.Optional[str], + action_code_settings: typing.Optional[_user_mgt.ActionCodeSettings], + ) -> str: """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -499,7 +567,7 @@ def generate_sign_in_with_email_link(self, email, action_code_settings): return self._user_manager.generate_email_action_link( 'EMAIL_SIGNIN', email, action_code_settings=action_code_settings) - def get_oidc_provider_config(self, provider_id): + def get_oidc_provider_config(self, provider_id: str) -> _auth_providers.OIDCProviderConfig: """Returns the ``OIDCProviderConfig`` with the given ID. Args: @@ -516,8 +584,16 @@ def get_oidc_provider_config(self, provider_id): return self._provider_manager.get_oidc_provider_config(provider_id) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: str, + issuer: str, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + client_secret: typing.Optional[str] = None, + id_token_response_type: typing.Optional[bool] = None, + code_response_type: typing.Optional[bool] = None, + ) -> _auth_providers.OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -555,8 +631,16 @@ def create_oidc_provider_config( id_token_response_type=id_token_response_type, code_response_type=code_response_type) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: typing.Optional[str] = None, + issuer: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + client_secret: typing.Optional[str] = None, + id_token_response_type: typing.Optional[bool] = None, + code_response_type: typing.Optional[bool] = None, + ) -> _auth_providers.OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters. Args: @@ -590,7 +674,7 @@ def update_oidc_provider_config( enabled=enabled, client_secret=client_secret, id_token_response_type=id_token_response_type, code_response_type=code_response_type) - def delete_oidc_provider_config(self, provider_id): + def delete_oidc_provider_config(self, provider_id: str) -> None: """Deletes the ``OIDCProviderConfig`` with the given ID. Args: @@ -604,7 +688,10 @@ def delete_oidc_provider_config(self, provider_id): self._provider_manager.delete_oidc_provider_config(provider_id) def list_oidc_provider_configs( - self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + self, + page_token: typing.Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + ) -> _auth_providers._ListOIDCProviderConfigsPage: """Retrieves a page of OIDC provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -628,7 +715,7 @@ def list_oidc_provider_configs( """ return self._provider_manager.list_oidc_provider_configs(page_token, max_results) - def get_saml_provider_config(self, provider_id): + def get_saml_provider_config(self, provider_id: str) -> _auth_providers.SAMLProviderConfig: """Returns the ``SAMLProviderConfig`` with the given ID. Args: @@ -645,8 +732,16 @@ def get_saml_provider_config(self, provider_id): return self._provider_manager.get_saml_provider_config(provider_id) def create_saml_provider_config( - self, provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, - callback_url, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: typing.List[str], + rp_entity_id: str, + callback_url: str, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + ) -> _auth_providers.SAMLProviderConfig: """Creates a new SAML provider config from the given parameters. SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -685,8 +780,16 @@ def create_saml_provider_config( callback_url=callback_url, display_name=display_name, enabled=enabled) def update_saml_provider_config( - self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: typing.Optional[str] = None, + sso_url: typing.Optional[str] = None, + x509_certificates: typing.Optional[typing.List[str]] = None, + rp_entity_id: typing.Optional[str] = None, + callback_url: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + ) -> _auth_providers.SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters. Args: @@ -714,7 +817,7 @@ def update_saml_provider_config( x509_certificates=x509_certificates, rp_entity_id=rp_entity_id, callback_url=callback_url, display_name=display_name, enabled=enabled) - def delete_saml_provider_config(self, provider_id): + def delete_saml_provider_config(self, provider_id: str) -> None: """Deletes the ``SAMLProviderConfig`` with the given ID. Args: @@ -728,7 +831,10 @@ def delete_saml_provider_config(self, provider_id): self._provider_manager.delete_saml_provider_config(provider_id) def list_saml_provider_configs( - self, page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS): + self, + page_token: typing.Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + ) -> _auth_providers._ListSAMLProviderConfigsPage: """Retrieves a page of SAML provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -752,9 +858,14 @@ def list_saml_provider_configs( """ return self._provider_manager.list_saml_provider_configs(page_token, max_results) - def _check_jwt_revoked_or_disabled(self, verified_claims, exc_type, label): - user = self.get_user(verified_claims.get('uid')) + def _check_jwt_revoked_or_disabled( + self, + verified_claims: typing.Dict[str, typing.Any], + exc_type: typing.Callable[[str], exceptions.FirebaseError], + label: str, + ) -> None: + user = self.get_user(verified_claims['uid']) if user.disabled: raise _auth_utils.UserDisabledError('The user record is disabled.') - if verified_claims.get('iat') * 1000 < user.tokens_valid_after_timestamp: + if verified_claims['iat'] * 1000 < user.tokens_valid_after_timestamp: raise exc_type('The Firebase {0} has been revoked.'.format(label)) diff --git a/firebase_admin/_auth_providers.py b/firebase_admin/_auth_providers.py index 31894a4d..14fb187a 100644 --- a/firebase_admin/_auth_providers.py +++ b/firebase_admin/_auth_providers.py @@ -14,34 +14,41 @@ """Firebase auth providers management sub module.""" +import typing +import typing_extensions from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _http_client +from firebase_admin import _typing from firebase_admin import _user_mgt +_ProviderConfigT = typing_extensions.TypeVar("_ProviderConfigT", bound="ProviderConfig", default="ProviderConfig") + + MAX_LIST_CONFIGS_RESULTS = 100 class ProviderConfig: """Parent type for all authentication provider config types.""" - def __init__(self, data): + def __init__(self, data: typing.Dict[str, typing.Any]) -> None: self._data = data @property - def provider_id(self): + def provider_id(self) -> str: name = self._data['name'] return name.split('/')[-1] @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: return self._data.get('displayName') @property - def enabled(self): + def enabled(self) -> bool: return self._data.get('enabled', False) @@ -80,55 +87,60 @@ class SAMLProviderConfig(ProviderConfig): @property def idp_entity_id(self): - return self._data.get('idpConfig', {})['idpEntityId'] + return self._data['idpConfig']['idpEntityId'] @property def sso_url(self): - return self._data.get('idpConfig', {})['ssoUrl'] + return self._data['idpConfig']['ssoUrl'] @property def x509_certificates(self): - certs = self._data.get('idpConfig', {})['idpCertificates'] + certs = self._data['idpConfig']['idpCertificates'] return [c['x509Certificate'] for c in certs] @property def callback_url(self): - return self._data.get('spConfig', {})['callbackUri'] + return self._data['spConfig']['callbackUri'] @property def rp_entity_id(self): - return self._data.get('spConfig', {})['spEntityId'] + return self._data['spConfig']['spEntityId'] -class ListProviderConfigsPage: - """Represents a page of AuthProviderConfig instances retrieved from a Firebase project. +class ListProviderConfigsPage(typing.Generic[_ProviderConfigT]): + """Represents a page of ProviderConfig instances retrieved from a Firebase project. Provides methods for traversing the provider configs included in this page, as well as retrieving subsequent pages. The iterator returned by ``iterate_all()`` can be used to iterate through all provider configs in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: typing.Callable[[typing.Optional[str], int], typing.Dict[str, _typing.Json]], + page_token: typing.Optional[str], + max_results: int + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def provider_configs(self): - """A list of ``AuthProviderConfig`` instances available in this page.""" + def provider_configs(self) -> typing.List[_ProviderConfigT]: + """A list of ``ProviderConfig`` instances available in this page.""" raise NotImplementedError @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" - return self._current.get('nextPageToken', '') + return typing.cast(str, self._current.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> typing.Optional[typing_extensions.Self]: """Retrieves the next page of provider configs, if available. Returns: @@ -139,7 +151,7 @@ def get_next_page(self): return self.__class__(self._download, self.next_page_token, self._max_results) return None - def iterate_all(self): + def iterate_all(self) -> "_ProviderConfigIterator[_ProviderConfigT]": """Retrieves an iterator for provider configs. Returned iterator will iterate through all the provider configs in the Firebase project @@ -147,30 +159,39 @@ def iterate_all(self): in memory at a time. Returns: - iterator: An iterator of AuthProviderConfig instances. + iterator: An iterator of ProviderConfig instances. """ return _ProviderConfigIterator(self) -class _ListOIDCProviderConfigsPage(ListProviderConfigsPage): - +class _ListOIDCProviderConfigsPage(ListProviderConfigsPage[OIDCProviderConfig]): @property - def provider_configs(self): - return [OIDCProviderConfig(data) for data in self._current.get('oauthIdpConfigs', [])] + def provider_configs(self) -> typing.List[OIDCProviderConfig]: + return [ + OIDCProviderConfig(data) + for data in typing.cast( + typing.List[typing.Dict[str, typing.Any]], + self._current.get('oauthIdpConfigs', []), + ) + ] -class _ListSAMLProviderConfigsPage(ListProviderConfigsPage): - +class _ListSAMLProviderConfigsPage(ListProviderConfigsPage[SAMLProviderConfig]): @property - def provider_configs(self): - return [SAMLProviderConfig(data) for data in self._current.get('inboundSamlConfigs', [])] - + def provider_configs(self) -> typing.List[SAMLProviderConfig]: + return [ + SAMLProviderConfig(data) + for data in typing.cast( + typing.List[typing.Dict[str, typing.Any]], + self._current.get('inboundSamlConfigs', []), + ) + ] -class _ProviderConfigIterator(_auth_utils.PageIterator): +class _ProviderConfigIterator(_auth_utils.PageIterator[ListProviderConfigsPage[_ProviderConfigT]]): @property - def items(self): - return self._current_page.provider_configs + def items(self) -> typing.List[_ProviderConfigT]: + return self._current_page.provider_configs if self._current_page else [] class ProviderConfigClient: @@ -178,24 +199,38 @@ class ProviderConfigClient: PROVIDER_CONFIG_URL = 'https://identitytoolkit.googleapis.com/v2' - def __init__(self, http_client, project_id, tenant_id=None, url_override=None): + def __init__( + self, + http_client: _http_client.HttpClient[typing.Dict[str, _typing.Json]], + project_id: str, + tenant_id: typing.Optional[str] = None, + url_override: typing.Optional[str] = None, + ) -> None: self.http_client = http_client url_prefix = url_override or self.PROVIDER_CONFIG_URL self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) if tenant_id: self.base_url += '/tenants/{0}'.format(tenant_id) - def get_oidc_provider_config(self, provider_id): + def get_oidc_provider_config(self, provider_id: str) -> OIDCProviderConfig: _validate_oidc_provider_id(provider_id) body = self._make_request('get', '/oauthIdpConfigs/{0}'.format(provider_id)) return OIDCProviderConfig(body) def create_oidc_provider_config( - self, provider_id, client_id, issuer, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None): + self, + provider_id: str, + client_id: str, + issuer: str, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + client_secret: typing.Optional[str] = None, + id_token_response_type: typing.Optional[bool] = None, + code_response_type: typing.Optional[bool] = None, + ): """Creates a new OIDC provider config from the given parameters.""" _validate_oidc_provider_id(provider_id) - req = { + req: typing.Dict[str, _typing.Json] = { 'clientId': _validate_non_empty_string(client_id, 'client_id'), 'issuer': _validate_url(issuer, 'issuer'), } @@ -204,7 +239,7 @@ def create_oidc_provider_config( if enabled is not None: req['enabled'] = _auth_utils.validate_boolean(enabled, 'enabled') - response_type = {} + response_type: typing.Dict[str, _typing.Json] = {} if id_token_response_type is False and code_response_type is False: raise ValueError('At least one response type must be returned.') if id_token_response_type is not None: @@ -223,12 +258,19 @@ def create_oidc_provider_config( return OIDCProviderConfig(body) def update_oidc_provider_config( - self, provider_id, client_id=None, issuer=None, display_name=None, - enabled=None, client_secret=None, id_token_response_type=None, - code_response_type=None): + self, + provider_id: str, + client_id: typing.Optional[str] = None, + issuer: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + client_secret: typing.Optional[str] = None, + id_token_response_type: typing.Optional[bool] = None, + code_response_type: typing.Optional[bool] = None, + ) -> OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters.""" _validate_oidc_provider_id(provider_id) - req = {} + req: typing.Dict[str, _typing.Json] = {} if display_name is not None: if display_name == _user_mgt.DELETE_ATTRIBUTE: req['displayName'] = None @@ -264,28 +306,44 @@ def update_oidc_provider_config( body = self._make_request('patch', url, json=req, params=params) return OIDCProviderConfig(body) - def delete_oidc_provider_config(self, provider_id): + def delete_oidc_provider_config(self, provider_id: str) -> None: _validate_oidc_provider_id(provider_id) self._make_request('delete', '/oauthIdpConfigs/{0}'.format(provider_id)) - def list_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def list_oidc_provider_configs( + self, + page_token: typing.Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> _ListOIDCProviderConfigsPage: return _ListOIDCProviderConfigsPage( self._fetch_oidc_provider_configs, page_token, max_results) - def _fetch_oidc_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_oidc_provider_configs( + self, + page_token: typing.Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> typing.Dict[str, _typing.Json]: return self._fetch_provider_configs('/oauthIdpConfigs', page_token, max_results) - def get_saml_provider_config(self, provider_id): + def get_saml_provider_config(self, provider_id: str) -> SAMLProviderConfig: _validate_saml_provider_id(provider_id) body = self._make_request('get', '/inboundSamlConfigs/{0}'.format(provider_id)) return SAMLProviderConfig(body) def create_saml_provider_config( - self, provider_id, idp_entity_id, sso_url, x509_certificates, - rp_entity_id, callback_url, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: typing.List[str], + rp_entity_id: str, + callback_url: str, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + ) -> SAMLProviderConfig: """Creates a new SAML provider config from the given parameters.""" _validate_saml_provider_id(provider_id) - req = { + req: typing.Dict[str, typing.Any] = { 'idpConfig': { 'idpEntityId': _validate_non_empty_string(idp_entity_id, 'idp_entity_id'), 'ssoUrl': _validate_url(sso_url, 'sso_url'), @@ -306,11 +364,19 @@ def create_saml_provider_config( return SAMLProviderConfig(body) def update_saml_provider_config( - self, provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None): + self, + provider_id: str, + idp_entity_id: typing.Optional[str] = None, + sso_url: typing.Optional[str] = None, + x509_certificates: typing.Optional[typing.List[str]]=None, + rp_entity_id: typing.Optional[str] = None, + callback_url: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + ) -> SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters.""" _validate_saml_provider_id(provider_id) - idp_config = {} + idp_config: typing.Dict[str, typing.Any] = {} if idp_entity_id is not None: idp_config['idpEntityId'] = _validate_non_empty_string(idp_entity_id, 'idp_entity_id') if sso_url is not None: @@ -318,13 +384,13 @@ def update_saml_provider_config( if x509_certificates is not None: idp_config['idpCertificates'] = _validate_x509_certificates(x509_certificates) - sp_config = {} + sp_config: typing.Dict[str, _typing.Json] = {} if rp_entity_id is not None: sp_config['spEntityId'] = _validate_non_empty_string(rp_entity_id, 'rp_entity_id') if callback_url is not None: sp_config['callbackUri'] = _validate_url(callback_url, 'callback_url') - req = {} + req: typing.Dict[str, _typing.Json] = {} if display_name is not None: if display_name == _user_mgt.DELETE_ATTRIBUTE: req['displayName'] = None @@ -346,18 +412,31 @@ def update_saml_provider_config( body = self._make_request('patch', url, json=req, params=params) return SAMLProviderConfig(body) - def delete_saml_provider_config(self, provider_id): + def delete_saml_provider_config(self, provider_id: str) -> None: _validate_saml_provider_id(provider_id) self._make_request('delete', '/inboundSamlConfigs/{0}'.format(provider_id)) - def list_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def list_saml_provider_configs( + self, + page_token: typing.Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> _ListSAMLProviderConfigsPage: return _ListSAMLProviderConfigsPage( self._fetch_saml_provider_configs, page_token, max_results) - def _fetch_saml_provider_configs(self, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_saml_provider_configs( + self, + page_token: typing.Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> typing.Dict[str, _typing.Json]: return self._fetch_provider_configs('/inboundSamlConfigs', page_token, max_results) - def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CONFIGS_RESULTS): + def _fetch_provider_configs( + self, + path: str, + page_token: typing.Optional[str] = None, + max_results: int = MAX_LIST_CONFIGS_RESULTS, + ) -> typing.Dict[str, _typing.Json]: """Fetches a page of auth provider configs""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -374,7 +453,7 @@ def _fetch_provider_configs(self, path, page_token=None, max_results=MAX_LIST_CO params += '&pageToken={0}'.format(page_token) return self._make_request('get', path, params=params) - def _make_request(self, method, path, **kwargs): + def _make_request(self, method: str, path: str, **kwargs: typing.Any) -> typing.Dict[str, _typing.Json]: url = '{0}{1}'.format(self.base_url, path) try: return self.http_client.body(method, url, **kwargs) @@ -382,7 +461,7 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -def _validate_oidc_provider_id(provider_id): +def _validate_oidc_provider_id(provider_id: typing.Any) -> str: if not isinstance(provider_id, str): raise ValueError( 'Invalid OIDC provider ID: {0}. Provider ID must be a non-empty string.'.format( @@ -392,7 +471,7 @@ def _validate_oidc_provider_id(provider_id): return provider_id -def _validate_saml_provider_id(provider_id): +def _validate_saml_provider_id(provider_id: typing.Any) -> str: if not isinstance(provider_id, str): raise ValueError( 'Invalid SAML provider ID: {0}. Provider ID must be a non-empty string.'.format( @@ -402,7 +481,7 @@ def _validate_saml_provider_id(provider_id): return provider_id -def _validate_non_empty_string(value, label): +def _validate_non_empty_string(value: typing.Any, label: str) -> str: """Validates that the given value is a non-empty string.""" if not isinstance(value, str): raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) @@ -411,7 +490,7 @@ def _validate_non_empty_string(value, label): return value -def _validate_url(url, label): +def _validate_url(url: typing.Any, label: str) -> str: """Validates that the given value is a well-formed URL string.""" if not isinstance(url, str) or not url: raise ValueError( @@ -426,9 +505,10 @@ def _validate_url(url, label): raise ValueError('Malformed {0}: "{1}".'.format(label, url)) -def _validate_x509_certificates(x509_certificates): +def _validate_x509_certificates(x509_certificates: typing.Any) -> typing.List[typing.Dict[str, str]]: if not isinstance(x509_certificates, list) or not x509_certificates: raise ValueError('x509_certificates must be a non-empty list.') + x509_certificates = typing.cast(typing.List[typing.Any], x509_certificates) if not all([isinstance(cert, str) and cert for cert in x509_certificates]): raise ValueError('x509_certificates must only contain non-empty strings.') return [{'x509Certificate': cert} for cert in x509_certificates] diff --git a/firebase_admin/_auth_utils.py b/firebase_admin/_auth_utils.py index ac7b322f..7bb46d5b 100644 --- a/firebase_admin/_auth_utils.py +++ b/firebase_admin/_auth_utils.py @@ -17,12 +17,18 @@ import json import os import re +import typing from urllib import parse +import requests + from firebase_admin import exceptions +from firebase_admin import _typing from firebase_admin import _utils +_PageT = typing.TypeVar("_PageT", bound=_typing.Page) + EMULATOR_HOST_ENV_VAR = 'FIREBASE_AUTH_EMULATOR_HOST' MAX_CLAIMS_PAYLOAD_SIZE = 1000 RESERVED_CLAIMS = set([ @@ -32,7 +38,7 @@ VALID_EMAIL_ACTION_TYPES = set(['VERIFY_EMAIL', 'EMAIL_SIGNIN', 'PASSWORD_RESET']) -class PageIterator: +class PageIterator(typing.Generic[_PageT]): """An iterator that allows iterating over a sequence of items, one at a time. This implementation loads a page of items into memory, and iterates on them. When the whole @@ -40,21 +46,21 @@ class PageIterator: of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: _PageT) -> None: if not current_page: raise ValueError('Current page must not be None.') - self._current_page = current_page - self._iter = None + self._current_page: typing.Optional[_PageT] = current_page + self._iter: typing.Optional[typing.Iterator[_PageT]] = None - def __next__(self): + def __next__(self) -> _PageT: if self._iter is None: self._iter = iter(self.items) try: return next(self._iter) except StopIteration: - if self._current_page.has_next_page: + if self._current_page and self._current_page.has_next_page: self._current_page = self._current_page.get_next_page() self._iter = iter(self.items) @@ -62,15 +68,15 @@ def __next__(self): raise - def __iter__(self): + def __iter__(self) -> typing.Iterator[_PageT]: return self @property - def items(self): + def items(self) -> typing.Sequence[typing.Any]: raise NotImplementedError -def get_emulator_host(): +def get_emulator_host() -> str: emulator_host = os.getenv(EMULATOR_HOST_ENV_VAR, '') if emulator_host and '//' in emulator_host: raise ValueError( @@ -79,11 +85,15 @@ def get_emulator_host(): return emulator_host -def is_emulated(): +def is_emulated() -> bool: return get_emulator_host() != '' -def validate_uid(uid, required=False): +@typing.overload +def validate_uid(uid: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_uid(uid: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: ... +def validate_uid(uid: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: if uid is None and not required: return None if not isinstance(uid, str) or not uid or len(uid) > 128: @@ -92,7 +102,12 @@ def validate_uid(uid, required=False): 'characters.'.format(uid)) return uid -def validate_email(email, required=False): + +@typing.overload +def validate_email(email: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_email(email: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: ... +def validate_email(email: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: if email is None and not required: return None if not isinstance(email, str) or not email: @@ -103,7 +118,12 @@ def validate_email(email, required=False): raise ValueError('Malformed email address string: "{0}".'.format(email)) return email -def validate_phone(phone, required=False): + +@typing.overload +def validate_phone(phone: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_phone(phone: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: ... +def validate_phone(phone: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: """Validates the specified phone number. Phone number vlidation is very lax here. Backend will enforce E.164 spec compliance, and @@ -120,7 +140,12 @@ def validate_phone(phone, required=False): 'compliant identifier.'.format(phone)) return phone -def validate_password(password, required=False): + +@typing.overload +def validate_password(password: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_password(password: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: ... +def validate_password(password: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: if password is None and not required: return None if not isinstance(password, str) or len(password) < 6: @@ -128,14 +153,24 @@ def validate_password(password, required=False): 'Invalid password string. Password must be a string at least 6 characters long.') return password -def validate_bytes(value, label, required=False): + +@typing.overload +def validate_bytes(value: typing.Optional[typing.Any], label: typing.Any, required: typing.Literal[True]) -> bytes: ... +@typing.overload +def validate_bytes(value: typing.Optional[typing.Any], label: typing.Any, required: bool = False) -> typing.Optional[bytes]: ... +def validate_bytes(value: typing.Optional[typing.Any], label: typing.Any, required: bool = False) -> typing.Optional[bytes]: if value is None and not required: return None if not isinstance(value, bytes) or not value: raise ValueError('{0} must be a non-empty byte sequence.'.format(label)) return value -def validate_display_name(display_name, required=False): + +@typing.overload +def validate_display_name(display_name: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_display_name(display_name: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: ... +def validate_display_name(display_name: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: if display_name is None and not required: return None if not isinstance(display_name, str) or not display_name: @@ -144,7 +179,12 @@ def validate_display_name(display_name, required=False): 'string.'.format(display_name)) return display_name -def validate_provider_id(provider_id, required=True): + +@typing.overload +def validate_provider_id(provider_id: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_provider_id(provider_id: typing.Optional[typing.Any], required: bool = True) -> typing.Optional[str]: ... +def validate_provider_id(provider_id: typing.Optional[typing.Any], required: bool = True) -> typing.Optional[str]: if provider_id is None and not required: return None if not isinstance(provider_id, str) or not provider_id: @@ -153,7 +193,12 @@ def validate_provider_id(provider_id, required=True): 'string.'.format(provider_id)) return provider_id -def validate_provider_uid(provider_uid, required=True): + +@typing.overload +def validate_provider_uid(provider_uid: typing.Optional[typing.Any], required: typing.Literal[True] = True) -> str: ... +@typing.overload +def validate_provider_uid(provider_uid: typing.Optional[typing.Any], required: bool = True) -> typing.Optional[str]: ... +def validate_provider_uid(provider_uid: typing.Optional[typing.Any], required: bool = True) -> typing.Optional[str]: if provider_uid is None and not required: return None if not isinstance(provider_uid, str) or not provider_uid: @@ -162,7 +207,12 @@ def validate_provider_uid(provider_uid, required=True): 'string.'.format(provider_uid)) return provider_uid -def validate_photo_url(photo_url, required=False): + +@typing.overload +def validate_photo_url(photo_url: typing.Optional[typing.Any], required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_photo_url(photo_url: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: ... +def validate_photo_url(photo_url: typing.Optional[typing.Any], required: bool = False) -> typing.Optional[str]: """Parses and validates the given URL string.""" if photo_url is None and not required: return None @@ -178,14 +228,31 @@ def validate_photo_url(photo_url, required=False): except Exception: raise ValueError('Malformed photo URL: "{0}".'.format(photo_url)) -def validate_timestamp(timestamp, label, required=False): + +@typing.overload +def validate_timestamp( + timestamp: typing.Optional[typing.Any], + label: typing.Any, + required: typing.Literal[True], +) -> int: ... +@typing.overload +def validate_timestamp( + timestamp: typing.Optional[typing.Any], + label: typing.Any, + required: bool = False, +) -> typing.Optional[int]: ... +def validate_timestamp( + timestamp: typing.Optional[typing.Any], + label: typing.Any, + required: bool = False, +) -> typing.Optional[int]: """Validates the given timestamp value. Timestamps must be positive integers.""" if timestamp is None and not required: return None if isinstance(timestamp, bool): raise ValueError('Boolean value specified as timestamp.') try: - timestamp_int = int(timestamp) + timestamp_int = int(timestamp) # type: ignore[reportArgumentType, arg-type] except TypeError: raise ValueError('Invalid type for timestamp value: {0}.'.format(timestamp)) else: @@ -195,7 +262,13 @@ def validate_timestamp(timestamp, label, required=False): raise ValueError('{0} timestamp must be a positive interger.'.format(label)) return timestamp_int -def validate_int(value, label, low=None, high=None): + +def validate_int( + value: typing.Any, + label: typing.Any, + low: typing.Optional[int] = None, + high: typing.Optional[int] = None, +) -> int: """Validates that the given value represents an integer. There are several ways to represent an integer in Python (e.g. 2, 2L, 2.0). This method allows @@ -219,19 +292,26 @@ def validate_int(value, label, low=None, high=None): raise ValueError('{0} must not be larger than {1}.'.format(label, high)) return val_int -def validate_string(value, label): + +def validate_string(value: typing.Any, label: typing.Any) -> str: """Validates that the given value is a string.""" if not isinstance(value, str): raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) return value -def validate_boolean(value, label): + +def validate_boolean(value: typing.Any, label: typing.Any) -> bool: """Validates that the given value is a boolean.""" if not isinstance(value, bool): raise ValueError('Invalid type for {0}: {1}.'.format(label, value)) return value -def validate_custom_claims(custom_claims, required=False): + +@typing.overload +def validate_custom_claims(custom_claims: typing.Any, required: typing.Literal[True]) -> str: ... +@typing.overload +def validate_custom_claims(custom_claims: typing.Any, required: bool = False) -> typing.Optional[str]: ... +def validate_custom_claims(custom_claims: typing.Any, required: bool = False) -> typing.Optional[str]: """Validates the specified custom claims. Custom claims must be specified as a JSON string. The string must not exceed 1000 @@ -251,7 +331,7 @@ def validate_custom_claims(custom_claims, required=False): if not isinstance(parsed, dict): raise ValueError('Custom claims must be parseable as a JSON object.') - invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys())) + invalid_claims = RESERVED_CLAIMS.intersection(set(parsed.keys())) # type: ignore[reportUnknownArgumentType] if len(invalid_claims) > 1: joined = ', '.join(sorted(invalid_claims)) raise ValueError('Claims "{0}" are reserved, and must not be set.'.format(joined)) @@ -260,13 +340,15 @@ def validate_custom_claims(custom_claims, required=False): 'Claim "{0}" is reserved, and must not be set.'.format(invalid_claims.pop())) return claims_str -def validate_action_type(action_type): + +def validate_action_type(action_type: typing.Any) -> _typing.EmailActionType: if action_type not in VALID_EMAIL_ACTION_TYPES: raise ValueError('Invalid action type provided action_type: {0}. \ Valid values are {1}'.format(action_type, ', '.join(VALID_EMAIL_ACTION_TYPES))) return action_type -def validate_provider_ids(provider_ids, required=False): + +def validate_provider_ids(provider_ids: typing.Any, required: bool = False) -> typing.List[str]: if not provider_ids: if required: raise ValueError('Invalid provider IDs. Provider ids should be provided') @@ -275,12 +357,13 @@ def validate_provider_ids(provider_ids, required=False): validate_provider_id(provider_id, True) return provider_ids -def build_update_mask(params): + +def build_update_mask(params: typing.Dict[str, typing.Any]) -> typing.List[str]: """Creates an update mask list from the given dictionary.""" - mask = [] + mask: typing.List[str] = [] for key, value in params.items(): if isinstance(value, dict): - child_mask = build_update_mask(value) + child_mask = build_update_mask(value) # type: ignore[reportUnknownArgumentType] for child in child_mask: mask.append('{0}.{1}'.format(key, child)) else: @@ -294,7 +377,12 @@ class UidAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided uid already exists' - def __init__(self, message, cause, http_response): + def __init__( + self, + message: str, + cause: typing.Optional[Exception], + http_response: typing.Optional[requests.Response] + ) -> None: exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) @@ -303,7 +391,12 @@ class EmailAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided email already exists' - def __init__(self, message, cause, http_response): + def __init__( + self, + message: str, + cause: typing.Optional[Exception], + http_response: typing.Optional[requests.Response] + ) -> None: exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) @@ -315,7 +408,12 @@ class InsufficientPermissionError(exceptions.PermissionDeniedError): 'https://firebase.google.com/docs/admin/setup for details ' 'on how to initialize the Admin SDK with appropriate permissions') - def __init__(self, message, cause, http_response): + def __init__( + self, + message: str, + cause: typing.Optional[Exception], + http_response: typing.Optional[requests.Response] + ) -> None: exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) @@ -324,7 +422,12 @@ class InvalidDynamicLinkDomainError(exceptions.InvalidArgumentError): default_message = 'Dynamic link domain specified in ActionCodeSettings is not authorized' - def __init__(self, message, cause, http_response): + def __init__( + self, + message: str, + cause: typing.Optional[Exception], + http_response: typing.Optional[requests.Response] + ) -> None: exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) @@ -333,7 +436,12 @@ class InvalidIdTokenError(exceptions.InvalidArgumentError): default_message = 'The provided ID token is invalid' - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) @@ -342,14 +450,24 @@ class PhoneNumberAlreadyExistsError(exceptions.AlreadyExistsError): default_message = 'The user with the provided phone number already exists' - def __init__(self, message, cause, http_response): + def __init__( + self, + message: str, + cause: typing.Optional[Exception], + http_response: typing.Optional[requests.Response], + ) -> None: exceptions.AlreadyExistsError.__init__(self, message, cause, http_response) class UnexpectedResponseError(exceptions.UnknownError): """Backend service responded with an unexpected or malformed response.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.UnknownError.__init__(self, message, cause, http_response) @@ -358,7 +476,12 @@ class UserNotFoundError(exceptions.NotFoundError): default_message = 'No user record found for the given identifier' - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.NotFoundError.__init__(self, message, cause, http_response) @@ -367,7 +490,12 @@ class EmailNotFoundError(exceptions.NotFoundError): default_message = 'No user record found for the given email' - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.NotFoundError.__init__(self, message, cause, http_response) @@ -376,14 +504,19 @@ class TenantNotFoundError(exceptions.NotFoundError): default_message = 'No tenant found for the given identifier' - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.NotFoundError.__init__(self, message, cause, http_response) class TenantIdMismatchError(exceptions.InvalidArgumentError): """Missing or invalid tenant ID field in the given JWT.""" - def __init__(self, message): + def __init__(self, message: str) -> None: exceptions.InvalidArgumentError.__init__(self, message) @@ -392,7 +525,12 @@ class ConfigurationNotFoundError(exceptions.NotFoundError): default_message = 'No auth provider found for the given identifier' - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.NotFoundError.__init__(self, message, cause, http_response) @@ -401,25 +539,40 @@ class UserDisabledError(exceptions.InvalidArgumentError): default_message = 'The user record is disabled' - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.InvalidArgumentError.__init__(self, message, cause, http_response) class TooManyAttemptsTryLaterError(exceptions.ResourceExhaustedError): """Rate limited because of too many attempts.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) class ResetPasswordExceedLimitError(exceptions.ResourceExhaustedError): """Reset password emails exceeded their limits.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) -_CODE_TO_EXC_TYPE = { +_CODE_TO_EXC_TYPE: typing.Dict[str, _typing.FirebaseErrorFactory] = { 'CONFIGURATION_NOT_FOUND': ConfigurationNotFoundError, 'DUPLICATE_EMAIL': EmailAlreadyExistsError, 'DUPLICATE_LOCAL_ID': UidAlreadyExistsError, @@ -436,7 +589,7 @@ def __init__(self, message, cause=None, http_response=None): } -def handle_auth_backend_error(error): +def handle_auth_backend_error(error: requests.RequestException) -> exceptions.FirebaseError: """Converts a requests error received from the Firebase Auth service into a FirebaseError.""" if error.response is None: return _utils.handle_requests_error(error) @@ -454,19 +607,21 @@ def handle_auth_backend_error(error): return exc_type(msg, cause=error, http_response=error.response) -def _parse_error_body(response): +def _parse_error_body(response: requests.Response) -> typing.Tuple[typing.Optional[str], typing.Optional[str]]: """Parses the given error response to extract Auth error code and message.""" - error_dict = {} + parsed_body = None try: parsed_body = response.json() - if isinstance(parsed_body, dict): - error_dict = parsed_body.get('error', {}) except ValueError: pass + if not isinstance(parsed_body, dict): + return None, None + # Auth error response format: {"error": {"message": "AUTH_ERROR_CODE: Optional text"}} - code = error_dict.get('message') if isinstance(error_dict, dict) else None - custom_message = None + parsed_body = typing.cast(typing.Dict[str, typing.Dict[str, str]], parsed_body) + error_dict = parsed_body.get('error', {}) + code, custom_message = error_dict.get('message'), None if code: separator = code.find(':') if separator != -1: @@ -476,8 +631,11 @@ def _parse_error_body(response): return code, custom_message -def _build_error_message(code, exc_type, custom_message): - default_message = exc_type.default_message if ( - exc_type and hasattr(exc_type, 'default_message')) else 'Error while calling Auth service' +def _build_error_message( + code: str, + exc_type: typing.Optional[_typing.FirebaseErrorFactory], + custom_message: typing.Optional[str] +) -> str: + default_message: str = getattr(exc_type, 'default_message', 'Error while calling Auth service') ext = ' {0}'.format(custom_message) if custom_message else '' return '{0} ({1}).{2}'.format(default_message, code, ext) diff --git a/firebase_admin/_gapic_utils.py b/firebase_admin/_gapic_utils.py index 3c975808..02cfb548 100644 --- a/firebase_admin/_gapic_utils.py +++ b/firebase_admin/_gapic_utils.py @@ -16,16 +16,20 @@ import io import socket +import typing -import googleapiclient +import googleapiclient.errors import httplib2 import requests from firebase_admin import exceptions +from firebase_admin import _typing from firebase_admin import _utils - -def handle_platform_error_from_googleapiclient(error, handle_func=None): +def handle_platform_error_from_googleapiclient( + error: Exception, + handle_func: typing.Optional[_typing.GoogleAPIErrorHandler] = None +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given googleapiclient error. This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. @@ -43,7 +47,7 @@ def handle_platform_error_from_googleapiclient(error, handle_func=None): return handle_googleapiclient_error(error) content = error.content.decode() - status_code = error.resp.status + status_code = typing.cast(int, error.resp.status) # type: ignore[reportUnknownMemberType] error_dict, message = _utils._parse_platform_error(content, status_code) # pylint: disable=protected-access http_response = _http_response_from_googleapiclient_error(error) exc = None @@ -53,7 +57,12 @@ def handle_platform_error_from_googleapiclient(error, handle_func=None): return exc if exc else _handle_func_googleapiclient(error, message, error_dict, http_response) -def _handle_func_googleapiclient(error, message, error_dict, http_response): +def _handle_func_googleapiclient( + error: Exception, + message: str, + error_dict: typing.Dict[str, typing.Any], + http_response: requests.Response +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given GCP error. Args: @@ -69,7 +78,12 @@ def _handle_func_googleapiclient(error, message, error_dict, http_response): return handle_googleapiclient_error(error, message, code, http_response) -def handle_googleapiclient_error(error, message=None, code=None, http_response=None): +def handle_googleapiclient_error( + error: Exception, + message: typing.Optional[str] = None, + code: typing.Optional[str] = None, + http_response: typing.Optional[requests.Response] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given googleapiclient error. This method is agnostic of the remote service that produced the error, whether it is a GCP @@ -104,7 +118,7 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N cause=error) if not code: - code = _utils._http_status_to_error_code(error.resp.status) # pylint: disable=protected-access + code = _utils._http_status_to_error_code(error.resp.status) # pylint: disable=protected-access # type: ignore[reportUnknownMemberType] if not message: message = str(error) if not http_response: @@ -114,9 +128,9 @@ def handle_googleapiclient_error(error, message=None, code=None, http_response=N return err_type(message=message, cause=error, http_response=http_response) -def _http_response_from_googleapiclient_error(error): +def _http_response_from_googleapiclient_error(error: googleapiclient.errors.HttpError) -> requests.Response: """Creates a requests HTTP Response object from the given googleapiclient error.""" - resp = requests.models.Response() + resp = requests.Response() resp.raw = io.BytesIO(error.content) - resp.status_code = error.resp.status + resp.status_code = typing.cast(int, error.resp.status) # type: ignore[reportUnknownMemberType] return resp diff --git a/firebase_admin/_http_client.py b/firebase_admin/_http_client.py index 57c09e2e..2e11428e 100644 --- a/firebase_admin/_http_client.py +++ b/firebase_admin/_http_client.py @@ -17,16 +17,29 @@ This module provides utilities for making HTTP calls using the requests library. """ -from google.auth import transport -import requests -from requests.packages.urllib3.util import retry # pylint: disable=import-error +import typing +import typing_extensions +import google.auth.transport.requests +import google.auth.credentials +import requests.adapters +import requests.structures + +from firebase_admin import _typing from firebase_admin import _utils +if typing.TYPE_CHECKING: + from urllib3.util import retry +else: + from requests.packages.urllib3.util import retry # pylint: disable=import-error + +_AnyT = typing_extensions.TypeVar("_AnyT", default=typing.Any) + if hasattr(retry.Retry.DEFAULT, 'allowed_methods'): - _ANY_METHOD = {'allowed_methods': None} + _ANY_METHOD: typing.Dict[str, typing.Any] = {'allowed_methods': None} else: - _ANY_METHOD = {'method_whitelist': None} + _ANY_METHOD = {'method_whitelist': None} # type: ignore[reportConstantRedefinition] + # Default retry configuration: Retries once on low-level connection and socket read errors. # Retries up to 4 times on HTTP 500 and 503 errors, with exponential backoff. Returns the # last response upon exhausting all retries. @@ -41,7 +54,7 @@ 'x-goog-api-client': _utils.get_metrics_header(), } -class HttpClient: +class HttpClient(typing.Generic[_AnyT]): """Base HTTP client used to make HTTP calls. HttpClient maintains an HTTP session, and handles request authentication and retries if @@ -49,8 +62,14 @@ class HttpClient: """ def __init__( - self, credential=None, session=None, base_url='', headers=None, - retries=DEFAULT_RETRY_CONFIG, timeout=DEFAULT_TIMEOUT_SECONDS): + self, + credential: typing.Optional[google.auth.credentials.Credentials] = None, + session: typing.Optional[requests.Session] = None, + base_url: str = '', + headers: typing.Optional["_typing.HeadersLike"] = None, + retries: retry.Retry = DEFAULT_RETRY_CONFIG, + timeout: int = DEFAULT_TIMEOUT_SECONDS + ) -> None: """Creates a new HttpClient instance from the provided arguments. If a credential is provided, initializes a new HTTP session authorized with it. If neither @@ -67,8 +86,9 @@ def __init__( timeout: HTTP timeout in seconds. Defaults to 120 seconds when not specified. Set to None to disable timeouts (optional). """ + self._session: typing.Optional[requests.Session] if credential: - self._session = transport.requests.AuthorizedSession(credential) + self._session = google.auth.transport.requests.AuthorizedSession(credential) elif session: self._session = session else: @@ -83,21 +103,21 @@ def __init__( self._timeout = timeout @property - def session(self): + def session(self) -> typing.Optional[requests.Session]: return self._session @property - def base_url(self): + def base_url(self) -> str: return self._base_url @property - def timeout(self): + def timeout(self) -> int: return self._timeout - def parse_body(self, resp): + def parse_body(self, resp: requests.Response) -> _AnyT: raise NotImplementedError - def request(self, method, url, **kwargs): + def request(self, method: str, url: str, **kwargs: typing.Any) -> requests.Response: """Makes an HTTP call using the Python requests library. This is the sole entry point to the requests library. All other helper methods in this @@ -120,36 +140,39 @@ class call this method to send HTTP requests out. Refer to if 'timeout' not in kwargs: kwargs['timeout'] = self.timeout kwargs.setdefault('headers', {}).update(METRICS_HEADERS) - resp = self._session.request(method, self.base_url + url, **kwargs) + # possible issue: _session can be None + resp = self._session.request(method, self.base_url + url, **kwargs) # type: ignore[reportOptionalMemberAccess] resp.raise_for_status() return resp - def headers(self, method, url, **kwargs): + def headers(self, method: str, url: str, **kwargs: typing.Any) -> requests.structures.CaseInsensitiveDict[str]: resp = self.request(method, url, **kwargs) return resp.headers - def body_and_response(self, method, url, **kwargs): + def body_and_response(self, method: str, url: str, **kwargs: typing.Any) -> typing.Tuple[_AnyT, requests.Response]: resp = self.request(method, url, **kwargs) return self.parse_body(resp), resp - def body(self, method, url, **kwargs): + def body(self, method: str, url: str, **kwargs: typing.Any) -> _AnyT: resp = self.request(method, url, **kwargs) return self.parse_body(resp) - def headers_and_body(self, method, url, **kwargs): + def headers_and_body( + self, + method: str, + url: str, + **kwargs: typing.Any, + ) -> typing.Tuple[requests.structures.CaseInsensitiveDict[str], _AnyT]: resp = self.request(method, url, **kwargs) return resp.headers, self.parse_body(resp) - def close(self): - self._session.close() - self._session = None + def close(self) -> None: + if self._session is not None: + self._session.close() + self._session = None -class JsonHttpClient(HttpClient): +class JsonHttpClient(HttpClient[typing.Dict[str ,"_typing.Json"]]): """An HTTP client that parses response messages as JSON.""" - - def __init__(self, **kwargs): - HttpClient.__init__(self, **kwargs) - - def parse_body(self, resp): + def parse_body(self, resp: requests.Response) -> typing.Dict[str ,"_typing.Json"]: return resp.json() diff --git a/firebase_admin/_messaging_encoder.py b/firebase_admin/_messaging_encoder.py index d7f23328..136caaf3 100644 --- a/firebase_admin/_messaging_encoder.py +++ b/firebase_admin/_messaging_encoder.py @@ -19,8 +19,12 @@ import math import numbers import re +import typing -import firebase_admin._messaging_utils as _messaging_utils +from firebase_admin import _messaging_utils + +_K = typing.TypeVar("_K") +_V = typing.TypeVar("_V") class Message: @@ -35,7 +39,7 @@ class Message: notification: An instance of ``messaging.Notification`` (optional). android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). + apns: An instance of ``messaging.APNSConfig`` (optional). fcm_options: An instance of ``messaging.FCMOptions`` (optional). token: The registration token of the device to which the message should be sent (optional). topic: Name of the FCM topic to which the message should be sent (optional). Topic name @@ -43,8 +47,18 @@ class Message: condition: The FCM condition to which the message should be sent (optional). """ - def __init__(self, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None, token=None, topic=None, condition=None): + def __init__( + self, + data: typing.Optional[typing.Dict[str, str]] = None, + notification: typing.Optional[_messaging_utils.Notification] = None, + android: typing.Optional[_messaging_utils.AndroidConfig] = None, + webpush: typing.Optional[_messaging_utils.WebpushConfig] = None, + apns: typing.Optional[_messaging_utils.APNSConfig] = None, + fcm_options: typing.Optional[_messaging_utils.FCMOptions] = None, + token: typing.Optional[str] = None, + topic: typing.Optional[str] = None, + condition: typing.Optional[str] = None, + ) -> None: self.data = data self.notification = notification self.android = android @@ -55,7 +69,7 @@ def __init__(self, data=None, notification=None, android=None, webpush=None, apn self.topic = topic self.condition = condition - def __str__(self): + def __str__(self) -> str: return json.dumps(self, cls=MessageEncoder, sort_keys=True) @@ -69,11 +83,19 @@ class MulticastMessage: notification: An instance of ``messaging.Notification`` (optional). android: An instance of ``messaging.AndroidConfig`` (optional). webpush: An instance of ``messaging.WebpushConfig`` (optional). - apns: An instance of ``messaging.ApnsConfig`` (optional). + apns: An instance of ``messaging.APNSConfig`` (optional). fcm_options: An instance of ``messaging.FCMOptions`` (optional). """ - def __init__(self, tokens, data=None, notification=None, android=None, webpush=None, apns=None, - fcm_options=None): + def __init__( + self, + tokens: typing.List[str], + data: typing.Optional[typing.Dict[str, str]] = None, + notification: typing.Optional[_messaging_utils.Notification] = None, + android: typing.Optional[_messaging_utils.AndroidConfig] = None, + webpush: typing.Optional[_messaging_utils.WebpushConfig] = None, + apns: typing.Optional[_messaging_utils.APNSConfig] = None, + fcm_options: typing.Optional[_messaging_utils.FCMOptions] = None + ) -> None: _Validators.check_string_list('MulticastMessage.tokens', tokens) if len(tokens) > 500: raise ValueError('MulticastMessage.tokens must not contain more than 500 tokens.') @@ -93,7 +115,7 @@ class _Validators: """ @classmethod - def check_string(cls, label, value, non_empty=False): + def check_string(cls, label: str, value: typing.Any, non_empty: bool = False) -> typing.Optional[str]: """Checks if the given value is a string.""" if value is None: return None @@ -106,7 +128,7 @@ def check_string(cls, label, value, non_empty=False): return value @classmethod - def check_number(cls, label, value): + def check_number(cls, label: str, value: typing.Any) -> typing.Optional[numbers.Number]: if value is None: return None if not isinstance(value, numbers.Number): @@ -114,12 +136,17 @@ def check_number(cls, label, value): return value @classmethod - def check_string_dict(cls, label, value): + def check_string_dict( + cls, + label: str, + value: typing.Union[typing.Dict[typing.Any, typing.Any], typing.Any], + ) -> typing.Optional[typing.Dict[str, str]]: """Checks if the given value is a dictionary comprised only of string keys and values.""" if value is None or value == {}: return None if not isinstance(value, dict): raise ValueError('{0} must be a dictionary.'.format(label)) + value = typing.cast(typing.Dict[typing.Any, typing.Any], value) non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError('{0} must not contain non-string keys.'.format(label)) @@ -129,39 +156,45 @@ def check_string_dict(cls, label, value): return value @classmethod - def check_string_list(cls, label, value): + def check_string_list( + cls, + label: str, + value: typing.Union[typing.List[typing.Any], typing.Any], + ) -> typing.Optional[typing.List[str]]: """Checks if the given value is a list comprised only of strings.""" if value is None or value == []: return None if not isinstance(value, list): raise ValueError('{0} must be a list of strings.'.format(label)) + value = typing.cast(typing.List[typing.Any], value) non_str = [k for k in value if not isinstance(k, str)] if non_str: raise ValueError('{0} must not contain non-string values.'.format(label)) return value @classmethod - def check_number_list(cls, label, value): + def check_number_list(cls, label: str, value: typing.Any) -> typing.Optional[typing.List[numbers.Number]]: """Checks if the given value is a list comprised only of numbers.""" if value is None or value == []: return None if not isinstance(value, list): raise ValueError('{0} must be a list of numbers.'.format(label)) + value = typing.cast(typing.List[typing.Any], value) non_number = [k for k in value if not isinstance(k, numbers.Number)] if non_number: raise ValueError('{0} must not contain non-number values.'.format(label)) return value @classmethod - def check_analytics_label(cls, label, value): + def check_analytics_label(cls, label: str, value: typing.Any) -> typing.Optional[str]: """Checks if the given value is a valid analytics label.""" - value = _Validators.check_string(label, value) + value = cls.check_string(label, value) if value is not None and not re.match(r'^[a-zA-Z0-9-_.~%]{1,50}$', value): raise ValueError('Malformed {}.'.format(label)) return value @classmethod - def check_boolean(cls, label, value): + def check_boolean(cls, label: str, value: typing.Any) -> typing.Optional[bool]: """Checks if the given value is boolean.""" if value is None: return None @@ -170,7 +203,7 @@ def check_boolean(cls, label, value): return value @classmethod - def check_datetime(cls, label, value): + def check_datetime(cls, label: str, value: typing.Any) -> typing.Optional[datetime.datetime]: """Checks if the given value is a datetime.""" if value is None: return None @@ -182,18 +215,21 @@ def check_datetime(cls, label, value): class MessageEncoder(json.JSONEncoder): """A custom ``JSONEncoder`` implementation for serializing Message instances into JSON.""" - @classmethod - def remove_null_values(cls, dict_value): - return {k: v for k, v in dict_value.items() if v not in [None, [], {}]} + @staticmethod + def remove_null_values(dict_value: typing.Dict[_K, typing.Optional[_V]]) -> typing.Dict[_K, _V]: + return {k: typing.cast(_V, v) for k, v in dict_value.items() if v not in [None, [], {}]} @classmethod - def encode_android(cls, android): + def encode_android( + cls, + android: typing.Optional[_messaging_utils.AndroidConfig], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes an ``AndroidConfig`` instance into JSON.""" if android is None: return None if not isinstance(android, _messaging_utils.AndroidConfig): raise ValueError('Message.android must be an instance of AndroidConfig class.') - result = { + result: typing.Dict[str, typing.Any] = { 'collapse_key': _Validators.check_string( 'AndroidConfig.collapse_key', android.collapse_key), 'data': _Validators.check_string_dict( @@ -215,7 +251,10 @@ def encode_android(cls, android): return result @classmethod - def encode_android_fcm_options(cls, fcm_options): + def encode_android_fcm_options( + cls, + fcm_options: typing.Optional[_messaging_utils.AndroidFCMOptions], + ) -> typing.Optional[typing.Dict[str, str]]: """Encodes an ``AndroidFCMOptions`` instance into JSON.""" if fcm_options is None: return None @@ -230,12 +269,12 @@ def encode_android_fcm_options(cls, fcm_options): return result @classmethod - def encode_ttl(cls, ttl): + def encode_ttl(cls, ttl: typing.Optional[typing.Union[numbers.Real, datetime.timedelta]]) -> typing.Optional[str]: """Encodes an ``AndroidConfig`` ``TTL`` duration into a string.""" if ttl is None: return None - if isinstance(ttl, numbers.Number): - ttl = datetime.timedelta(seconds=ttl) + if isinstance(ttl, numbers.Real): + ttl = datetime.timedelta(seconds=float(ttl)) if not isinstance(ttl, datetime.timedelta): raise ValueError('AndroidConfig.ttl must be a duration in seconds or an instance of ' 'datetime.timedelta.') @@ -249,12 +288,16 @@ def encode_ttl(cls, ttl): return '{0}s'.format(seconds) @classmethod - def encode_milliseconds(cls, label, msec): + def encode_milliseconds( + cls, + label: str, + msec: typing.Optional[typing.Union[numbers.Real, datetime.timedelta]], + ) -> typing.Optional[str]: """Encodes a duration in milliseconds into a string.""" if msec is None: return None - if isinstance(msec, numbers.Number): - msec = datetime.timedelta(milliseconds=msec) + if isinstance(msec, numbers.Real): + msec = datetime.timedelta(milliseconds=float(msec)) if not isinstance(msec, datetime.timedelta): raise ValueError('{0} must be a duration in milliseconds or an instance of ' 'datetime.timedelta.'.format(label)) @@ -268,14 +311,17 @@ def encode_milliseconds(cls, label, msec): return '{0}s'.format(seconds) @classmethod - def encode_android_notification(cls, notification): + def encode_android_notification( + cls, + notification: typing.Optional[_messaging_utils.AndroidNotification], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes an ``AndroidNotification`` instance into JSON.""" if notification is None: return None if not isinstance(notification, _messaging_utils.AndroidNotification): raise ValueError('AndroidConfig.notification must be an instance of ' 'AndroidNotification class.') - result = { + result: typing.Dict[str, typing.Any] = { 'body': _Validators.check_string( 'AndroidNotification.body', notification.body), 'body_loc_args': _Validators.check_string_list( @@ -324,7 +370,7 @@ def encode_android_notification(cls, notification): 'AndroidNotification.proxy', notification.proxy, non_empty=True) } result = cls.remove_null_values(result) - color = result.get('color') + color: typing.Optional[str] = result.get('color') if color and not re.match(r'^#[0-9a-fA-F]{6}$', color): raise ValueError( 'AndroidNotification.color must be in the form #RRGGBB.') @@ -335,7 +381,7 @@ def encode_android_notification(cls, notification): raise ValueError( 'AndroidNotification.title_loc_key is required when specifying title_loc_args.') - event_time = result.get('event_time') + event_time: typing.Optional[datetime.datetime] = result.get('event_time') if event_time: # if the datetime instance is not naive (tzinfo is present), convert to UTC # otherwise (tzinfo is None) assume the datetime instance is already in UTC @@ -357,9 +403,9 @@ def encode_android_notification(cls, notification): 'AndroidNotification.visibility must be "private", "public" or "secret".') result['visibility'] = visibility.upper() - vibrate_timings_millis = result.get('vibrate_timings') + vibrate_timings_millis: typing.Optional[typing.List[typing.Any]] = result.get('vibrate_timings') if vibrate_timings_millis: - vibrate_timing_strings = [] + vibrate_timing_strings: typing.List[typing.Optional[str]] = [] for msec in vibrate_timings_millis: formated_string = cls.encode_milliseconds( 'AndroidNotification.vibrate_timings_millis', msec) @@ -375,14 +421,17 @@ def encode_android_notification(cls, notification): return result @classmethod - def encode_light_settings(cls, light_settings): + def encode_light_settings( + cls, + light_settings: typing.Optional[_messaging_utils.LightSettings], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes a ``LightSettings`` instance into JSON.""" if light_settings is None: return None if not isinstance(light_settings, _messaging_utils.LightSettings): raise ValueError( 'AndroidNotification.light_settings must be an instance of LightSettings class.') - result = { + result: typing.Dict[str, typing.Any] = { 'color': _Validators.check_string( 'LightSettings.color', light_settings.color, non_empty=True), 'light_on_duration': cls.encode_milliseconds( @@ -416,7 +465,10 @@ def encode_light_settings(cls, light_settings): return result @classmethod - def encode_webpush(cls, webpush): + def encode_webpush( + cls, + webpush: typing.Optional[_messaging_utils.WebpushConfig], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes a ``WebpushConfig`` instance into JSON.""" if webpush is None: return None @@ -433,7 +485,10 @@ def encode_webpush(cls, webpush): return cls.remove_null_values(result) @classmethod - def encode_webpush_notification(cls, notification): + def encode_webpush_notification( + cls, + notification: typing.Optional[_messaging_utils.WebpushNotification], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes a ``WebpushNotification`` instance into JSON.""" if notification is None: return None @@ -441,7 +496,7 @@ def encode_webpush_notification(cls, notification): raise ValueError('WebpushConfig.notification must be an instance of ' 'WebpushNotification class.') result = { - 'actions': cls.encode_webpush_notification_actions(notification.actions), + 'actions': MessageEncoder.encode_webpush_notification_actions(notification.actions), 'badge': _Validators.check_string( 'WebpushNotification.badge', notification.badge), 'body': _Validators.check_string( @@ -477,17 +532,20 @@ def encode_webpush_notification(cls, notification): raise ValueError( 'Multiple specifications for {0} in WebpushNotification.'.format(key)) result[key] = value - return cls.remove_null_values(result) + return MessageEncoder.remove_null_values(result) @classmethod - def encode_webpush_notification_actions(cls, actions): + def encode_webpush_notification_actions( + cls, + actions: typing.Optional[typing.List[_messaging_utils.WebpushNotificationAction]] + ) -> typing.Optional[typing.List[typing.Dict[str, str]]]: """Encodes a list of ``WebpushNotificationActions`` into JSON.""" if actions is None: return None if not isinstance(actions, list): raise ValueError('WebpushConfig.notification.actions must be a list of ' 'WebpushNotificationAction instances.') - results = [] + results: typing.List[typing.Dict[str, str]] = [] for action in actions: if not isinstance(action, _messaging_utils.WebpushNotificationAction): raise ValueError('WebpushConfig.notification.actions must be a list of ' @@ -504,7 +562,10 @@ def encode_webpush_notification_actions(cls, actions): return results @classmethod - def encode_webpush_fcm_options(cls, options): + def encode_webpush_fcm_options( + cls, + options: typing.Optional[_messaging_utils.WebpushFCMOptions], + ) -> typing.Optional[typing.Dict[str, str]]: """Encodes a ``WebpushFCMOptions`` instance into JSON.""" if options is None: return None @@ -518,7 +579,10 @@ def encode_webpush_fcm_options(cls, options): return result @classmethod - def encode_apns(cls, apns): + def encode_apns( + cls, + apns: typing.Optional[_messaging_utils.APNSConfig], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes an ``APNSConfig`` instance into JSON.""" if apns is None: return None @@ -533,13 +597,16 @@ def encode_apns(cls, apns): return cls.remove_null_values(result) @classmethod - def encode_apns_payload(cls, payload): + def encode_apns_payload( + cls, + payload: typing.Optional[_messaging_utils.APNSPayload], + ) -> typing.Optional[typing.Dict[str, typing.Any]]: """Encodes an ``APNSPayload`` instance into JSON.""" if payload is None: return None if not isinstance(payload, _messaging_utils.APNSPayload): raise ValueError('APNSConfig.payload must be an instance of APNSPayload class.') - result = { + result: typing.Dict[str, typing.Any] = { 'aps': cls.encode_aps(payload.aps) } for key, value in payload.custom_data.items(): @@ -547,7 +614,10 @@ def encode_apns_payload(cls, payload): return cls.remove_null_values(result) @classmethod - def encode_apns_fcm_options(cls, fcm_options): + def encode_apns_fcm_options( + cls, + fcm_options: typing.Optional[_messaging_utils.APNSFCMOptions], + ) -> typing.Optional[typing.Dict[str, str]]: """Encodes an ``APNSFCMOptions`` instance into JSON.""" if fcm_options is None: return None @@ -562,11 +632,11 @@ def encode_apns_fcm_options(cls, fcm_options): return result @classmethod - def encode_aps(cls, aps): + def encode_aps(cls, aps: _messaging_utils.Aps) -> typing.Dict[str, typing.Any]: """Encodes an ``Aps`` instance into JSON.""" if not isinstance(aps, _messaging_utils.Aps): raise ValueError('APNSPayload.aps must be an instance of Aps class.') - result = { + result: typing.Dict[str, typing.Any] = { 'alert': cls.encode_aps_alert(aps.alert), 'badge': _Validators.check_number('Aps.badge', aps.badge), 'sound': cls.encode_aps_sound(aps.sound), @@ -588,7 +658,10 @@ def encode_aps(cls, aps): return cls.remove_null_values(result) @classmethod - def encode_aps_sound(cls, sound): + def encode_aps_sound( + cls, + sound: typing.Optional[typing.Union[str, _messaging_utils.CriticalSound]], + ) -> typing.Optional[typing.Union[str, typing.Dict[str, typing.Any]]]: """Encodes an APNs sound configuration into JSON.""" if sound is None: return None @@ -597,7 +670,7 @@ def encode_aps_sound(cls, sound): if not isinstance(sound, _messaging_utils.CriticalSound): raise ValueError( 'Aps.sound must be a non-empty string or an instance of CriticalSound class.') - result = { + result: typing.Dict[str, typing.Any] = { 'name': _Validators.check_string('CriticalSound.name', sound.name, non_empty=True), 'volume': _Validators.check_number('CriticalSound.volume', sound.volume), } @@ -611,7 +684,10 @@ def encode_aps_sound(cls, sound): return cls.remove_null_values(result) @classmethod - def encode_aps_alert(cls, alert): + def encode_aps_alert( + cls, + alert: typing.Optional[typing.Union[_messaging_utils.ApsAlert, str]], + ) -> typing.Optional[typing.Union[str, typing.Dict[str, typing.Any]]]: """Encodes an ``ApsAlert`` instance into JSON.""" if alert is None: return None @@ -653,7 +729,10 @@ def encode_aps_alert(cls, alert): return cls.remove_null_values(result) @classmethod - def encode_notification(cls, notification): + def encode_notification( + cls, + notification: typing.Optional[_messaging_utils.Notification], + ) -> typing.Optional[typing.Dict[str, str]]: """Encodes a ``Notification`` instance into JSON.""" if notification is None: return None @@ -667,7 +746,7 @@ def encode_notification(cls, notification): return cls.remove_null_values(result) @classmethod - def sanitize_topic_name(cls, topic): + def sanitize_topic_name(cls, topic: typing.Optional[str]) -> typing.Optional[str]: """Removes the /topics/ prefix from the topic name, if present.""" if not topic: return None @@ -679,10 +758,10 @@ def sanitize_topic_name(cls, topic): raise ValueError('Malformed topic name.') return topic - def default(self, o): # pylint: disable=method-hidden + def default(self, o: typing.Any) -> typing.Dict[str, typing.Any]: # pylint: disable=method-hidden if not isinstance(o, Message): return json.JSONEncoder.default(self, o) - result = { + result: typing.Dict[str, typing.Any] = { 'android': MessageEncoder.encode_android(o.android), 'apns': MessageEncoder.encode_apns(o.apns), 'condition': _Validators.check_string( @@ -702,7 +781,10 @@ def default(self, o): # pylint: disable=method-hidden return result @classmethod - def encode_fcm_options(cls, fcm_options): + def encode_fcm_options( + cls, + fcm_options: typing.Optional[_messaging_utils.FCMOptions], + ) -> typing.Optional[typing.Dict[str, str]]: """Encodes an ``FCMOptions`` instance into JSON.""" if fcm_options is None: return None diff --git a/firebase_admin/_messaging_utils.py b/firebase_admin/_messaging_utils.py index ae1f5cc5..b097a2a0 100644 --- a/firebase_admin/_messaging_utils.py +++ b/firebase_admin/_messaging_utils.py @@ -14,8 +14,19 @@ """Types and utilities used by the messaging (FCM) module.""" +import datetime +import numbers +import typing + +import requests + from firebase_admin import exceptions +if typing.TYPE_CHECKING: + from _typeshed import Incomplete +else: + Incomplete = typing.Any + class Notification: """A notification that can be included in a message. @@ -26,7 +37,12 @@ class Notification: image: Image url of the notification (optional) """ - def __init__(self, title=None, body=None, image=None): + def __init__( + self, + title: typing.Optional[str] = None, + body: typing.Optional[str] = None, + image: typing.Optional[str] = None, + ) -> None: self.title = title self.body = body self.image = image @@ -53,8 +69,17 @@ class AndroidConfig: the app while the device is in direct boot mode (optional). """ - def __init__(self, collapse_key=None, priority=None, ttl=None, restricted_package_name=None, - data=None, notification=None, fcm_options=None, direct_boot_ok=None): + def __init__( + self, + collapse_key: typing.Optional[str] = None, + priority: typing.Optional[typing.Literal["high", "normal"]] = None, + ttl: typing.Optional[typing.Union[numbers.Real, datetime.timedelta]] = None, + restricted_package_name: typing.Optional[str] = None, + data: typing.Optional[typing.Dict[str, str]] = None, + notification: typing.Optional["AndroidNotification"] = None, + fcm_options: typing.Optional["AndroidFCMOptions"] = None, + direct_boot_ok: typing.Optional[bool] = None, + ) -> None: self.collapse_key = collapse_key self.priority = priority self.ttl = ttl @@ -153,13 +178,35 @@ class AndroidNotification: """ - def __init__(self, title=None, body=None, icon=None, color=None, sound=None, tag=None, - click_action=None, body_loc_key=None, body_loc_args=None, title_loc_key=None, - title_loc_args=None, channel_id=None, image=None, ticker=None, sticky=None, - event_timestamp=None, local_only=None, priority=None, vibrate_timings_millis=None, - default_vibrate_timings=None, default_sound=None, light_settings=None, - default_light_settings=None, visibility=None, notification_count=None, - proxy=None): + def __init__( + self, + title: typing.Optional[str] = None, + body: typing.Optional[str] = None, + icon: typing.Optional[str] = None, + color: typing.Optional[str] = None, + sound: typing.Optional[str] = None, + tag: typing.Optional[str] = None, + click_action: typing.Optional[Incomplete] = None, + body_loc_key: typing.Optional[str] = None, + body_loc_args: typing.Optional[typing.List[str]] = None, + title_loc_key: typing.Optional[str] = None, + title_loc_args: typing.Optional[typing.List[str]] = None, + channel_id: typing.Optional[Incomplete] = None, + image: typing.Optional[str] = None, + ticker: typing.Optional[Incomplete] = None, + sticky: typing.Optional[bool] = None, + event_timestamp: typing.Optional[datetime.datetime] = None, + local_only: typing.Optional[Incomplete] = None, + priority: typing.Optional[typing.Literal["default", "min", "low", "high", "max", "normal"]] = None, + vibrate_timings_millis: typing.Optional[float] = None, + default_vibrate_timings: typing.Optional[bool] = None, + default_sound: typing.Optional[bool] = None, + light_settings: typing.Optional["LightSettings"] = None, + default_light_settings: typing.Optional[bool] = None, + visibility: typing.Optional[typing.Literal["private", "public", "secret"]] = None, + notification_count: typing.Optional[int] = None, + proxy: typing.Optional[typing.Literal["allow", "deny"]] = None, + ) -> None: self.title = title self.body = body self.icon = icon @@ -199,8 +246,12 @@ class LightSettings: light_off_duration_millis: Along with ``light_on_duration``, defines the blink rate of LED flashes. """ - def __init__(self, color, light_on_duration_millis, - light_off_duration_millis): + def __init__( + self, + color: str, + light_on_duration_millis: typing.Union[numbers.Real, datetime.timedelta], + light_off_duration_millis: typing.Union[numbers.Real, datetime.timedelta], + ) -> None: self.color = color self.light_on_duration_millis = light_on_duration_millis self.light_off_duration_millis = light_off_duration_millis @@ -214,7 +265,7 @@ class AndroidFCMOptions: (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label: typing.Optional[Incomplete] = None) -> None: self.analytics_label = analytics_label @@ -233,7 +284,13 @@ class WebpushConfig: .. _Webpush Specification: https://tools.ietf.org/html/rfc8030#section-5 """ - def __init__(self, headers=None, data=None, notification=None, fcm_options=None): + def __init__( + self, + headers: typing.Optional[typing.Dict[str, str]] = None, + data: typing.Optional[typing.Dict[str, str]] = None, + notification: typing.Optional["WebpushNotification"] = None, + fcm_options: typing.Optional["WebpushFCMOptions"] = None, + ) -> None: self.headers = headers self.data = data self.notification = notification @@ -249,7 +306,7 @@ class WebpushNotificationAction: icon: Icon URL for the action (optional). """ - def __init__(self, action, title, icon=None): + def __init__(self, action: str, title: str, icon: typing.Optional[str] = None) -> None: self.action = action self.title = title self.icon = icon @@ -290,10 +347,25 @@ class WebpushNotification: /notification/Notification """ - def __init__(self, title=None, body=None, icon=None, actions=None, badge=None, data=None, - direction=None, image=None, language=None, renotify=None, - require_interaction=None, silent=None, tag=None, timestamp_millis=None, - vibrate=None, custom_data=None): + def __init__( + self, + title: typing.Optional[str] = None, + body: typing.Optional[str] = None, + icon: typing.Optional[str] = None, + actions: typing.Optional[typing.List[WebpushNotificationAction]] = None, + badge: typing.Optional[str] = None, + data: typing.Optional[typing.Any] = None, + direction: typing.Optional[typing.Literal["auto", "ltr", "rtl"]] = None, + image: typing.Optional[str] = None, + language: typing.Optional[str] = None, + renotify: typing.Optional[bool] = None, + require_interaction: typing.Optional[bool] = None, + silent: typing.Optional[bool] = None, + tag: typing.Optional[str] = None, + timestamp_millis: typing.Optional[int] = None, + vibrate: typing.Optional[typing.List[int]] = None, + custom_data: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> None: self.title = title self.body = body self.icon = icon @@ -320,7 +392,7 @@ class WebpushFCMOptions: (optional). """ - def __init__(self, link=None): + def __init__(self, link: typing.Optional[str] = None) -> None: self.link = link @@ -339,7 +411,12 @@ class APNSConfig: /NetworkingInternet/Conceptual/RemoteNotificationsPG/CommunicatingwithAPNs.html """ - def __init__(self, headers=None, payload=None, fcm_options=None): + def __init__( + self, + headers: typing.Optional[typing.Dict[str, str]] = None, + payload: typing.Optional["APNSPayload"] = None, + fcm_options: typing.Optional["APNSFCMOptions"] = None, + ) -> None: self.headers = headers self.payload = payload self.fcm_options = fcm_options @@ -354,7 +431,7 @@ class APNSPayload: (optional). """ - def __init__(self, aps, **kwargs): + def __init__(self, aps: "Aps", **kwargs: typing.Any) -> None: self.aps = aps self.custom_data = kwargs @@ -377,8 +454,17 @@ class Aps: (optional). """ - def __init__(self, alert=None, badge=None, sound=None, content_available=None, category=None, - thread_id=None, mutable_content=None, custom_data=None): + def __init__( + self, + alert: typing.Optional[typing.Union["ApsAlert", str]] = None, + badge: typing.Optional[float] = None, # should it be int? + sound: typing.Optional[typing.Union[str, "CriticalSound"]] = None, + content_available: typing.Optional[bool] = None, + category: typing.Optional[str] = None, + thread_id: typing.Optional[str] = None, + mutable_content: typing.Optional[bool] = None, + custom_data: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> None: self.alert = alert self.badge = badge self.sound = sound @@ -402,7 +488,12 @@ class CriticalSound: and 1.0 (full volume) (optional). """ - def __init__(self, name, critical=None, volume=None): + def __init__( + self, + name: str, + critical: typing.Optional[bool] = None, + volume: typing.Optional[float] = None, + ) -> None: self.name = name self.critical = critical self.volume = volume @@ -432,9 +523,19 @@ class ApsAlert: (optional) """ - def __init__(self, title=None, subtitle=None, body=None, loc_key=None, loc_args=None, - title_loc_key=None, title_loc_args=None, action_loc_key=None, launch_image=None, - custom_data=None): + def __init__( + self, + title: typing.Optional[str] = None, + subtitle: typing.Optional[str] = None, + body: typing.Optional[str] = None, + loc_key: typing.Optional[str] = None, + loc_args: typing.Optional[typing.List[str]] = None, + title_loc_key: typing.Optional[str] = None, + title_loc_args: typing.Optional[typing.List[str]] = None, + action_loc_key: typing.Optional[str] = None, + launch_image: typing.Optional[str] = None, + custom_data: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> None: self.title = title self.subtitle = subtitle self.body = body @@ -457,7 +558,11 @@ class APNSFCMOptions: (optional). """ - def __init__(self, analytics_label=None, image=None): + def __init__( + self, + analytics_label: typing.Optional[Incomplete] = None, + image: typing.Optional[str] = None, + ) -> None: self.analytics_label = analytics_label self.image = image @@ -469,28 +574,43 @@ class FCMOptions: analytics_label: contains additional options to use across all platforms (optional). """ - def __init__(self, analytics_label=None): + def __init__(self, analytics_label: typing.Optional[Incomplete] = None) -> None: self.analytics_label = analytics_label class ThirdPartyAuthError(exceptions.UnauthenticatedError): """APNs certificate or web push auth key was invalid or missing.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.UnauthenticatedError.__init__(self, message, cause, http_response) class QuotaExceededError(exceptions.ResourceExhaustedError): """Sending limit exceeded for the message target.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.ResourceExhaustedError.__init__(self, message, cause, http_response) class SenderIdMismatchError(exceptions.PermissionDeniedError): """The authenticated sender ID is different from the sender ID for the registration token.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.PermissionDeniedError.__init__(self, message, cause, http_response) @@ -499,5 +619,10 @@ class UnregisteredError(exceptions.NotFoundError): This usually means that the token used is no longer valid and a new one must be used.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None, + ) -> None: exceptions.NotFoundError.__init__(self, message, cause, http_response) diff --git a/firebase_admin/_rfc3339.py b/firebase_admin/_rfc3339.py index 2c720bdd..32bafccf 100644 --- a/firebase_admin/_rfc3339.py +++ b/firebase_admin/_rfc3339.py @@ -14,10 +14,11 @@ """Parse RFC3339 date strings""" -from datetime import datetime, timezone +import datetime import re -def parse_to_epoch(datestr): + +def parse_to_epoch(datestr: str) -> float: """Parse an RFC3339 date string and return the number of seconds since the epoch (as a float). @@ -37,7 +38,7 @@ def parse_to_epoch(datestr): return _parse_to_datetime(datestr).timestamp() -def _parse_to_datetime(datestr): +def _parse_to_datetime(datestr: str) -> datetime.datetime: """Parse an RFC3339 date string and return a python datetime instance. Args: @@ -55,16 +56,16 @@ def _parse_to_datetime(datestr): # This format is the one we actually expect to occur from our backend. The # others are only present because the spec says we *should* accept them. try: - return datetime.strptime( + return datetime.datetime.strptime( datestr_modified, '%Y-%m-%dT%H:%M:%S.%fZ' - ).replace(tzinfo=timezone.utc) + ).replace(tzinfo=datetime.timezone.utc) except ValueError: pass try: - return datetime.strptime( + return datetime.datetime.strptime( datestr_modified, '%Y-%m-%dT%H:%M:%SZ' - ).replace(tzinfo=timezone.utc) + ).replace(tzinfo=datetime.timezone.utc) except ValueError: pass @@ -75,12 +76,12 @@ def _parse_to_datetime(datestr): datestr_modified = re.sub(r'(\d\d):(\d\d)$', r'\1\2', datestr_modified) try: - return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') + return datetime.datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S.%f%z') except ValueError: pass try: - return datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') + return datetime.datetime.strptime(datestr_modified, '%Y-%m-%dT%H:%M:%S%z') except ValueError: pass diff --git a/firebase_admin/_sseclient.py b/firebase_admin/_sseclient.py index 6585dfc8..be1fde18 100644 --- a/firebase_admin/_sseclient.py +++ b/firebase_admin/_sseclient.py @@ -19,9 +19,12 @@ import re import time +import typing +import typing_extensions import warnings -from google.auth import transport +import google.auth.credentials +import google.auth.transport.requests import requests @@ -30,48 +33,48 @@ end_of_field = re.compile(r'\r\n\r\n|\r\r|\n\n') -class KeepAuthSession(transport.requests.AuthorizedSession): +class KeepAuthSession(google.auth.transport.requests.AuthorizedSession): """A session that does not drop authentication on redirects between domains.""" - def __init__(self, credential): - super(KeepAuthSession, self).__init__(credential) + def __init__(self, credential: typing.Optional[google.auth.credentials.Credentials]) -> None: + super(KeepAuthSession, self).__init__(credential) # type: ignore[reportUnknownMemberType] - def rebuild_auth(self, prepared_request, response): + def rebuild_auth(self, prepared_request: requests.PreparedRequest, response: requests.Response) -> None: pass class _EventBuffer: """A helper class for buffering and parsing raw SSE data.""" - def __init__(self): - self._buffer = [] + def __init__(self) -> None: + self._buffer: typing.List[str] = [] self._tail = '' - def append(self, char): + def append(self, char: str) -> None: self._buffer.append(char) self._tail += char self._tail = self._tail[-4:] - def truncate(self): + def truncate(self) -> None: head, sep, _ = self.buffer_string.rpartition('\n') rem = head + sep self._buffer = list(rem) self._tail = rem[-4:] @property - def is_end_of_field(self): + def is_end_of_field(self) -> bool: last_two_chars = self._tail[-2:] return last_two_chars == '\n\n' or last_two_chars == '\r\r' or self._tail == '\r\n\r\n' @property - def buffer_string(self): + def buffer_string(self) -> str: return ''.join(self._buffer) class SSEClient: """SSE client implementation.""" - def __init__(self, url, session, retry=3000, **kwargs): + def __init__(self, url: str, session: requests.Session, retry: int = 3000, **kwargs: typing.Any) -> None: """Initializes the SSEClient. Args: @@ -85,7 +88,7 @@ def __init__(self, url, session, retry=3000, **kwargs): self.retry = retry self.requests_kwargs = kwargs self.should_connect = True - self.last_id = None + self.last_id: typing.Optional[str] = None self.buf = u'' # Keep data here as it streams in headers = self.requests_kwargs.get('headers', {}) @@ -96,13 +99,13 @@ def __init__(self, url, session, retry=3000, **kwargs): self.requests_kwargs['headers'] = headers self._connect() - def close(self): + def close(self) -> None: """Closes the SSEClient instance.""" self.should_connect = False self.retry = 0 self.resp.close() - def _connect(self): + def _connect(self) -> None: """Connects to the server using requests.""" if self.should_connect: if self.last_id: @@ -113,10 +116,10 @@ def _connect(self): else: raise StopIteration() - def __iter__(self): + def __iter__(self) -> typing.Iterator[typing.Optional["Event"]]: return self - def __next__(self): + def __next__(self) -> typing.Optional["Event"]: if not re.search(end_of_field, self.buf): temp_buffer = _EventBuffer() while not temp_buffer.is_end_of_field: @@ -153,7 +156,7 @@ def __next__(self): self.last_id = event.event_id return event - def next(self): + def next(self) -> typing.Optional["Event"]: return self.__next__() @@ -162,14 +165,20 @@ class Event: sse_line_pattern = re.compile('(?P[^:]*):?( ?(?P.*))?') - def __init__(self, data='', event_type='message', event_id=None, retry=None): + def __init__( + self, + data: str = '', + event_type: str = 'message', + event_id: typing.Optional[str] = None, + retry: typing.Optional[int] = None, + ) -> None: self.data = data self.event_type = event_type self.event_id = event_id self.retry = retry @classmethod - def parse(cls, raw): + def parse(cls, raw: str) -> typing_extensions.Self: """Given a possibly-multiline string representing an SSE message, parses it and returns an Event object. diff --git a/firebase_admin/_token_gen.py b/firebase_admin/_token_gen.py index a2fc725e..e7d2f981 100644 --- a/firebase_admin/_token_gen.py +++ b/firebase_admin/_token_gen.py @@ -16,20 +16,29 @@ import datetime import time +import typing import cachecontrol import requests from google.auth import credentials from google.auth import iam from google.auth import jwt -from google.auth import transport +import google.auth.transport.requests +import google.auth.crypt import google.auth.exceptions import google.oauth2.id_token import google.oauth2.service_account +import firebase_admin from firebase_admin import exceptions from firebase_admin import _auth_utils from firebase_admin import _http_client +from firebase_admin import _typing + +if typing.TYPE_CHECKING: + from _typeshed import Incomplete +else: + Incomplete = typing.Any # ID token constants @@ -61,19 +70,26 @@ class _EmulatedSigner(google.auth.crypt.Signer): - key_id = None + @property + def key_id(self) -> typing.Optional[str]: + return None - def __init__(self): + def __init__(self) -> None: pass - def sign(self, message): + def sign(self, message: typing.Union[str, bytes]) -> bytes: return b'' class _SigningProvider: """Stores a reference to a google.auth.crypto.Signer.""" - def __init__(self, signer, signer_email, alg=ALGORITHM_RS256): + def __init__( + self, + signer: google.auth.crypt.Signer, + signer_email: typing.Optional[str], + alg: str = ALGORITHM_RS256, + ) -> None: self._signer = signer self._signer_email = signer_email self._alg = alg @@ -87,20 +103,28 @@ def signer_email(self): return self._signer_email @property - def alg(self): + def alg(self) -> str: return self._alg @classmethod - def from_credential(cls, google_cred): - return _SigningProvider(google_cred.signer, google_cred.signer_email) + def from_credential( + cls, + google_cred: typing.Union[google.oauth2.service_account.Credentials, credentials.Signing] + ) -> "_SigningProvider": + return _SigningProvider(google_cred.signer, google_cred.signer_email) # type: ignore[reportUnknownMemberType] @classmethod - def from_iam(cls, request, google_cred, service_account): + def from_iam( + cls, + request: google.auth.transport.Request, + google_cred: credentials.Credentials, + service_account: str, + ) -> "_SigningProvider": signer = iam.Signer(request, google_cred, service_account) return _SigningProvider(signer, service_account) @classmethod - def for_emulator(cls): + def for_emulator(cls) -> "_SigningProvider": return _SigningProvider(_EmulatedSigner(), AUTH_EMULATOR_EMAIL, ALGORITHM_NONE) @@ -109,15 +133,20 @@ class TokenGenerator: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, app, http_client, url_override=None): + def __init__( + self, + app: firebase_admin.App, + http_client: _http_client.HttpClient[typing.Dict[str, "_typing.Json"]], + url_override: typing.Optional[str] = None, + ) -> None: self.app = app self.http_client = http_client - self.request = transport.requests.Request() + self.request = google.auth.transport.requests.Request() url_prefix = url_override or self.ID_TOOLKIT_URL self.base_url = '{0}/projects/{1}'.format(url_prefix, app.project_id) - self._signing_provider = None + self._signing_provider: typing.Optional[_SigningProvider] = None - def _init_signing_provider(self): + def _init_signing_provider(self) -> _SigningProvider: """Initializes a signing provider by following the go/firebase-admin-sign protocol.""" if _auth_utils.is_emulated(): return _SigningProvider.for_emulator() @@ -140,14 +169,14 @@ def _init_signing_provider(self): # Attempt to discover a service account email from the local Metadata service. Use it # with the IAM service to sign bytes. resp = self.request(url=METADATA_SERVICE_URL, headers={'Metadata-Flavor': 'Google'}) - if resp.status != 200: + if resp.status != 200: # type: ignore[reportUnknownMemberType] raise ValueError( - 'Failed to contact the local metadata service: {0}.'.format(resp.data.decode())) - service_account = resp.data.decode() + 'Failed to contact the local metadata service: {0}.'.format(resp.data.decode())) # type: ignore[reportUnknownMemberType] + service_account = typing.cast(str, resp.data.decode()) # type: ignore[reportUnknownMemberType] return _SigningProvider.from_iam(self.request, google_cred, service_account) @property - def signing_provider(self): + def signing_provider(self) -> _SigningProvider: """Initializes and returns the SigningProvider instance to be used.""" if not self._signing_provider: try: @@ -161,7 +190,12 @@ def signing_provider(self): 'details on creating custom tokens.'.format(error, url)) return self._signing_provider - def create_custom_token(self, uid, developer_claims=None, tenant_id=None): + def create_custom_token( + self, + uid: str, + developer_claims: typing.Optional[typing.Dict[str, typing.Any]] = None, + tenant_id: typing.Optional[str] = None + ): """Builds and signs a Firebase custom auth token.""" if developer_claims is not None: if not isinstance(developer_claims, dict): @@ -184,7 +218,7 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): signing_provider = self.signing_provider now = int(time.time()) - payload = { + payload: typing.Dict[str, typing.Any] = { 'iss': signing_provider.signer_email, 'sub': signing_provider.signer_email, 'aud': FIREBASE_AUDIENCE, @@ -200,13 +234,17 @@ def create_custom_token(self, uid, developer_claims=None, tenant_id=None): header = {'alg': signing_provider.alg} try: - return jwt.encode(signing_provider.signer, payload, header=header) + return jwt.encode(signing_provider.signer, payload, header=header) # type: ignore[reportUnknownMemberType] except google.auth.exceptions.TransportError as error: msg = 'Failed to sign custom token. {0}'.format(error) raise TokenSignError(msg, error) - def create_session_cookie(self, id_token, expires_in): + def create_session_cookie( + self, + id_token: typing.Union[bytes, str], + expires_in: typing.Union[datetime.timedelta, int] + ) -> str: """Creates a session cookie from the provided ID token.""" id_token = id_token.decode('utf-8') if isinstance(id_token, bytes) else id_token if not isinstance(id_token, str) or not id_token: @@ -238,38 +276,46 @@ def create_session_cookie(self, id_token, expires_in): if not body or not body.get('sessionCookie'): raise _auth_utils.UnexpectedResponseError( 'Failed to create session cookie.', http_response=http_resp) - return body.get('sessionCookie') + return typing.cast(str, body['sessionCookie']) -class CertificateFetchRequest(transport.Request): +class CertificateFetchRequest(google.auth.transport.Request): """A google-auth transport that supports HTTP cache-control. Also injects a timeout to each outgoing HTTP request. """ - def __init__(self, timeout_seconds=None): + def __init__(self, timeout_seconds: typing.Optional[float] = None) -> None: self._session = cachecontrol.CacheControl(requests.Session()) - self._delegate = transport.requests.Request(self.session) + self._delegate = google.auth.transport.requests.Request(self.session) self._timeout_seconds = timeout_seconds @property - def session(self): + def session(self) -> requests.Session: return self._session @property - def timeout_seconds(self): + def timeout_seconds(self) -> typing.Optional[float]: return self._timeout_seconds - def __call__(self, url, method='GET', body=None, headers=None, timeout=None, **kwargs): + def __call__( + self, + url: str, + method: str = 'GET', + body: typing.Optional[Incomplete] = None, + headers: typing.Optional[typing.Mapping[str, str]] = None, + timeout: typing.Optional[float] = None, + **kwargs: Incomplete + ) -> google.auth.transport.Response: timeout = timeout or self.timeout_seconds return self._delegate( - url, method=method, body=body, headers=headers, timeout=timeout, **kwargs) + url, method=method, body=body, headers=headers, timeout=timeout, **kwargs) # type: ignore[reportArgumentType] class TokenVerifier: """Verifies ID tokens and session cookies.""" - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self.request = CertificateFetchRequest(timeout) self.id_token_verifier = _JWTVerifier( @@ -289,31 +335,48 @@ def __init__(self, app): invalid_token_error=InvalidSessionCookieError, expired_token_error=ExpiredSessionCookieError) - def verify_id_token(self, id_token, clock_skew_seconds=0): + def verify_id_token(self, id_token: typing.Union[bytes, str], clock_skew_seconds: int = 0): return self.id_token_verifier.verify(id_token, self.request, clock_skew_seconds) - def verify_session_cookie(self, cookie, clock_skew_seconds=0): + def verify_session_cookie(self, cookie: typing.Union[bytes, str], clock_skew_seconds: int = 0): return self.cookie_verifier.verify(cookie, self.request, clock_skew_seconds) class _JWTVerifier: """Verifies Firebase JWTs (ID tokens or session cookies).""" - def __init__(self, **kwargs): - self.project_id = kwargs.pop('project_id') - self.short_name = kwargs.pop('short_name') - self.operation = kwargs.pop('operation') - self.url = kwargs.pop('doc_url') - self.cert_url = kwargs.pop('cert_url') - self.issuer = kwargs.pop('issuer') + def __init__( + self, + *, + project_id: typing.Optional[str], + short_name: str, + operation: str, + doc_url: str, + cert_url: str, + issuer: str, + invalid_token_error: _typing.FirebaseErrorFactoryNoHttpWithDefaults, + expired_token_error: _typing.FirebaseErrorFactoryNoHttp, + **kwargs: typing.Any, + ) -> None: + self.project_id = project_id + self.short_name = short_name + self.operation = operation + self.url = doc_url + self.cert_url = cert_url + self.issuer = issuer if self.short_name[0].lower() in 'aeiou': self.articled_short_name = 'an {0}'.format(self.short_name) else: self.articled_short_name = 'a {0}'.format(self.short_name) - self._invalid_token_error = kwargs.pop('invalid_token_error') - self._expired_token_error = kwargs.pop('expired_token_error') - - def verify(self, token, request, clock_skew_seconds=0): + self._invalid_token_error = invalid_token_error + self._expired_token_error = expired_token_error + + def verify( + self, + token: typing.Union[bytes, str], + request: google.auth.transport.Request, + clock_skew_seconds: int = 0, + ) -> typing.Dict[str, typing.Any]: """Verifies the signature and data for the provided JWT.""" token = token.encode('utf-8') if isinstance(token, str) else token if not isinstance(token, bytes) or not token: @@ -392,9 +455,9 @@ def verify(self, token, request, clock_skew_seconds=0): try: if emulated: - verified_claims = payload + verified_claims: typing.Dict[str, typing.Any] = payload else: - verified_claims = google.oauth2.id_token.verify_token( + verified_claims = google.oauth2.id_token.verify_token( # type: ignore[reportUnknownMemberType] token, request=request, audience=self.project_id, @@ -407,61 +470,64 @@ def verify(self, token, request, clock_skew_seconds=0): except ValueError as error: if 'Token expired' in str(error): raise self._expired_token_error(str(error), cause=error) - raise self._invalid_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), error) - def _decode_unverified(self, token): + def _decode_unverified( + self, + token: typing.Union[bytes, str], + ) -> typing.Tuple[typing.Dict[str, str], typing.Dict[str, typing.Any]]: try: - header = jwt.decode_header(token) - payload = jwt.decode(token, verify=False) - return header, payload + header = typing.cast(typing.Mapping[str, str], jwt.decode_header(token)) # type: ignore[reportUnknownMemberType] + payload = typing.cast(typing.Mapping[str, str], jwt.decode(token, verify=False)) # type: ignore[reportUnknownMemberType] + return dict(header), dict(payload) except ValueError as error: - raise self._invalid_token_error(str(error), cause=error) + raise self._invalid_token_error(str(error), error) class TokenSignError(exceptions.UnknownError): """Unexpected error while signing a Firebase custom token.""" - def __init__(self, message, cause): + def __init__(self, message: str, cause: typing.Optional[Exception]) -> None: exceptions.UnknownError.__init__(self, message, cause) class CertificateFetchError(exceptions.UnknownError): """Failed to fetch some public key certificates required to verify a token.""" - def __init__(self, message, cause): + def __init__(self, message: str, cause: typing.Optional[Exception]) -> None: exceptions.UnknownError.__init__(self, message, cause) class ExpiredIdTokenError(_auth_utils.InvalidIdTokenError): """The provided ID token is expired.""" - def __init__(self, message, cause): + def __init__(self, message: str, cause: typing.Optional[Exception]) -> None: _auth_utils.InvalidIdTokenError.__init__(self, message, cause) class RevokedIdTokenError(_auth_utils.InvalidIdTokenError): """The provided ID token has been revoked.""" - def __init__(self, message): + def __init__(self, message: str) -> None: _auth_utils.InvalidIdTokenError.__init__(self, message) class InvalidSessionCookieError(exceptions.InvalidArgumentError): """The provided string is not a valid Firebase session cookie.""" - def __init__(self, message, cause=None): + def __init__(self, message: str, cause: typing.Optional[Exception] = None) -> None: exceptions.InvalidArgumentError.__init__(self, message, cause) class ExpiredSessionCookieError(InvalidSessionCookieError): """The provided session cookie is expired.""" - def __init__(self, message, cause): + def __init__(self, message: str, cause: typing.Optional[Exception]) -> None: InvalidSessionCookieError.__init__(self, message, cause) class RevokedSessionCookieError(InvalidSessionCookieError): """The provided session cookie has been revoked.""" - def __init__(self, message): + def __init__(self, message: str) -> None: InvalidSessionCookieError.__init__(self, message) diff --git a/firebase_admin/_user_identifier.py b/firebase_admin/_user_identifier.py index 85a224e0..9afb146b 100644 --- a/firebase_admin/_user_identifier.py +++ b/firebase_admin/_user_identifier.py @@ -26,7 +26,7 @@ class UidIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, uid): + def __init__(self, uid: str) -> None: """Constructs a new `UidIdentifier` object. Args: @@ -35,7 +35,7 @@ def __init__(self, uid): self._uid = _auth_utils.validate_uid(uid, required=True) @property - def uid(self): + def uid(self) -> str: return self._uid @@ -45,7 +45,7 @@ class EmailIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, email): + def __init__(self, email: str) -> None: """Constructs a new `EmailIdentifier` object. Args: @@ -54,7 +54,7 @@ def __init__(self, email): self._email = _auth_utils.validate_email(email, required=True) @property - def email(self): + def email(self) -> str: return self._email @@ -64,7 +64,7 @@ class PhoneIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, phone_number): + def __init__(self, phone_number: str) -> None: """Constructs a new `PhoneIdentifier` object. Args: @@ -73,7 +73,7 @@ def __init__(self, phone_number): self._phone_number = _auth_utils.validate_phone(phone_number, required=True) @property - def phone_number(self): + def phone_number(self) -> str: return self._phone_number @@ -83,21 +83,21 @@ class ProviderIdentifier(UserIdentifier): See ``auth.get_user()``. """ - def __init__(self, provider_id, provider_uid): + def __init__(self, provider_id: str, provider_uid: str) -> None: """Constructs a new `ProviderIdentifier` object. -   Args: -     provider_id: A provider ID string. -     provider_uid: A provider UID string. + Args: + provider_id: A provider ID string. + provider_uid: A provider UID string. """ self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) self._provider_uid = _auth_utils.validate_provider_uid( provider_uid, required=True) @property - def provider_id(self): + def provider_id(self) -> str: return self._provider_id @property - def provider_uid(self): + def provider_uid(self) -> str: return self._provider_uid diff --git a/firebase_admin/_user_import.py b/firebase_admin/_user_import.py index 659a6870..623ead0f 100644 --- a/firebase_admin/_user_import.py +++ b/firebase_admin/_user_import.py @@ -16,11 +16,14 @@ import base64 import json +import typing from firebase_admin import _auth_utils +from firebase_admin import _user_mgt +from firebase_admin import _typing -def b64_encode(bytes_value): +def b64_encode(bytes_value: bytes) -> str: return base64.urlsafe_b64encode(bytes_value).decode() @@ -39,7 +42,14 @@ class UserProvider: photo_url: User's photo URL (optional). """ - def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=None): + def __init__( + self, + uid: str, + provider_id: str, + email: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + ) -> None: self.uid = uid self.provider_id = provider_id self.email = email @@ -47,46 +57,46 @@ def __init__(self, uid, provider_id, email=None, display_name=None, photo_url=No self.photo_url = photo_url @property - def uid(self): + def uid(self) -> str: return self._uid @uid.setter - def uid(self, uid): + def uid(self, uid: str) -> None: self._uid = _auth_utils.validate_uid(uid, required=True) @property - def provider_id(self): + def provider_id(self) -> str: return self._provider_id @provider_id.setter - def provider_id(self, provider_id): + def provider_id(self, provider_id: str) -> None: self._provider_id = _auth_utils.validate_provider_id(provider_id, required=True) @property - def email(self): + def email(self) -> typing.Optional[str]: return self._email @email.setter - def email(self, email): + def email(self, email: typing.Optional[str]) -> None: self._email = _auth_utils.validate_email(email) @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: return self._display_name @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: typing.Optional[str]) -> None: self._display_name = _auth_utils.validate_display_name(display_name) @property - def photo_url(self): + def photo_url(self) -> typing.Optional[str]: return self._photo_url @photo_url.setter - def photo_url(self, photo_url): + def photo_url(self, photo_url: typing.Optional[str]): self._photo_url = _auth_utils.validate_photo_url(photo_url) - def to_dict(self): + def to_dict(self) -> typing.Dict[str, str]: payload = { 'rawId': self.uid, 'providerId': self.provider_id, @@ -123,9 +133,21 @@ class ImportUserRecord: ValueError: If provided arguments are invalid. """ - def __init__(self, uid, email=None, email_verified=None, display_name=None, phone_number=None, - photo_url=None, disabled=None, user_metadata=None, provider_data=None, - custom_claims=None, password_hash=None, password_salt=None): + def __init__( + self, + uid: str, + email: typing.Optional[str] = None, + email_verified: typing.Optional[bool] = None, + display_name: typing.Optional[str] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + user_metadata: typing.Optional[_user_mgt.UserMetadata] = None, + provider_data: typing.Optional[typing.List[UserProvider]] = None, + custom_claims: typing.Optional[typing.Dict[str, typing.Any]] = None, + password_hash: typing.Optional[bytes] = None, + password_salt: typing.Optional[bytes] = None, + ) -> None: self.uid = uid self.email = email self.display_name = display_name @@ -140,67 +162,67 @@ def __init__(self, uid, email=None, email_verified=None, display_name=None, phon self.custom_claims = custom_claims @property - def uid(self): + def uid(self) -> str: return self._uid @uid.setter - def uid(self, uid): + def uid(self, uid: str) -> None: self._uid = _auth_utils.validate_uid(uid, required=True) @property - def email(self): + def email(self) -> typing.Optional[str]: return self._email @email.setter - def email(self, email): + def email(self, email: typing.Optional[str]) -> None: self._email = _auth_utils.validate_email(email) @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: return self._display_name @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: typing.Optional[str]) -> None: self._display_name = _auth_utils.validate_display_name(display_name) @property - def phone_number(self): + def phone_number(self) -> typing.Optional[str]: return self._phone_number @phone_number.setter - def phone_number(self, phone_number): + def phone_number(self, phone_number: typing.Optional[str]) -> None: self._phone_number = _auth_utils.validate_phone(phone_number) @property - def photo_url(self): + def photo_url(self) -> typing.Optional[str]: return self._photo_url @photo_url.setter - def photo_url(self, photo_url): + def photo_url(self, photo_url: typing.Optional[str]) -> None: self._photo_url = _auth_utils.validate_photo_url(photo_url) @property - def password_hash(self): + def password_hash(self) -> typing.Optional[bytes]: return self._password_hash @password_hash.setter - def password_hash(self, password_hash): + def password_hash(self, password_hash: typing.Optional[bytes]) -> None: self._password_hash = _auth_utils.validate_bytes(password_hash, 'password_hash') @property - def password_salt(self): + def password_salt(self) -> typing.Optional[bytes]: return self._password_salt @password_salt.setter - def password_salt(self, password_salt): + def password_salt(self, password_salt: typing.Optional[bytes]) -> None: self._password_salt = _auth_utils.validate_bytes(password_salt, 'password_salt') @property - def user_metadata(self): + def user_metadata(self) -> typing.Optional[_user_mgt.UserMetadata]: return self._user_metadata @user_metadata.setter - def user_metadata(self, user_metadata): + def user_metadata(self, user_metadata: typing.Optional[_user_mgt.UserMetadata]) -> None: created_at = user_metadata.creation_timestamp if user_metadata is not None else None last_login_at = user_metadata.last_sign_in_timestamp if user_metadata is not None else None self._created_at = _auth_utils.validate_timestamp(created_at, 'creation_timestamp') @@ -209,11 +231,11 @@ def user_metadata(self, user_metadata): self._user_metadata = user_metadata @property - def provider_data(self): + def provider_data(self) -> typing.Optional[typing.List[UserProvider]]: return self._provider_data @provider_data.setter - def provider_data(self, provider_data): + def provider_data(self, provider_data: typing.Optional[typing.List[UserProvider]]) -> None: if provider_data is not None: try: if any([not isinstance(p, UserProvider) for p in provider_data]): @@ -223,19 +245,19 @@ def provider_data(self, provider_data): self._provider_data = provider_data @property - def custom_claims(self): + def custom_claims(self) -> typing.Optional[typing.Dict[str, typing.Any]]: return self._custom_claims @custom_claims.setter - def custom_claims(self, custom_claims): + def custom_claims(self, custom_claims: typing.Optional[typing.Dict[str, typing.Any]]) -> None: json_claims = json.dumps(custom_claims) if isinstance( custom_claims, dict) else custom_claims self._custom_claims_str = _auth_utils.validate_custom_claims(json_claims) self._custom_claims = custom_claims - def to_dict(self): + def to_dict(self) -> typing.Dict[str, typing.Any]: """Returns a dict representation of the user. For internal use only.""" - payload = { + payload: typing.Dict[str, typing.Any] = { 'localId': self.uid, 'email': self.email, 'displayName': self.display_name, @@ -265,25 +287,25 @@ class UserImportHash: .. _documentation: https://firebase.google.com/docs/auth/admin/import-users """ - def __init__(self, name, data=None): + def __init__(self, name: str, data: typing.Optional[typing.Dict[str, typing.Any]] = None) -> None: self._name = name self._data = data - def to_dict(self): - payload = {'hashAlgorithm': self._name} + def to_dict(self) -> typing.Dict[str, typing.Any]: + payload: typing.Dict[str, typing.Any] = {'hashAlgorithm': self._name} if self._data: payload.update(self._data) return payload @classmethod - def _hmac(cls, name, key): + def _hmac(cls, name: str, key: bytes) -> "UserImportHash": data = { 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)) } return UserImportHash(name, data) @classmethod - def hmac_sha512(cls, key): + def hmac_sha512(cls, key: bytes) -> "UserImportHash": """Creates a new HMAC SHA512 algorithm instance. Args: @@ -295,7 +317,7 @@ def hmac_sha512(cls, key): return cls._hmac('HMAC_SHA512', key) @classmethod - def hmac_sha256(cls, key): + def hmac_sha256(cls, key: bytes) -> "UserImportHash": """Creates a new HMAC SHA256 algorithm instance. Args: @@ -307,7 +329,7 @@ def hmac_sha256(cls, key): return cls._hmac('HMAC_SHA256', key) @classmethod - def hmac_sha1(cls, key): + def hmac_sha1(cls, key: bytes) -> "UserImportHash": """Creates a new HMAC SHA1 algorithm instance. Args: @@ -319,7 +341,7 @@ def hmac_sha1(cls, key): return cls._hmac('HMAC_SHA1', key) @classmethod - def hmac_md5(cls, key): + def hmac_md5(cls, key: bytes) -> "UserImportHash": """Creates a new HMAC MD5 algorithm instance. Args: @@ -331,7 +353,7 @@ def hmac_md5(cls, key): return cls._hmac('HMAC_MD5', key) @classmethod - def md5(cls, rounds): + def md5(cls, rounds: int) -> "UserImportHash": """Creates a new MD5 algorithm instance. Args: @@ -345,7 +367,7 @@ def md5(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 8192)}) @classmethod - def sha1(cls, rounds): + def sha1(cls, rounds: int) -> "UserImportHash": """Creates a new SHA1 algorithm instance. Args: @@ -359,7 +381,7 @@ def sha1(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def sha256(cls, rounds): + def sha256(cls, rounds: int) -> "UserImportHash": """Creates a new SHA256 algorithm instance. Args: @@ -373,7 +395,7 @@ def sha256(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def sha512(cls, rounds): + def sha512(cls, rounds: int) -> "UserImportHash": """Creates a new SHA512 algorithm instance. Args: @@ -387,7 +409,7 @@ def sha512(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8192)}) @classmethod - def pbkdf_sha1(cls, rounds): + def pbkdf_sha1(cls, rounds: int) -> "UserImportHash": """Creates a new PBKDF SHA1 algorithm instance. Args: @@ -401,7 +423,7 @@ def pbkdf_sha1(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod - def pbkdf2_sha256(cls, rounds): + def pbkdf2_sha256(cls, rounds: int) -> "UserImportHash": """Creates a new PBKDF2 SHA256 algorithm instance. Args: @@ -415,7 +437,13 @@ def pbkdf2_sha256(cls, rounds): {'rounds': _auth_utils.validate_int(rounds, 'rounds', 0, 120000)}) @classmethod - def scrypt(cls, key, rounds, memory_cost, salt_separator=None): + def scrypt( + cls, + key: bytes, + rounds: int, + memory_cost: int, + salt_separator: typing.Optional[bytes] = None, + ) -> "UserImportHash": """Creates a new Scrypt algorithm instance. This is the modified Scrypt algorithm used by Firebase Auth. See ``standard_scrypt()`` @@ -430,18 +458,18 @@ def scrypt(cls, key, rounds, memory_cost, salt_separator=None): Returns: UserImportHash: A new ``UserImportHash``. """ - data = { + data: typing.Dict[str, typing.Any] = { 'signerKey': b64_encode(_auth_utils.validate_bytes(key, 'key', required=True)), 'rounds': _auth_utils.validate_int(rounds, 'rounds', 1, 8), 'memoryCost': _auth_utils.validate_int(memory_cost, 'memory_cost', 1, 14), } if salt_separator: - data['saltSeparator'] = b64_encode(_auth_utils.validate_bytes( - salt_separator, 'salt_separator')) + _auth_utils.validate_bytes(salt_separator, 'salt_separator') + data['saltSeparator'] = b64_encode(salt_separator) return UserImportHash('SCRYPT', data) @classmethod - def bcrypt(cls): + def bcrypt(cls) -> "UserImportHash": """Creates a new Bcrypt algorithm instance. Returns: @@ -450,7 +478,13 @@ def bcrypt(cls): return UserImportHash('BCRYPT') @classmethod - def standard_scrypt(cls, memory_cost, parallelization, block_size, derived_key_length): + def standard_scrypt( + cls, + memory_cost: int, + parallelization: int, + block_size: int, + derived_key_length: int, + ) -> "UserImportHash": """Creates a new standard Scrypt algorithm instance. Args: @@ -479,16 +513,16 @@ class ErrorInfo: # it's home in _user_import.py). It's now also used by bulk deletion of # users. Move this to a more common location. - def __init__(self, error): - self._index = error['index'] - self._reason = error['message'] + def __init__(self, error: typing.Dict[str, _typing.Json]) -> None: + self._index = typing.cast(int, error['index']) + self._reason = typing.cast(str, error['message']) @property - def index(self): + def index(self) -> int: return self._index @property - def reason(self): + def reason(self) -> str: return self._reason @@ -498,23 +532,23 @@ class UserImportResult: See ``auth.import_users()`` API for more details. """ - def __init__(self, result, total): + def __init__(self, result: typing.Dict[str, typing.Any], total: int) -> None: errors = result.get('error', []) self._success_count = total - len(errors) self._failure_count = len(errors) self._errors = [ErrorInfo(err) for err in errors] @property - def success_count(self): + def success_count(self) -> int: """Returns the number of users successfully imported.""" return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Returns the number of users that failed to be imported.""" return self._failure_count @property - def errors(self): + def errors(self) -> typing.List[ErrorInfo]: """Returns a list of ``auth.ErrorInfo`` instances describing the errors encountered.""" return self._errors diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index aa0dfb0a..d9356b5c 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -15,14 +15,17 @@ """Firebase user management sub module.""" import base64 -from collections import defaultdict +import collections import json +import typing from urllib import parse import requests from firebase_admin import _auth_utils +from firebase_admin import _http_client from firebase_admin import _rfc3339 +from firebase_admin import _typing from firebase_admin import _user_identifier from firebase_admin import _user_import from firebase_admin._user_import import ErrorInfo @@ -34,19 +37,22 @@ class Sentinel: - - def __init__(self, description): + def __init__(self, description: str) -> None: self.description = description -DELETE_ATTRIBUTE = Sentinel('Value used to delete an attribute from a user profile') +DELETE_ATTRIBUTE: typing.Any = Sentinel('Value used to delete an attribute from a user profile') class UserMetadata: """Contains additional metadata associated with a user account.""" - def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, - last_refresh_timestamp=None): + def __init__( + self, + creation_timestamp: typing.Optional[typing.Any] = None, + last_sign_in_timestamp: typing.Optional[typing.Any] = None, + last_refresh_timestamp: typing.Optional[typing.Any] = None, + ) -> None: self._creation_timestamp = _auth_utils.validate_timestamp( creation_timestamp, 'creation_timestamp') self._last_sign_in_timestamp = _auth_utils.validate_timestamp( @@ -55,7 +61,7 @@ def __init__(self, creation_timestamp=None, last_sign_in_timestamp=None, last_refresh_timestamp, 'last_refresh_timestamp') @property - def creation_timestamp(self): + def creation_timestamp(self) -> typing.Optional[int]: """ Creation timestamp in milliseconds since the epoch. Returns: @@ -64,7 +70,7 @@ def creation_timestamp(self): return self._creation_timestamp @property - def last_sign_in_timestamp(self): + def last_sign_in_timestamp(self) -> typing.Optional[int]: """ Last sign in timestamp in milliseconds since the epoch. Returns: @@ -73,7 +79,7 @@ def last_sign_in_timestamp(self): return self._last_sign_in_timestamp @property - def last_refresh_timestamp(self): + def last_refresh_timestamp(self) -> typing.Optional[int]: """The time at which the user was last active (ID token refreshed). Returns: @@ -90,32 +96,32 @@ class UserInfo: """ @property - def uid(self): + def uid(self) -> str: """Returns the user ID of this user.""" raise NotImplementedError @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: """Returns the display name of this user.""" raise NotImplementedError @property - def email(self): + def email(self) -> typing.Optional[str]: """Returns the email address associated with this user.""" raise NotImplementedError @property - def phone_number(self): + def phone_number(self) -> typing.Optional[str]: """Returns the phone number associated with this user.""" raise NotImplementedError @property - def photo_url(self): + def photo_url(self) -> typing.Optional[str]: """Returns the photo URL of this user.""" raise NotImplementedError @property - def provider_id(self): + def provider_id(self) -> str: """Returns the ID of the identity provider. This can be a short domain name (e.g. google.com), or the identity of an OpenID @@ -127,7 +133,7 @@ def provider_id(self): class UserRecord(UserInfo): """Contains metadata associated with a Firebase user account.""" - def __init__(self, data): + def __init__(self, data: typing.Dict[str, typing.Any]) -> None: super(UserRecord, self).__init__() if not isinstance(data, dict): raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) @@ -136,16 +142,16 @@ def __init__(self, data): self._data = data @property - def uid(self): + def uid(self) -> str: """Returns the user ID of this user. Returns: string: A user ID string. This value is never None or empty. """ - return self._data.get('localId') + return self._data['localId'] @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: """Returns the display name of this user. Returns: @@ -154,7 +160,7 @@ def display_name(self): return self._data.get('displayName') @property - def email(self): + def email(self) -> typing.Optional[str]: """Returns the email address associated with this user. Returns: @@ -163,7 +169,7 @@ def email(self): return self._data.get('email') @property - def phone_number(self): + def phone_number(self) -> typing.Optional[str]: """Returns the phone number associated with this user. Returns: @@ -172,7 +178,7 @@ def phone_number(self): return self._data.get('phoneNumber') @property - def photo_url(self): + def photo_url(self) -> typing.Optional[str]: """Returns the photo URL of this user. Returns: @@ -181,7 +187,7 @@ def photo_url(self): return self._data.get('photoUrl') @property - def provider_id(self): + def provider_id(self) -> str: """Returns the provider ID of this user. Returns: @@ -190,7 +196,7 @@ def provider_id(self): return 'firebase' @property - def email_verified(self): + def email_verified(self) -> bool: """Returns whether the email address of this user has been verified. Returns: @@ -199,7 +205,7 @@ def email_verified(self): return bool(self._data.get('emailVerified')) @property - def disabled(self): + def disabled(self) -> bool: """Returns whether this user account is disabled. Returns: @@ -208,7 +214,7 @@ def disabled(self): return bool(self._data.get('disabled')) @property - def tokens_valid_after_timestamp(self): + def tokens_valid_after_timestamp(self) -> int: """Returns the time, in milliseconds since the epoch, before which tokens are invalid. Note: this is truncated to 1 second accuracy. @@ -223,16 +229,17 @@ def tokens_valid_after_timestamp(self): return 0 @property - def user_metadata(self): + def user_metadata(self) -> UserMetadata: """Returns additional metadata associated with this user. Returns: UserMetadata: A UserMetadata instance. Does not return None. """ - def _int_or_none(key): + def _int_or_none(key: str) -> typing.Optional[int]: if key in self._data: return int(self._data[key]) return None + last_refresh_at_millis = None last_refresh_at_rfc3339 = self._data.get('lastRefreshAt', None) if last_refresh_at_rfc3339: @@ -241,7 +248,7 @@ def _int_or_none(key): _int_or_none('createdAt'), _int_or_none('lastLoginAt'), last_refresh_at_millis) @property - def provider_data(self): + def provider_data(self) -> typing.List["ProviderUserInfo"]: """Returns a list of UserInfo instances. Each object represents an identity from an identity provider that is linked to this user. @@ -253,7 +260,7 @@ def provider_data(self): return [ProviderUserInfo(entry) for entry in providers] @property - def custom_claims(self): + def custom_claims(self) -> typing.Optional[typing.Dict[str, typing.Any]]: """Returns any custom claims set on this user account. Returns: @@ -267,7 +274,7 @@ def custom_claims(self): return None @property - def tenant_id(self): + def tenant_id(self) -> typing.Optional[str]: """Returns the tenant ID of this user. Returns: @@ -280,7 +287,7 @@ class ExportedUserRecord(UserRecord): """Contains metadata associated with a user including password hash and salt.""" @property - def password_hash(self): + def password_hash(self) -> typing.Optional[str]: """The user's password hash as a base64-encoded string. If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this @@ -299,7 +306,7 @@ def password_hash(self): return password_hash @property - def password_salt(self): + def password_salt(self) -> typing.Optional[str]: """The user's password salt as a base64-encoded string. If the Firebase Auth hashing algorithm (SCRYPT) was used to create the user account, this @@ -314,7 +321,7 @@ def password_salt(self): class GetUsersResult: """Represents the result of the ``auth.get_users()`` API.""" - def __init__(self, users, not_found): + def __init__(self, users: typing.List[UserRecord], not_found: typing.List[_user_identifier.UserIdentifier]) -> None: """Constructs a `GetUsersResult` object. Args: @@ -325,7 +332,7 @@ def __init__(self, users, not_found): self._not_found = not_found @property - def users(self): + def users(self) -> typing.List[UserRecord]: """Set of `UserRecord` instances, corresponding to the set of users that were requested. Only users that were found are listed here. The result set is unordered. @@ -333,7 +340,7 @@ def users(self): return self._users @property - def not_found(self): + def not_found(self) -> typing.List[_user_identifier.UserIdentifier]: """Set of `UserIdentifier` instances that were requested, but not found. """ @@ -348,27 +355,38 @@ class ListUsersPage: through all users in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: typing.Callable[[typing.Optional[str], int], typing.Dict[str, _typing.Json]], + page_token: typing.Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def users(self): + def users(self) -> typing.List[ExportedUserRecord]: """A list of ``ExportedUserRecord`` instances available in this page.""" - return [ExportedUserRecord(user) for user in self._current.get('users', [])] + return [ + ExportedUserRecord(user) + for user in typing.cast( + typing.List[typing.Dict[str, _typing.Json]], + self._current.get('users', []), + ) + ] @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" - return self._current.get('nextPageToken', '') + return typing.cast(str, self._current.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> typing.Optional["ListUsersPage"]: """Retrieves the next page of user accounts, if available. Returns: @@ -378,7 +396,7 @@ def get_next_page(self): return ListUsersPage(self._download, self.next_page_token, self._max_results) return None - def iterate_all(self): + def iterate_all(self) -> "_UserIterator": """Retrieves an iterator for user accounts. Returned iterator will iterate through all the user accounts in the Firebase project @@ -394,7 +412,7 @@ def iterate_all(self): class DeleteUsersResult: """Represents the result of the ``auth.delete_users()`` API.""" - def __init__(self, result, total): + def __init__(self, result: "BatchDeleteAccountsResponse", total: int) -> None: """Constructs a `DeleteUsersResult` object. Args: @@ -408,7 +426,7 @@ def __init__(self, result, total): self._errors = errors @property - def success_count(self): + def success_count(self) -> int: """Returns the number of users that were deleted successfully (possibly zero). @@ -418,14 +436,14 @@ def success_count(self): return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Returns the number of users that failed to be deleted (possibly zero). """ return self._failure_count @property - def errors(self): + def errors(self) -> typing.List[ErrorInfo]: """A list of `auth.ErrorInfo` instances describing the errors that were encountered during the deletion. Length of this list is equal to `failure_count`. @@ -436,7 +454,7 @@ def errors(self): class BatchDeleteAccountsResponse: """Represents the results of a `delete_users()` call.""" - def __init__(self, errors=None): + def __init__(self, errors: typing.Optional[typing.List[typing.Dict[str, _typing.Json]]] = None) -> None: """Constructs a `BatchDeleteAccountsResponse` instance, corresponding to the JSON representing the `BatchDeleteAccountsResponse` proto. @@ -451,8 +469,7 @@ def __init__(self, errors=None): class ProviderUserInfo(UserInfo): """Contains metadata regarding how a user is known by a particular identity provider.""" - def __init__(self, data): - super(ProviderUserInfo, self).__init__() + def __init__(self, data: typing.Dict[str, typing.Any]) -> None: if not isinstance(data, dict): raise ValueError('Invalid data argument: {0}. Must be a dictionary.'.format(data)) if not data.get('rawId'): @@ -460,28 +477,29 @@ def __init__(self, data): self._data = data @property - def uid(self): - return self._data.get('rawId') + def uid(self) -> str: + return self._data['rawId'] @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: return self._data.get('displayName') @property - def email(self): + def email(self) -> typing.Optional[str]: return self._data.get('email') @property - def phone_number(self): + def phone_number(self) -> typing.Optional[str]: return self._data.get('phoneNumber') @property - def photo_url(self): + def photo_url(self) -> typing.Optional[str]: return self._data.get('photoUrl') @property - def provider_id(self): - return self._data.get('providerId') + def provider_id(self) -> str: + # possible issue: can providerId be `None`? + return self._data.get('providerId') # type: ignore[reportReturnType] class ActionCodeSettings: @@ -489,8 +507,16 @@ class ActionCodeSettings: Used when invoking the email action link generation APIs. """ - def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_bundle_id=None, - android_package_name=None, android_install_app=None, android_minimum_version=None): + def __init__( + self, + url: str, + handle_code_in_app: typing.Optional[bool] = None, + dynamic_link_domain: typing.Optional[str] = None, + ios_bundle_id: typing.Optional[str] = None, + android_package_name: typing.Optional[str] = None, + android_install_app: typing.Optional[bool] = None, + android_minimum_version: typing.Optional[str] = None + ) -> None: self.url = url self.handle_code_in_app = handle_code_in_app self.dynamic_link_domain = dynamic_link_domain @@ -500,7 +526,7 @@ def __init__(self, url, handle_code_in_app=None, dynamic_link_domain=None, ios_b self.android_minimum_version = android_minimum_version -def encode_action_code_settings(settings): +def encode_action_code_settings(settings: ActionCodeSettings) -> typing.Dict[str, typing.Any]: """ Validates the provided action code settings for email link generation and populates the REST api parameters. @@ -508,7 +534,7 @@ def encode_action_code_settings(settings): returns - dict of parameters to be passed for link gereration. """ - parameters = {} + parameters: typing.Dict[str, typing.Any] = {} # url if not settings.url: raise ValueError("Dynamic action links url is mandatory") @@ -573,14 +599,20 @@ class UserManager: ID_TOOLKIT_URL = 'https://identitytoolkit.googleapis.com/v1' - def __init__(self, http_client, project_id, tenant_id=None, url_override=None): + def __init__( + self, + http_client: _http_client.HttpClient[typing.Dict[str, _typing.Json]], + project_id: str, + tenant_id: typing.Optional[str] = None, + url_override: typing.Optional[str] = None, + ) -> None: self.http_client = http_client url_prefix = url_override or self.ID_TOOLKIT_URL self.base_url = '{0}/projects/{1}'.format(url_prefix, project_id) if tenant_id: self.base_url += '/tenants/{0}'.format(tenant_id) - def get_user(self, **kwargs): + def get_user(self, **kwargs: typing.Any) -> typing.Dict[str, _typing.Json]: """Gets the user data corresponding to the provided key.""" if 'uid' in kwargs: key, key_type = kwargs.pop('uid'), 'user ID' @@ -599,9 +631,12 @@ def get_user(self, **kwargs): raise _auth_utils.UserNotFoundError( 'No user record found for the provided {0}: {1}.'.format(key_type, key), http_response=http_resp) - return body['users'][0] + return typing.cast(typing.List[typing.Dict[str, _typing.Json]], body['users'])[0] - def get_users(self, identifiers): + def get_users( + self, + identifiers: typing.Sequence[_user_identifier.UserIdentifier], + ) -> typing.List[typing.Dict[str, _typing.Json]]: """Looks up multiple users by their identifiers (uid, email, etc.) Args: @@ -623,7 +658,7 @@ def get_users(self, identifiers): if len(identifiers) > 100: raise ValueError('`identifiers` parameter must have <= 100 entries.') - payload = defaultdict(list) + payload: typing.Dict[str, typing.List[typing.Any]] = collections.defaultdict(list) for identifier in identifiers: if isinstance(identifier, _user_identifier.UidIdentifier): payload['localId'].append(identifier.uid) @@ -646,9 +681,13 @@ def get_users(self, identifiers): if not http_resp.ok: raise _auth_utils.UnexpectedResponseError( 'Failed to get users.', http_response=http_resp) - return body.get('users', []) + return typing.cast(typing.List[typing.Dict[str, _typing.Json]], body.get('users', [])) - def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): + def list_users( + self, + page_token: typing.Optional[str] = None, + max_results: int = MAX_LIST_USERS_RESULTS, + ): """Retrieves a batch of users.""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -660,14 +699,23 @@ def list_users(self, page_token=None, max_results=MAX_LIST_USERS_RESULTS): 'Max results must be a positive integer less than ' '{0}.'.format(MAX_LIST_USERS_RESULTS)) - payload = {'maxResults': max_results} + payload: typing.Dict[str, typing.Any] = {'maxResults': max_results} if page_token: payload['nextPageToken'] = page_token body, _ = self._make_request('get', '/accounts:batchGet', params=payload) return body - def create_user(self, uid=None, display_name=None, email=None, phone_number=None, - photo_url=None, password=None, disabled=None, email_verified=None): + def create_user( + self, + uid: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + email: typing.Optional[str] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + password: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + email_verified: typing.Optional[bool] = None, + ) -> str: """Creates a new user account with the specified properties.""" payload = { 'localId': _auth_utils.validate_uid(uid), @@ -684,13 +732,24 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None if not body or not body.get('localId'): raise _auth_utils.UnexpectedResponseError( 'Failed to create new user.', http_response=http_resp) - return body.get('localId') - - def update_user(self, uid, display_name=None, email=None, phone_number=None, - photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None, providers_to_delete=None): + return typing.cast(str, body['localId']) + + def update_user( + self, + uid: str, + display_name: typing.Optional[str] = None, + email: typing.Optional[str] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + password: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + email_verified: typing.Optional[bool] = None, + valid_since: typing.Optional[_typing.ConvertibleToInt] = None, + custom_claims: typing.Optional[typing.Union[typing.Dict[str, typing.Any], str]] = None, + providers_to_delete: typing.Optional[typing.List[str]] = None, + ): """Updates an existing user account with the specified properties""" - payload = { + payload: typing.Dict[str, typing.Any] = { 'localId': _auth_utils.validate_uid(uid, required=True), 'email': _auth_utils.validate_email(email), 'password': _auth_utils.validate_password(password), @@ -699,7 +758,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, 'disableUser': bool(disabled) if disabled is not None else None, } - remove = [] + remove: typing.List[str] = [] remove_provider = _auth_utils.validate_provider_ids(providers_to_delete) if display_name is not None: if display_name is DELETE_ATTRIBUTE: @@ -737,7 +796,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, 'Failed to update user: {0}.'.format(uid), http_response=http_resp) return body.get('localId') - def delete_user(self, uid): + def delete_user(self, uid: str) -> None: """Deletes the user identified by the specified user ID.""" _auth_utils.validate_uid(uid, required=True) body, http_resp = self._make_request('post', '/accounts:delete', json={'localId' : uid}) @@ -745,7 +804,7 @@ def delete_user(self, uid): raise _auth_utils.UnexpectedResponseError( 'Failed to delete user: {0}.'.format(uid), http_response=http_resp) - def delete_users(self, uids, force_delete=False): + def delete_users(self, uids: typing.Sequence[str], force_delete: bool = False) -> BatchDeleteAccountsResponse: """Deletes the users identified by the specified user ids. Args: @@ -774,14 +833,19 @@ def delete_users(self, uids, force_delete=False): _auth_utils.validate_uid(uid, required=True) body, http_resp = self._make_request('post', '/accounts:batchDelete', - json={'localIds': uids, 'force': force_delete}) + json={'localIds': list(uids), 'force': force_delete}) if not isinstance(body, dict): raise _auth_utils.UnexpectedResponseError( 'Unexpected response from server while attempting to delete users.', http_response=http_resp) - return BatchDeleteAccountsResponse(body.get('errors', [])) - - def import_users(self, users, hash_alg=None): + return BatchDeleteAccountsResponse(typing.cast(typing.List[typing.Dict[str, _typing.Json]], + body.get('errors', []))) + + def import_users( + self, + users: typing.Sequence[_user_import.ImportUserRecord], + hash_alg: typing.Optional[_user_import.UserImportHash] = None, + ) -> typing.Dict[str, typing.Any]: """Imports the given list of users to Firebase Auth.""" try: if not users or len(users) > MAX_IMPORT_USERS_SIZE: @@ -804,7 +868,12 @@ def import_users(self, users, hash_alg=None): 'Failed to import users.', http_response=http_resp) return body - def generate_email_action_link(self, action_type, email, action_code_settings=None): + def generate_email_action_link( + self, + action_type: _typing.EmailActionType, + email: typing.Optional[str], + action_code_settings: typing.Optional[ActionCodeSettings] = None + ) -> str: """Fetches the email action links for types Args: @@ -834,9 +903,14 @@ def generate_email_action_link(self, action_type, email, action_code_settings=No if not body or not body.get('oobLink'): raise _auth_utils.UnexpectedResponseError( 'Failed to generate email action link.', http_response=http_resp) - return body.get('oobLink') - - def _make_request(self, method, path, **kwargs): + return typing.cast(str, body['oobLink']) + + def _make_request( + self, + method: str, + path: str, + **kwargs: typing.Any, + ) -> typing.Tuple[typing.Dict[str, _typing.Json], requests.Response]: url = '{0}{1}'.format(self.base_url, path) try: return self.http_client.body_and_response(method, url, **kwargs) @@ -844,8 +918,7 @@ def _make_request(self, method, path, **kwargs): raise _auth_utils.handle_auth_backend_error(error) -class _UserIterator(_auth_utils.PageIterator): - +class _UserIterator(_auth_utils.PageIterator[ListUsersPage]): @property - def items(self): - return self._current_page.users + def items(self) -> typing.List[ExportedUserRecord]: + return self._current_page.users if self._current_page else [] diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index b6e29254..6ee56826 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -15,16 +15,20 @@ """Internal utilities common to all modules.""" import json +import typing from platform import python_version -import google.auth +import google.auth.credentials +import google.auth.transport import requests import firebase_admin from firebase_admin import exceptions +from firebase_admin import _typing +_T = typing.TypeVar("_T") -_ERROR_CODE_TO_EXCEPTION_TYPE = { +_ERROR_CODE_TO_EXCEPTION_TYPE: typing.Dict[str, "_typing.FirebaseErrorFactoryWithDefaults"] = { exceptions.INVALID_ARGUMENT: exceptions.InvalidArgumentError, exceptions.FAILED_PRECONDITION: exceptions.FailedPreconditionError, exceptions.OUT_OF_RANGE: exceptions.OutOfRangeError, @@ -44,7 +48,7 @@ } -_HTTP_STATUS_TO_ERROR_CODE = { +_HTTP_STATUS_TO_ERROR_CODE: typing.Dict[int, str] = { 400: exceptions.INVALID_ARGUMENT, 401: exceptions.UNAUTHENTICATED, 403: exceptions.PERMISSION_DENIED, @@ -58,7 +62,7 @@ # See https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto -_RPC_CODE_TO_ERROR_CODE = { +_RPC_CODE_TO_ERROR_CODE: typing.Dict[int, str] = { 1: exceptions.CANCELLED, 2: exceptions.UNKNOWN, 3: exceptions.INVALID_ARGUMENT, @@ -76,10 +80,10 @@ 16: exceptions.UNAUTHENTICATED, } -def get_metrics_header(): +def get_metrics_header() -> str: return f'gl-python/{python_version()} fire-admin/{firebase_admin.__version__}' -def _get_initialized_app(app): +def _get_initialized_app(app: typing.Optional["firebase_admin.App"]) -> "firebase_admin.App": """Returns a reference to an initialized App instance.""" if app is None: return firebase_admin.get_app() @@ -95,13 +99,19 @@ def _get_initialized_app(app): ' firebase_admin.App, but given "{0}".'.format(type(app))) - -def get_app_service(app, name, initializer): +def get_app_service( + app: typing.Optional["firebase_admin.App"], + name: str, + initializer: "_typing.ServiceInitializer[_T]", +) -> _T: app = _get_initialized_app(app) return app._get_service(name, initializer) # pylint: disable=protected-access -def handle_platform_error_from_requests(error, handle_func=None): +def handle_platform_error_from_requests( + error: requests.RequestException, + handle_func: typing.Optional["_typing.RequestErrorHandler"] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given requests error. This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. @@ -129,7 +139,7 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) -def handle_operation_error(error): +def handle_operation_error(error: typing.Union[typing.Dict[str, typing.Any], Exception]) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given operation error. Args: @@ -143,14 +153,19 @@ def handle_operation_error(error): message='Unknown error while making a remote service call: {0}'.format(error), cause=error) - rpc_code = error.get('code') - message = error.get('message') + rpc_code = error.get('code', 0) + # possible issue: needs be str | None ? + message = typing.cast(str, error.get('message')) error_code = _rpc_code_to_error_code(rpc_code) err_type = _error_code_to_exception_type(error_code) return err_type(message=message) -def _handle_func_requests(error, message, error_dict): +def _handle_func_requests( + error: requests.RequestException, + message: str, + error_dict: typing.Dict[str, typing.Any], +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given GCP error. Args: @@ -165,7 +180,11 @@ def _handle_func_requests(error, message, error_dict): return handle_requests_error(error, message, code) -def handle_requests_error(error, message=None, code=None): +def handle_requests_error( + error: requests.RequestException, + message: typing.Optional[str] = None, + code: typing.Optional[str] = None, +) -> exceptions.FirebaseError: """Constructs a ``FirebaseError`` from the given requests error. This method is agnostic of the remote service that produced the error, whether it is a GCP @@ -205,20 +224,20 @@ def handle_requests_error(error, message=None, code=None): return err_type(message=message, cause=error, http_response=error.response) -def _http_status_to_error_code(status): +def _http_status_to_error_code(status: int) -> str: """Maps an HTTP status to a platform error code.""" return _HTTP_STATUS_TO_ERROR_CODE.get(status, exceptions.UNKNOWN) -def _rpc_code_to_error_code(rpc_code): +def _rpc_code_to_error_code(rpc_code: int) -> str: """Maps an RPC code to a platform error code.""" return _RPC_CODE_TO_ERROR_CODE.get(rpc_code, exceptions.UNKNOWN) -def _error_code_to_exception_type(code): +def _error_code_to_exception_type(code: str) -> "_typing.FirebaseErrorFactoryWithDefaults": """Maps a platform error code to an exception type.""" return _ERROR_CODE_TO_EXCEPTION_TYPE.get(code, exceptions.UnknownError) -def _parse_platform_error(content, status_code): +def _parse_platform_error(content: str, status_code: int) -> typing.Tuple[typing.Dict[str, typing.Any], str]: """Parses an HTTP error response from a Google Cloud Platform API and extracts the error code and message fields. @@ -229,15 +248,15 @@ def _parse_platform_error(content, status_code): Returns: tuple: A tuple containing error code and message. """ - data = {} + data: typing.Dict[str, typing.Any] = {} try: parsed_body = json.loads(content) if isinstance(parsed_body, dict): - data = parsed_body + data = typing.cast(typing.Dict[str, typing.Any], parsed_body) except ValueError: pass - error_dict = data.get('error', {}) + error_dict: typing.Dict[str, typing.Any] = data.get('error', {}) msg = error_dict.get('message') if not msg: msg = 'Unexpected HTTP response with status: {0}; body: {1}'.format(status_code, content) @@ -253,9 +272,9 @@ class EmulatorAdminCredentials(google.auth.credentials.Credentials): This is used instead of user-supplied credentials or ADC. It will silently do nothing when asked to refresh credentials. """ - def __init__(self): + def __init__(self) -> None: google.auth.credentials.Credentials.__init__(self) self.token = 'owner' - def refresh(self, request): + def refresh(self, request: google.auth.transport.Request) -> None: pass diff --git a/firebase_admin/app_check.py b/firebase_admin/app_check.py index 53686db3..49617b5b 100644 --- a/firebase_admin/app_check.py +++ b/firebase_admin/app_check.py @@ -14,18 +14,22 @@ """Firebase App Check module.""" -from typing import Any, Dict +import typing + import jwt -from jwt import PyJWKClient, ExpiredSignatureError, InvalidTokenError, DecodeError -from jwt import InvalidAudienceError, InvalidIssuerError, InvalidSignatureError + +import firebase_admin from firebase_admin import _utils + _APP_CHECK_ATTRIBUTE = '_app_check' -def _get_app_check_service(app) -> Any: + +def _get_app_check_service(app: typing.Optional[firebase_admin.App]) -> "_AppCheckService": return _utils.get_app_service(app, _APP_CHECK_ATTRIBUTE, _AppCheckService) -def verify_token(token: str, app=None) -> Dict[str, Any]: + +def verify_token(token: str, app: typing.Optional[firebase_admin.App] = None) -> typing.Dict[str, typing.Any]: """Verifies a Firebase App Check token. Args: @@ -42,35 +46,32 @@ def verify_token(token: str, app=None) -> Dict[str, Any]: """ return _get_app_check_service(app).verify_token(token) + class _AppCheckService: """Service class that implements Firebase App Check functionality.""" _APP_CHECK_ISSUER = 'https://firebaseappcheck.googleapis.com/' _JWKS_URL = 'https://firebaseappcheck.googleapis.com/v1/jwks' - _project_id = None - _scoped_project_id = None - _jwks_client = None _APP_CHECK_HEADERS = { 'x-goog-api-client': _utils.get_metrics_header(), } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: # Validate and store the project_id to validate the JWT claims - self._project_id = app.project_id - if not self._project_id: + if not app.project_id: raise ValueError( 'A project ID must be specified to access the App Check ' 'service. Either set the projectId option, use service ' 'account credentials, or set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') + self._project_id = app.project_id self._scoped_project_id = 'projects/' + app.project_id # Default lifespan is 300 seconds (5 minutes) so we change it to 21600 seconds (6 hours). - self._jwks_client = PyJWKClient( + self._jwks_client = jwt.PyJWKClient( self._JWKS_URL, lifespan=21600, headers=self._APP_CHECK_HEADERS) - - def verify_token(self, token: str) -> Dict[str, Any]: + def verify_token(self, token: str) -> typing.Dict[str, typing.Any]: """Verifies a Firebase App Check token.""" _Validators.check_string("app check token", token) @@ -81,7 +82,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: signing_key = self._jwks_client.get_signing_key_from_jwt(token) self._has_valid_token_headers(jwt.get_unverified_header(token)) verified_claims = self._decode_and_verify(token, signing_key.key) - except (InvalidTokenError, DecodeError) as exception: + except (jwt.InvalidTokenError, jwt.DecodeError) as exception: raise ValueError( f'Verifying App Check token failed. Error: {exception}' ) @@ -89,7 +90,7 @@ def verify_token(self, token: str) -> Dict[str, Any]: verified_claims['app_id'] = verified_claims.get('sub') return verified_claims - def _has_valid_token_headers(self, headers: Any) -> None: + def _has_valid_token_headers(self, headers: typing.Dict[str, typing.Any]) -> None: """Checks whether the token has valid headers for App Check.""" # Ensure the token's header has type JWT if headers.get('typ') != 'JWT': @@ -102,9 +103,9 @@ def _has_valid_token_headers(self, headers: Any) -> None: f'Expected RS256 but got {algorithm}.' ) - def _decode_and_verify(self, token: str, signing_key: str): + def _decode_and_verify(self, token: str, signing_key: str) -> typing.Dict[str, typing.Any]: """Decodes and verifies the token from App Check.""" - payload = {} + payload: typing.Dict[str, typing.Any] = {} try: payload = jwt.decode( token, @@ -112,25 +113,25 @@ def _decode_and_verify(self, token: str, signing_key: str): algorithms=["RS256"], audience=self._scoped_project_id ) - except InvalidSignatureError: + except jwt.InvalidSignatureError: raise ValueError( 'The provided App Check token has an invalid signature.' ) - except InvalidAudienceError: + except jwt.InvalidAudienceError: raise ValueError( 'The provided App Check token has an incorrect "aud" (audience) claim. ' f'Expected payload to include {self._scoped_project_id}.' ) - except InvalidIssuerError: + except jwt.InvalidIssuerError: raise ValueError( 'The provided App Check token has an incorrect "iss" (issuer) claim. ' f'Expected claim to include {self._APP_CHECK_ISSUER}' ) - except ExpiredSignatureError: + except jwt.ExpiredSignatureError: raise ValueError( 'The provided App Check token has expired.' ) - except InvalidTokenError as exception: + except jwt.InvalidTokenError as exception: raise ValueError( f'Decoding App Check token failed. Error: {exception}' ) @@ -138,7 +139,7 @@ def _decode_and_verify(self, token: str, signing_key: str): audience = payload.get('aud') if not isinstance(audience, list) or self._scoped_project_id not in audience: raise ValueError('Firebase App Check token has incorrect "aud" (audience) claim.') - if not payload.get('iss').startswith(self._APP_CHECK_ISSUER): + if not typing.cast(str, payload['iss']).startswith(self._APP_CHECK_ISSUER): raise ValueError('Token does not contain the correct "iss" (issuer).') _Validators.check_string( 'The provided App Check token "sub" (subject) claim', @@ -146,6 +147,7 @@ def _decode_and_verify(self, token: str, signing_key: str): return payload + class _Validators: """A collection of data validation utilities. @@ -153,7 +155,7 @@ class _Validators: """ @classmethod - def check_string(cls, label: str, value: Any): + def check_string(cls, label: str, value: typing.Any) -> None: """Checks if the given value is a string.""" if value is None: raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) diff --git a/firebase_admin/auth.py b/firebase_admin/auth.py index ced14311..88d824f8 100644 --- a/firebase_admin/auth.py +++ b/firebase_admin/auth.py @@ -19,11 +19,16 @@ creating and managing user accounts in Firebase projects. """ +import datetime +import typing + +import firebase_admin from firebase_admin import _auth_client from firebase_admin import _auth_providers from firebase_admin import _auth_utils from firebase_admin import _user_identifier from firebase_admin import _token_gen +from firebase_admin import _typing from firebase_admin import _user_import from firebase_admin import _user_mgt from firebase_admin import _utils @@ -156,7 +161,7 @@ ProviderIdentifier = _user_identifier.ProviderIdentifier -def _get_client(app): +def _get_client(app: typing.Optional[firebase_admin.App]) -> Client: """Returns a client instance for an App. If the App already has a client associated with it, simply returns @@ -175,7 +180,11 @@ def _get_client(app): return _utils.get_app_service(app, _AUTH_ATTRIBUTE, Client) -def create_custom_token(uid, developer_claims=None, app=None): +def create_custom_token( + uid: str, + developer_claims: typing.Optional[typing.Dict[str, typing.Any]] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> bytes: """Builds and signs a Firebase custom auth token. Args: @@ -195,7 +204,12 @@ def create_custom_token(uid, developer_claims=None, app=None): return client.create_custom_token(uid, developer_claims) -def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds=0): +def verify_id_token( + id_token: typing.Union[bytes, str], + app: typing.Optional[firebase_admin.App] = None, + check_revoked: bool = False, + clock_skew_seconds: int = 0, +) -> typing.Dict[str, typing.Any]: """Verifies the signature and data for the provided JWT. Accepts a signed token string, verifies that it is current, and issued @@ -226,7 +240,11 @@ def verify_id_token(id_token, app=None, check_revoked=False, clock_skew_seconds= id_token, check_revoked=check_revoked, clock_skew_seconds=clock_skew_seconds) -def create_session_cookie(id_token, expires_in, app=None): +def create_session_cookie( + id_token: typing.Union[bytes, str], + expires_in: typing.Union[datetime.timedelta, int], + app: typing.Optional[firebase_admin.App] = None, +) -> str: """Creates a new Firebase session cookie from the given ID token and options. The returned JWT can be set as a server-side session cookie with a custom cookie policy. @@ -249,7 +267,12 @@ def create_session_cookie(id_token, expires_in, app=None): return client._token_generator.create_session_cookie(id_token, expires_in) -def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_skew_seconds=0): +def verify_session_cookie( + session_cookie: typing.Union[bytes, str], + check_revoked: bool = False, + app: typing.Optional[firebase_admin.App] = None, + clock_skew_seconds: int = 0, +) -> typing.Dict[str, typing.Any]: """Verifies a Firebase session cookie. Accepts a session cookie string, verifies that it is current, and issued @@ -285,7 +308,7 @@ def verify_session_cookie(session_cookie, check_revoked=False, app=None, clock_s return verified_claims -def revoke_refresh_tokens(uid, app=None): +def revoke_refresh_tokens(uid: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Revokes all refresh tokens for an existing user. This function updates the user's ``tokens_valid_after_timestamp`` to the current UTC @@ -309,7 +332,7 @@ def revoke_refresh_tokens(uid, app=None): client.revoke_refresh_tokens(uid) -def get_user(uid, app=None): +def get_user(uid: str, app: typing.Optional[firebase_admin.App] = None) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user ID. Args: @@ -328,7 +351,7 @@ def get_user(uid, app=None): return client.get_user(uid=uid) -def get_user_by_email(email, app=None): +def get_user_by_email(email: str, app: typing.Optional[firebase_admin.App] = None) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified user email. Args: @@ -347,7 +370,10 @@ def get_user_by_email(email, app=None): return client.get_user_by_email(email=email) -def get_user_by_phone_number(phone_number, app=None): +def get_user_by_phone_number( + phone_number: str, + app: typing.Optional[firebase_admin.App] = None, +) -> _user_mgt.UserRecord: """Gets the user data corresponding to the specified phone number. Args: @@ -366,7 +392,10 @@ def get_user_by_phone_number(phone_number, app=None): return client.get_user_by_phone_number(phone_number=phone_number) -def get_users(identifiers, app=None): +def get_users( + identifiers: typing.Sequence[_user_identifier.UserIdentifier], + app: typing.Optional[firebase_admin.App] = None, +) -> _user_mgt.GetUsersResult: """Gets the user data corresponding to the specified identifiers. There are no ordering guarantees; in particular, the nth entry in the @@ -394,7 +423,11 @@ def get_users(identifiers, app=None): return client.get_users(identifiers) -def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, app=None): +def list_users( + page_token: typing.Optional[str] = None, + max_results: int = _user_mgt.MAX_LIST_USERS_RESULTS, + app: typing.Optional[firebase_admin.App] = None, +) -> _user_mgt.ListUsersPage: """Retrieves a page of user accounts from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -420,7 +453,18 @@ def list_users(page_token=None, max_results=_user_mgt.MAX_LIST_USERS_RESULTS, ap return client.list_users(page_token=page_token, max_results=max_results) -def create_user(**kwargs): # pylint: disable=differing-param-doc +def create_user( + uid: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + email: typing.Optional[str] = None, + email_verified: typing.Optional[bool] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + password: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None, + **kwargs: typing.Any, +) -> _user_mgt.UserRecord: # pylint: disable=differing-param-doc """Creates a new user account with the specified properties. Args: @@ -445,12 +489,27 @@ def create_user(**kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user properties are invalid. FirebaseError: If an error occurs while creating the user account. """ - app = kwargs.pop('app', None) client = _get_client(app) - return client.create_user(**kwargs) - - -def update_user(uid, **kwargs): # pylint: disable=differing-param-doc + return client.create_user(uid=uid, display_name=display_name, email=email, + email_verified=email_verified, phone_number=phone_number, photo_url=photo_url, + password=password, disabled=disabled, **kwargs) + + +def update_user( + uid: str, + display_name: typing.Optional[str] = None, + email: typing.Optional[str] = None, + phone_number: typing.Optional[str] = None, + photo_url: typing.Optional[str] = None, + password: typing.Optional[str] = None, + disabled: typing.Optional[bool] = None, + email_verified: typing.Optional[bool] = None, + valid_since: typing.Optional[_typing.ConvertibleToInt] = None, + custom_claims: typing.Optional[typing.Union[typing.Dict[str, typing.Any], str]] = None, + providers_to_delete: typing.Optional[typing.List[str]] = None, + app: typing.Optional[firebase_admin.App] = None, + **kwargs: typing.Any, +) -> _user_mgt.UserRecord: # pylint: disable=differing-param-doc """Updates an existing user account with the specified properties. Args: @@ -482,12 +541,17 @@ def update_user(uid, **kwargs): # pylint: disable=differing-param-doc ValueError: If the specified user ID or properties are invalid. FirebaseError: If an error occurs while updating the user account. """ - app = kwargs.pop('app', None) client = _get_client(app) - return client.update_user(uid, **kwargs) + return client.update_user(uid, display_name=display_name, email=email, phone_number=phone_number, + photo_url=photo_url, password=password, disabled=disabled, email_verified=email_verified, + valid_since=valid_since, custom_claims=custom_claims, providers_to_delete=providers_to_delete, **kwargs) -def set_custom_user_claims(uid, custom_claims, app=None): +def set_custom_user_claims( + uid: str, + custom_claims: typing.Optional[typing.Union[typing.Dict[str, typing.Any], str]], + app: typing.Optional[firebase_admin.App] = None, +) -> None: """Sets additional claims on an existing user account. Custom claims set via this function can be used to define user roles and privilege levels. @@ -511,7 +575,7 @@ def set_custom_user_claims(uid, custom_claims, app=None): client.set_custom_user_claims(uid, custom_claims=custom_claims) -def delete_user(uid, app=None): +def delete_user(uid: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Deletes the user identified by the specified user ID. Args: @@ -526,7 +590,10 @@ def delete_user(uid, app=None): client.delete_user(uid) -def delete_users(uids, app=None): +def delete_users( + uids: typing.Sequence[str], + app: typing.Optional[firebase_admin.App] = None, +) -> _user_mgt.DeleteUsersResult: """Deletes the users specified by the given identifiers. Deleting a non-existing user does not generate an error (the method is @@ -553,7 +620,11 @@ def delete_users(uids, app=None): return client.delete_users(uids) -def import_users(users, hash_alg=None, app=None): +def import_users( + users: typing.Sequence[_user_import.ImportUserRecord], + hash_alg: typing.Optional[_user_import.UserImportHash] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> _user_import.UserImportResult: """Imports the specified list of users into Firebase Auth. At most 1000 users can be imported at a time. This operation is optimized for bulk imports and @@ -579,7 +650,11 @@ def import_users(users, hash_alg=None, app=None): return client.import_users(users, hash_alg) -def generate_password_reset_link(email, action_code_settings=None, app=None): +def generate_password_reset_link( + email: typing.Optional[str], + action_code_settings: typing.Optional[_user_mgt.ActionCodeSettings] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for password reset flows for the specified email address. @@ -600,7 +675,11 @@ def generate_password_reset_link(email, action_code_settings=None, app=None): return client.generate_password_reset_link(email, action_code_settings=action_code_settings) -def generate_email_verification_link(email, action_code_settings=None, app=None): +def generate_email_verification_link( + email: typing.Optional[str], + action_code_settings: typing.Optional[_user_mgt.ActionCodeSettings] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for email verification flows for the specified email address. @@ -622,7 +701,11 @@ def generate_email_verification_link(email, action_code_settings=None, app=None) email, action_code_settings=action_code_settings) -def generate_sign_in_with_email_link(email, action_code_settings, app=None): +def generate_sign_in_with_email_link( + email: typing.Optional[str], + action_code_settings: typing.Optional[_user_mgt.ActionCodeSettings], + app: typing.Optional[firebase_admin.App] = None, +) -> str: """Generates the out-of-band email action link for email link sign-in flows, using the action code settings provided. @@ -645,7 +728,10 @@ def generate_sign_in_with_email_link(email, action_code_settings, app=None): email, action_code_settings=action_code_settings) -def get_oidc_provider_config(provider_id, app=None): +def get_oidc_provider_config( + provider_id: str, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Returns the ``OIDCProviderConfig`` with the given ID. Args: @@ -663,9 +749,18 @@ def get_oidc_provider_config(provider_id, app=None): client = _get_client(app) return client.get_oidc_provider_config(provider_id) + def create_oidc_provider_config( - provider_id, client_id, issuer, display_name=None, enabled=None, client_secret=None, - id_token_response_type=None, code_response_type=None, app=None): + provider_id: str, + client_id: str, + issuer: str, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + client_secret: typing.Optional[str] = None, + id_token_response_type: typing.Optional[bool] = None, + code_response_type: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Creates a new OIDC provider config from the given parameters. OIDC provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -705,8 +800,16 @@ def create_oidc_provider_config( def update_oidc_provider_config( - provider_id, client_id=None, issuer=None, display_name=None, enabled=None, - client_secret=None, id_token_response_type=None, code_response_type=None, app=None): + provider_id: str, + client_id: typing.Optional[str] = None, + issuer: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + client_secret: typing.Optional[str] = None, + id_token_response_type: typing.Optional[bool] = None, + code_response_type: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers.OIDCProviderConfig: """Updates an existing OIDC provider config with the given parameters. Args: @@ -717,16 +820,16 @@ def update_oidc_provider_config( Pass ``auth.DELETE_ATTRIBUTE`` to delete the current display name. enabled: A boolean indicating whether the provider configuration is enabled or disabled (optional). - app: An App instance (optional). client_secret: A string which sets the client secret for the new provider. This is required for the code flow. + id_token_response_type: A boolean which sets whether to enable the ID token response flow + for the new provider. By default, this is enabled if no response type is specified. + Having both the code and ID token response flows is currently not supported. code_response_type: A boolean which sets whether to enable the code response flow for the new provider. By default, this is not enabled if no response type is specified. A client secret must be set for this response type. Having both the code and ID token response flows is currently not supported. - id_token_response_type: A boolean which sets whether to enable the ID token response flow - for the new provider. By default, this is enabled if no response type is specified. - Having both the code and ID token response flows is currently not supported. + app: An App instance (optional). Returns: OIDCProviderConfig: The updated OIDC provider config instance. @@ -742,7 +845,7 @@ def update_oidc_provider_config( code_response_type=code_response_type) -def delete_oidc_provider_config(provider_id, app=None): +def delete_oidc_provider_config(provider_id: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Deletes the ``OIDCProviderConfig`` with the given ID. Args: @@ -759,7 +862,10 @@ def delete_oidc_provider_config(provider_id, app=None): def list_oidc_provider_configs( - page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + page_token: typing.Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers._ListOIDCProviderConfigsPage: """Retrieves a page of OIDC provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -786,7 +892,10 @@ def list_oidc_provider_configs( return client.list_oidc_provider_configs(page_token, max_results) -def get_saml_provider_config(provider_id, app=None): +def get_saml_provider_config( + provider_id: str, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Returns the ``SAMLProviderConfig`` with the given ID. Args: @@ -806,8 +915,16 @@ def get_saml_provider_config(provider_id, app=None): def create_saml_provider_config( - provider_id, idp_entity_id, sso_url, x509_certificates, rp_entity_id, callback_url, - display_name=None, enabled=None, app=None): + provider_id: str, + idp_entity_id: str, + sso_url: str, + x509_certificates: typing.List[str], + rp_entity_id: str, + callback_url: str, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Creates a new SAML provider config from the given parameters. SAML provider support requires Google Cloud's Identity Platform (GCIP). To learn more about @@ -848,8 +965,16 @@ def create_saml_provider_config( def update_saml_provider_config( - provider_id, idp_entity_id=None, sso_url=None, x509_certificates=None, - rp_entity_id=None, callback_url=None, display_name=None, enabled=None, app=None): + provider_id: str, + idp_entity_id: typing.Optional[str] = None, + sso_url: typing.Optional[str] = None, + x509_certificates: typing.Optional[typing.List[str]] = None, + rp_entity_id: typing.Optional[str] = None, + callback_url: typing.Optional[str] = None, + display_name: typing.Optional[str] = None, + enabled: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers.SAMLProviderConfig: """Updates an existing SAML provider config with the given parameters. Args: @@ -880,7 +1005,7 @@ def update_saml_provider_config( callback_url=callback_url, display_name=display_name, enabled=enabled) -def delete_saml_provider_config(provider_id, app=None): +def delete_saml_provider_config(provider_id: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Deletes the ``SAMLProviderConfig`` with the given ID. Args: @@ -897,7 +1022,10 @@ def delete_saml_provider_config(provider_id, app=None): def list_saml_provider_configs( - page_token=None, max_results=_auth_providers.MAX_LIST_CONFIGS_RESULTS, app=None): + page_token: typing.Optional[str] = None, + max_results: int = _auth_providers.MAX_LIST_CONFIGS_RESULTS, + app: typing.Optional[firebase_admin.App] = None, +) -> _auth_providers._ListSAMLProviderConfigsPage: """Retrieves a page of SAML provider configs from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` diff --git a/firebase_admin/credentials.py b/firebase_admin/credentials.py index 75060028..32105eb4 100644 --- a/firebase_admin/credentials.py +++ b/firebase_admin/credentials.py @@ -13,16 +13,27 @@ # limitations under the License. """Firebase credentials module.""" -import collections +import datetime import json import pathlib +import typing +import typing_extensions import google.auth + from google.auth.credentials import Credentials as GoogleAuthCredentials from google.auth.transport import requests +from google.auth import crypt from google.oauth2 import credentials from google.oauth2 import service_account +if typing.TYPE_CHECKING: + from _typeshed import StrPath +else: + import os + + StrPath = typing.Union[str, os.PathLike[str]] + _request = requests.Request() _scopes = [ @@ -34,51 +45,56 @@ 'https://www.googleapis.com/auth/userinfo.email' ] -AccessTokenInfo = collections.namedtuple('AccessTokenInfo', ['access_token', 'expiry']) -"""Data included in an OAuth2 access token. -Contains the access token string and the expiry time. The expirty time is exposed as a -``datetime`` value. -""" +class AccessTokenInfo(typing.NamedTuple): + """Data included in an OAuth2 access token. + + Contains the access token string and the expiry time. The expirty time is exposed as a + ``datetime`` value. + """ + access_token: typing.Any + expiry: typing.Optional[datetime.datetime] class Base: """Provides OAuth2 access tokens for accessing Firebase services.""" - def get_access_token(self): + def get_access_token(self) -> AccessTokenInfo: """Fetches a Google OAuth2 access token using this credential instance. Returns: AccessTokenInfo: An access token obtained using the credential. """ google_cred = self.get_credential() - google_cred.refresh(_request) + google_cred.refresh(_request) # type: ignore[reportUnknownMemberType] return AccessTokenInfo(google_cred.token, google_cred.expiry) - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the Google credential instance used for authentication.""" raise NotImplementedError + class _ExternalCredentials(Base): """A wrapper for google.auth.credentials.Credentials typed credential instances""" - def __init__(self, credential: GoogleAuthCredentials): + def __init__(self, credential: GoogleAuthCredentials) -> None: super(_ExternalCredentials, self).__init__() self._g_credential = credential - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google Credential Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" return self._g_credential + class Certificate(Base): """A credential initialized from a JSON certificate keyfile.""" _CREDENTIAL_TYPE = 'service_account' - def __init__(self, cert): + def __init__(self, cert: typing.Union[StrPath, typing.Dict[str, typing.Any]]) -> None: """Initializes a credential from a Google service account certificate. Service account certificates can be downloaded as JSON files from the Firebase console. @@ -107,25 +123,25 @@ def __init__(self, cert): raise ValueError('Invalid service account certificate. Certificate must contain a ' '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) try: - self._g_credential = service_account.Credentials.from_service_account_info( + self._g_credential = service_account.Credentials.from_service_account_info( # type: ignore[reportUnknownMemberType] json_data, scopes=_scopes) except ValueError as error: raise ValueError('Failed to initialize a certificate credential. ' 'Caused by: "{0}"'.format(error)) @property - def project_id(self): - return self._g_credential.project_id + def project_id(self) -> typing.Optional[str]: + return self._g_credential.project_id # type: ignore[reportUnknownMemberType] @property - def signer(self): + def signer(self) -> crypt.Signer: return self._g_credential.signer @property - def service_account_email(self): + def service_account_email(self) -> str: return self._g_credential.service_account_email - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Returns: @@ -136,16 +152,17 @@ def get_credential(self): class ApplicationDefault(Base): """A Google Application Default credential.""" - def __init__(self): + def __init__(self) -> None: """Creates an instance that will use Application Default credentials. The credentials will be lazily initialized when get_credential() or project_id() is called. See those methods for possible errors raised. """ super(ApplicationDefault, self).__init__() - self._g_credential = None # Will be lazily-loaded via _load_credential(). + self._g_credential: typing.Optional[GoogleAuthCredentials] = None # Will be lazily-loaded via _load_credential(). + self._project_id: typing.Optional[str] - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Raises: @@ -154,10 +171,10 @@ def get_credential(self): Returns: google.auth.credentials.Credentials: A Google Auth credential instance.""" self._load_credential() - return self._g_credential + return typing.cast(GoogleAuthCredentials, self._g_credential) @property - def project_id(self): + def project_id(self) -> typing.Optional[str]: """Returns the project_id from the underlying Google credential. Raises: @@ -168,16 +185,17 @@ def project_id(self): self._load_credential() return self._project_id - def _load_credential(self): + def _load_credential(self) -> None: if not self._g_credential: - self._g_credential, self._project_id = google.auth.default(scopes=_scopes) + self._g_credential, self._project_id = google.auth.default(scopes=_scopes) # type: ignore[reportUnknownMemberType] + class RefreshToken(Base): """A credential initialized from an existing refresh token.""" _CREDENTIAL_TYPE = 'authorized_user' - def __init__(self, refresh_token): + def __init__(self, refresh_token: typing.Union[StrPath, typing.Dict[str, typing.Any]]) -> None: """Initializes a credential from a refresh token JSON file. The JSON must consist of client_id, client_secret and refresh_token fields. Refresh @@ -207,21 +225,22 @@ def __init__(self, refresh_token): if json_data.get('type') != self._CREDENTIAL_TYPE: raise ValueError('Invalid refresh token configuration. JSON must contain a ' '"type" field set to "{0}".'.format(self._CREDENTIAL_TYPE)) - self._g_credential = credentials.Credentials.from_authorized_user_info(json_data, _scopes) + self._g_credential = credentials.Credentials.from_authorized_user_info( # type: ignore[reportUnknownMemberType] + json_data, _scopes) @property - def client_id(self): - return self._g_credential.client_id + def client_id(self) -> typing.Optional[str]: + return self._g_credential.client_id # type: ignore[reportUnknownMemberType] @property - def client_secret(self): - return self._g_credential.client_secret + def client_secret(self) -> typing.Optional[str]: + return self._g_credential.client_secret # type: ignore[reportUnknownMemberType] @property - def refresh_token(self): - return self._g_credential.refresh_token + def refresh_token(self) -> typing.Optional[str]: + return self._g_credential.refresh_token # type: ignore[reportUnknownMemberType] - def get_credential(self): + def get_credential(self) -> GoogleAuthCredentials: """Returns the underlying Google credential. Returns: @@ -229,7 +248,7 @@ def get_credential(self): return self._g_credential -def _is_file_path(path): +def _is_file_path(path: typing.Any) -> typing_extensions.TypeGuard[StrPath]: try: pathlib.Path(path) return True diff --git a/firebase_admin/db.py b/firebase_admin/db.py index 1dec9865..9076bcd4 100644 --- a/firebase_admin/db.py +++ b/firebase_admin/db.py @@ -25,17 +25,25 @@ import os import sys import threading -from urllib import parse +import typing +import typing_extensions +import urllib.parse +import google.auth.credentials import requests import firebase_admin from firebase_admin import exceptions from firebase_admin import _http_client +from firebase_admin import _typing from firebase_admin import _sseclient from firebase_admin import _utils +_K = typing_extensions.TypeVar("_K", default=typing.Any) +_V = typing_extensions.TypeVar("_V", default=typing.Any) +_JsonT = typing_extensions.TypeVar("_JsonT", bound=_typing.Json, default=_typing.Json) + _DB_ATTRIBUTE = '_database' _INVALID_PATH_CHARACTERS = '[].?#$' _RESERVED_FILTERS = ('$key', '$value', '$priority') @@ -45,7 +53,11 @@ _EMULATOR_HOST_ENV_VAR = 'FIREBASE_DATABASE_EMULATOR_HOST' -def reference(path='/', app=None, url=None): +def reference( + path: str = '/', + app: typing.Optional[firebase_admin.App] = None, + url: typing.Optional[str] = None, +) -> "Reference": """Returns a database ``Reference`` representing the node at the specified path. If no path is specified, this function returns a ``Reference`` that represents the database @@ -69,7 +81,8 @@ def reference(path='/', app=None, url=None): client = service.get_client(url) return Reference(client=client, path=path) -def _parse_path(path): + +def _parse_path(path: typing.Any) -> typing.List[str]: """Parses a path string into a set of segments.""" if not isinstance(path, str): raise ValueError('Invalid path: "{0}". Path must be a string.'.format(path)) @@ -82,7 +95,7 @@ def _parse_path(path): class Event: """Represents a realtime update event received from the database.""" - def __init__(self, sse_event): + def __init__(self, sse_event: _sseclient.Event) -> None: self._sse_event = sse_event self._data = json.loads(sse_event.data) @@ -97,7 +110,7 @@ def path(self): return self._data['path'] @property - def event_type(self): + def event_type(self) -> str: """Event type string (put, patch).""" return self._sse_event.event_type @@ -105,7 +118,11 @@ def event_type(self): class ListenerRegistration: """Represents the addition of an event listener to a database reference.""" - def __init__(self, callback, sse): + def __init__( + self, + callback: typing.Callable[[Event], None], + sse: _sseclient.SSEClient, + ) -> None: """Initializes a new listener with given parameters. This is an internal API. Use the ``db.Reference.listen()`` method to start a @@ -120,14 +137,14 @@ def __init__(self, callback, sse): self._thread = threading.Thread(target=self._start_listen) self._thread.start() - def _start_listen(self): + def _start_listen(self) -> None: # iterate the sse client's generator for sse_event in self._sse: # only inject data events if sse_event: self._callback(Event(sse_event)) - def close(self): + def close(self) -> None: """Stops the event listener represented by this registration This closes the SSE HTTP connection, and joins the background thread. @@ -139,36 +156,59 @@ def close(self): class Reference: """Reference represents a node in the Firebase realtime database.""" - def __init__(self, **kwargs): + @typing.overload + def __init__( + self, + *, + client: "_Client", + path: str, + **kwargs: typing.Any, + ) -> None: ... + @typing.overload + def __init__( + self, + *, + client: "_Client", + segments: typing.List[str], + **kwargs: typing.Any, + ) -> None: ... + def __init__( + self, + *, + client: "_Client", + path: typing.Optional[str] = None, + segments: typing.Optional[typing.List[str]] = None, + **kwargs: typing.Any, + ) -> None: """Creates a new Reference using the provided parameters. This method is for internal use only. Use db.reference() to obtain an instance of Reference. """ - self._client = kwargs.get('client') - if 'segments' in kwargs: - self._segments = kwargs.get('segments') + self._client = client + if segments: + self._segments = segments else: - self._segments = _parse_path(kwargs.get('path')) + self._segments = _parse_path(path) self._pathurl = '/' + '/'.join(self._segments) @property - def key(self): + def key(self) -> typing.Optional[str]: if self._segments: return self._segments[-1] return None @property - def path(self): + def path(self) -> str: return self._pathurl @property - def parent(self): + def parent(self) -> typing.Optional["Reference"]: if self._segments: return Reference(client=self._client, segments=self._segments[:-1]) return None - def child(self, path): + def child(self, path: typing.Optional[str]) -> "Reference": """Returns a Reference to the specified child node. The path may point to an immediate child of the current Reference, or a deeply nested @@ -192,7 +232,23 @@ def child(self, path): full_path = self._pathurl + '/' + path return Reference(client=self._client, path=full_path) - def get(self, etag=False, shallow=False): + @typing.overload + def get( # type: ignore[reportOverlappingOverload] + self, + etag: typing.Literal[True], + shallow: bool = False, + ) -> typing.Tuple[typing.Dict[str, _typing.Json], str]: ... + @typing.overload + def get( + self, + etag: bool = False, + shallow: bool = False, + ) -> typing.Dict[str, _typing.Json]: ... + def get( + self, + etag: bool = False, + shallow: bool = False, + ) -> typing.Union[typing.Tuple[typing.Dict[str, _typing.Json], str], typing.Dict[str, _typing.Json]]: """Returns the value, and optionally the ETag, at the current location of the database. Args: @@ -215,12 +271,12 @@ def get(self, etag=False, shallow=False): raise ValueError('etag and shallow cannot both be set to True.') headers, data = self._client.headers_and_body( 'get', self._add_suffix(), headers={'X-Firebase-ETag' : 'true'}) - return data, headers.get('ETag') + return data, typing.cast(str, headers.get('ETag')) params = 'shallow=true' if shallow else None return self._client.body('get', self._add_suffix(), params=params) - def get_if_changed(self, etag): + def get_if_changed(self, etag: str) -> typing.Tuple[bool, typing.Optional[typing.Any], typing.Optional[str]]: """Gets data in this location only if the specified ETag does not match. Args: @@ -246,7 +302,7 @@ def get_if_changed(self, etag): return True, resp.json(), resp.headers.get('ETag') - def set(self, value): + def set(self, value: _typing.Json) -> None: """Sets the data at this location to the given value. The value must be JSON-serializable and not None. @@ -263,7 +319,11 @@ def set(self, value): raise ValueError('Value must not be None.') self._client.request('put', self._add_suffix(), json=value, params='print=silent') - def set_if_unchanged(self, expected_etag, value): + def set_if_unchanged( + self, + expected_etag: str, + value: _JsonT + ) -> typing.Tuple[bool, _JsonT, str]: """Conditonally sets the data at this location to the given value. Sets the data at this location to the given value only if ``expected_etag`` is same as the @@ -291,7 +351,7 @@ def set_if_unchanged(self, expected_etag, value): try: headers = self._client.headers( 'put', self._add_suffix(), json=value, headers={'if-match': expected_etag}) - return True, value, headers.get('ETag') + return True, value, typing.cast(str, headers.get('ETag')) except exceptions.FailedPreconditionError as error: http_response = error.http_response if http_response is not None and 'ETag' in http_response.headers: @@ -301,7 +361,7 @@ def set_if_unchanged(self, expected_etag, value): raise error - def push(self, value=''): + def push(self, value: _typing.Json = '') -> "Reference": """Creates a new child node. The optional value argument can be used to provide an initial value for the child node. If @@ -321,10 +381,10 @@ def push(self, value=''): if value is None: raise ValueError('Value must not be None.') output = self._client.body('post', self._add_suffix(), json=value) - push_id = output.get('name') + push_id = typing.cast(typing.Optional[str], output.get('name')) return self.child(push_id) - def update(self, value): + def update(self, value: _typing.Json) -> None: """Updates the specified child keys of this Reference to the provided values. Args: @@ -340,7 +400,7 @@ def update(self, value): raise ValueError('Dictionary must not contain None keys.') self._client.request('patch', self._add_suffix(), json=value, params='print=silent') - def delete(self): + def delete(self) -> None: """Deletes this node from the database. Raises: @@ -348,7 +408,7 @@ def delete(self): """ self._client.request('delete', self._add_suffix()) - def listen(self, callback): + def listen(self, callback: typing.Callable[[Event], None]) -> ListenerRegistration: """Registers the ``callback`` function to receive realtime updates. The specified callback function will get invoked with ``db.Event`` objects for each @@ -374,7 +434,7 @@ def listen(self, callback): """ return self._listen_with_session(callback) - def transaction(self, transaction_update): + def transaction(self, transaction_update: typing.Callable[[_typing.Json], _typing.Json]) -> _typing.Json: """Atomically modifies the data at this location. Unlike a normal ``set()``, which just overwrites the data regardless of its previous state, @@ -417,7 +477,7 @@ def transaction(self, transaction_update): raise TransactionAbortedError('Transaction aborted after failed retries.') - def order_by_child(self, path): + def order_by_child(self, path: str) -> "Query": """Returns a Query that orders data by child values. Returned Query can be used to set additional parameters, and execute complex database @@ -436,7 +496,7 @@ def order_by_child(self, path): raise ValueError('Illegal child path: {0}'.format(path)) return Query(order_by=path, client=self._client, pathurl=self._add_suffix()) - def order_by_key(self): + def order_by_key(self) -> "Query": """Creates a Query that orderes data by key. Returned Query can be used to set additional parameters, and execute complex database @@ -447,7 +507,7 @@ def order_by_key(self): """ return Query(order_by='$key', client=self._client, pathurl=self._add_suffix()) - def order_by_value(self): + def order_by_value(self) -> "Query": """Creates a Query that orderes data by value. Returned Query can be used to set additional parameters, and execute complex database @@ -458,16 +518,20 @@ def order_by_value(self): """ return Query(order_by='$value', client=self._client, pathurl=self._add_suffix()) - def _add_suffix(self, suffix='.json'): + def _add_suffix(self, suffix: str = '.json') -> str: return self._pathurl + suffix - def _listen_with_session(self, callback, session=None): + def _listen_with_session( + self, + callback: typing.Callable[[Event], None], + session: typing.Optional[requests.Session] = None, + ) -> ListenerRegistration: url = self._client.base_url + self._add_suffix() if not session: session = self._client.create_listener_session() try: - sse = _sseclient.SSEClient(url, session, **{"params": self._client.params}) + sse = _sseclient.SSEClient(url, session, params=self._client.params) return ListenerRegistration(callback, sse) except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) @@ -486,8 +550,7 @@ class Query: OrderedDict. """ - def __init__(self, **kwargs): - order_by = kwargs.pop('order_by') + def __init__(self, *, client: "_Client", order_by: str, pathurl: str, **kwargs: typing.Any) -> None: if not order_by or not isinstance(order_by, str): raise ValueError('order_by field must be a non-empty string') if order_by not in _RESERVED_FILTERS: @@ -496,14 +559,14 @@ def __init__(self, **kwargs): 'with "/"'.format(order_by)) segments = _parse_path(order_by) order_by = '/'.join(segments) - self._client = kwargs.pop('client') - self._pathurl = kwargs.pop('pathurl') + self._client = client + self._pathurl = pathurl self._order_by = order_by - self._params = {'orderBy' : json.dumps(order_by)} + self._params: typing.Dict[str, typing.Any] = {'orderBy' : json.dumps(order_by)} if kwargs: raise ValueError('Unexpected keyword arguments: {0}'.format(kwargs)) - def limit_to_first(self, limit): + def limit_to_first(self, limit: int) -> typing_extensions.Self: """Creates a query with limit, and anchors it to the start of the window. Args: @@ -522,7 +585,7 @@ def limit_to_first(self, limit): self._params['limitToFirst'] = limit return self - def limit_to_last(self, limit): + def limit_to_last(self, limit: int) -> typing_extensions.Self: """Creates a query with limit, and anchors it to the end of the window. Args: @@ -541,7 +604,7 @@ def limit_to_last(self, limit): self._params['limitToLast'] = limit return self - def start_at(self, start): + def start_at(self, start: _typing.Json) -> typing_extensions.Self: """Sets the lower bound for a range query. The Query will only return child nodes with a value greater than or equal to the specified @@ -561,7 +624,7 @@ def start_at(self, start): self._params['startAt'] = json.dumps(start) return self - def end_at(self, end): + def end_at(self, end: _typing.Json) -> typing_extensions.Self: """Sets the upper bound for a range query. The Query will only return child nodes with a value less than or equal to the specified @@ -581,7 +644,7 @@ def end_at(self, end): self._params['endAt'] = json.dumps(end) return self - def equal_to(self, value): + def equal_to(self, value: _typing.Json) -> typing_extensions.Self: """Sets an equals constraint on the Query. The Query will only return child nodes whose value is equal to the specified value. @@ -601,13 +664,13 @@ def equal_to(self, value): return self @property - def _querystr(self): - params = [] + def _querystr(self) -> str: + params: typing.List[str] = [] for key in sorted(self._params): params.append('{0}={1}'.format(key, self._params[key])) return '&'.join(params) - def get(self): + def get(self) -> typing.Union[typing.Dict[str, _typing.Json], typing.List[_typing.Json]]: """Executes this Query and returns the results. The results will be returned as a sorted list or an OrderedDict. @@ -627,32 +690,36 @@ def get(self): class TransactionAbortedError(exceptions.AbortedError): """A transaction was aborted aftr exceeding the maximum number of retries.""" - def __init__(self, message): + def __init__(self, message: str) -> None: exceptions.AbortedError.__init__(self, message) -class _Sorter: +class _Sorter(typing.Generic[_K, _V]): """Helper class for sorting query results.""" - def __init__(self, results, order_by): + @typing.overload + def __init__(self, results: typing.Dict[_K, _V], order_by: str) -> None: ... + @typing.overload + def __init__(self: '_Sorter[int, _V]', results: typing.List[_V], order_by: str) -> None: ... # type: ignore[reportInvalidTypeVarUse] + def __init__(self, results: typing.Union[typing.Dict[_K, _V], typing.List[_V]], order_by: str) -> None: if isinstance(results, dict): self.dict_input = True entries = [_SortEntry(k, v, order_by) for k, v in results.items()] elif isinstance(results, list): self.dict_input = False - entries = [_SortEntry(k, v, order_by) for k, v in enumerate(results)] + entries = [_SortEntry(typing.cast(_K, k), v, order_by) for k, v in enumerate(results)] else: raise ValueError('Sorting not supported for "{0}" object.'.format(type(results))) self.sort_entries = sorted(entries) - def get(self): + def get(self) -> typing.Union['collections.OrderedDict[_K, _V]', typing.List[_V]]: if self.dict_input: return collections.OrderedDict([(e.key, e.value) for e in self.sort_entries]) return [e.value for e in self.sort_entries] -class _SortEntry: +class _SortEntry(typing.Generic[_K, _V]): """A wrapper that is capable of sorting items in a dictionary.""" _type_none = 0 @@ -662,7 +729,7 @@ class _SortEntry: _type_string = 4 _type_object = 5 - def __init__(self, key, value, order_by): + def __init__(self, key: _K, value: _V, order_by: str) -> None: self._key = key self._value = value if order_by in ('$key', '$priority'): @@ -674,23 +741,23 @@ def __init__(self, key, value, order_by): self._index_type = _SortEntry._get_index_type(self._index) @property - def key(self): + def key(self) -> _K: return self._key @property - def index(self): + def index(self) -> typing.Optional[typing.Any]: return self._index @property - def index_type(self): + def index_type(self) -> int: return self._index_type @property - def value(self): + def value(self) -> _V: return self._value @classmethod - def _get_index_type(cls, index): + def _get_index_type(cls, index: typing.Any) -> int: """Assigns an integer code to the type of the index. The index type determines how differently typed values are sorted. This ordering is based @@ -710,17 +777,18 @@ def _get_index_type(cls, index): return cls._type_object @classmethod - def _extract_child(cls, value, path): + def _extract_child(cls, value: typing.Any, path: str) -> typing.Optional[typing.Any]: segments = path.split('/') current = value for segment in segments: if isinstance(current, dict): + current = typing.cast(typing.Dict[str, typing.Any], current) current = current.get(segment) else: return None return current - def _compare(self, other): + def _compare(self, other: '_SortEntry') -> typing.Literal[-1, 0, 1]: """Compares two _SortEntry instances. If the indices have the same numeric or string type, compare them directly. Ties are @@ -735,39 +803,44 @@ def _compare(self, other): else: self_key, other_key = self.key, other.key - if self_key < other_key: + if self_key < other_key: # type: ignore[reportOperatorIssue] return -1 - if self_key > other_key: + if self_key > other_key: # type: ignore[reportOperatorIssue] return 1 return 0 - def __lt__(self, other): + def __lt__(self, other: '_SortEntry') -> bool: return self._compare(other) < 0 - def __le__(self, other): + def __le__(self, other: '_SortEntry') -> bool: return self._compare(other) <= 0 - def __gt__(self, other): + def __gt__(self, other: '_SortEntry') -> bool: return self._compare(other) > 0 - def __ge__(self, other): + def __ge__(self, other: '_SortEntry') -> bool: return self._compare(other) >= 0 - def __eq__(self, other): + def __eq__(self, other: '_SortEntry') -> bool: # type: ignore[reportIncompatibleMethodOverride] return self._compare(other) == 0 +class EmulatorConfig(typing.NamedTuple): + base_url: str + namespace: str + + class _DatabaseService: """Service that maintains a collection of database clients.""" _DEFAULT_AUTH_OVERRIDE = '_admin_' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: self._credential = app.credential db_url = app.options.get('databaseURL') if db_url: - self._db_url = db_url + self._db_url: typing.Optional[str] = db_url else: self._db_url = None @@ -777,7 +850,7 @@ def __init__(self, app): else: self._auth_override = None self._timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) - self._clients = {} + self._clients: typing.Dict[typing.Tuple[str, str], _Client] = {} emulator_host = os.environ.get(_EMULATOR_HOST_ENV_VAR) if emulator_host: @@ -789,7 +862,7 @@ def __init__(self, app): else: self._emulator_host = None - def get_client(self, db_url=None): + def get_client(self, db_url: typing.Optional[str] = None) -> "_Client": """Creates a client based on the db_url. Clients may be cached.""" if db_url is None: db_url = self._db_url @@ -799,7 +872,7 @@ def get_client(self, db_url=None): 'Invalid database URL: "{0}". Database URL must be a non-empty ' 'URL string.'.format(db_url)) - parsed_url = parse.urlparse(db_url) + parsed_url = urllib.parse.urlparse(db_url) if not parsed_url.netloc: raise ValueError( 'Invalid database URL: "{0}". Database URL must be a wellformed ' @@ -816,7 +889,6 @@ def get_client(self, db_url=None): base_url = 'https://{0}'.format(parsed_url.netloc) params = {} - if self._auth_override: params['auth_variable_override'] = self._auth_override @@ -826,9 +898,8 @@ def get_client(self, db_url=None): self._clients[client_cache_key] = client return self._clients[client_cache_key] - def _get_emulator_config(self, parsed_url): + def _get_emulator_config(self, parsed_url: urllib.parse.ParseResult) -> typing.Optional[EmulatorConfig]: """Checks whether the SDK should connect to the RTDB emulator.""" - EmulatorConfig = collections.namedtuple('EmulatorConfig', ['base_url', 'namespace']) if parsed_url.scheme != 'https': # Emulator mode enabled by passing http URL via AppOptions base_url, namespace = _DatabaseService._parse_emulator_url(parsed_url) @@ -842,9 +913,9 @@ def _get_emulator_config(self, parsed_url): return None @classmethod - def _parse_emulator_url(cls, parsed_url): + def _parse_emulator_url(cls, parsed_url: urllib.parse.ParseResult) -> typing.Tuple[str, str]: """Parses emulator URL like http://localhost:8080/?ns=foo-bar""" - query_ns = parse.parse_qs(parsed_url.query).get('ns') + query_ns = urllib.parse.parse_qs(parsed_url.query).get('ns') if parsed_url.scheme != 'http' or (not query_ns or len(query_ns) != 1 or not query_ns[0]): raise ValueError( 'Invalid database URL: "{0}". Database URL must be a valid URL to a ' @@ -855,8 +926,9 @@ def _parse_emulator_url(cls, parsed_url): return base_url, namespace @classmethod - def _get_auth_override(cls, app): - auth_override = app.options.get('databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE) + def _get_auth_override(cls, app: firebase_admin.App) -> typing.Optional[typing.Union[typing.Dict[str, typing.Any], str]]: + auth_override = typing.cast(typing.Optional[str], app.options.get( + 'databaseAuthVariableOverride', cls._DEFAULT_AUTH_OVERRIDE)) if auth_override == cls._DEFAULT_AUTH_OVERRIDE or auth_override is None: return auth_override if not isinstance(auth_override, dict): @@ -865,7 +937,7 @@ def _get_auth_override(cls, app): return auth_override - def close(self): + def close(self) -> None: for value in self._clients.values(): value.close() self._clients = {} @@ -878,7 +950,13 @@ class _Client(_http_client.JsonHttpClient): marshalling and unmarshalling of JSON data. """ - def __init__(self, credential, base_url, timeout, params=None): + def __init__( + self, + credential: typing.Optional[google.auth.credentials.Credentials], + base_url: str, + timeout: int, + params: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> None: """Creates a new _Client from the given parameters. This exists primarily to enable testing. For regular use, obtain _Client instances by @@ -898,7 +976,7 @@ def __init__(self, credential, base_url, timeout, params=None): self.credential = credential self.params = params if params else {} - def request(self, method, url, **kwargs): + def request(self, method: str, url: str, **kwargs: typing.Any) -> requests.Response: """Makes an HTTP call using the Python requests library. Extends the request() method of the parent JsonHttpClient class. Handles default @@ -930,11 +1008,11 @@ def request(self, method, url, **kwargs): except requests.exceptions.RequestException as error: raise _Client.handle_rtdb_error(error) - def create_listener_session(self): + def create_listener_session(self) -> _sseclient.KeepAuthSession: return _sseclient.KeepAuthSession(self.credential) @classmethod - def handle_rtdb_error(cls, error): + def handle_rtdb_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Converts an error encountered while calling RTDB into a FirebaseError.""" if error.response is None: return _utils.handle_requests_error(error) @@ -943,7 +1021,7 @@ def handle_rtdb_error(cls, error): return _utils.handle_requests_error(error, message=message) @classmethod - def _extract_error_message(cls, response): + def _extract_error_message(cls, response: requests.Response) -> str: """Extracts an error message from an error response. If the server has sent a JSON response with an 'error' field, which is the typical @@ -954,7 +1032,7 @@ def _extract_error_message(cls, response): message = None try: # RTDB error format: {"error": "text message"} - data = response.json() + data: typing.Dict[str, str] = response.json() if isinstance(data, dict): message = data.get('error') except ValueError: diff --git a/firebase_admin/exceptions.py b/firebase_admin/exceptions.py index 947f3680..fad1b4df 100644 --- a/firebase_admin/exceptions.py +++ b/firebase_admin/exceptions.py @@ -31,6 +31,9 @@ subtype error handlers. """ +import typing + +import requests #: Error code for ``InvalidArgumentError`` type. INVALID_ARGUMENT = 'INVALID_ARGUMENT' @@ -95,14 +98,20 @@ class FirebaseError(Exception): this object. """ - def __init__(self, code, message, cause=None, http_response=None): + def __init__( + self, + code: str, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: Exception.__init__(self, message) self._code = code self._cause = cause self._http_response = http_response @property - def code(self): + def code(self) -> str: return self._code @property @@ -117,7 +126,12 @@ def http_response(self): class InvalidArgumentError(FirebaseError): """Client specified an invalid argument.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, INVALID_ARGUMENT, message, cause, http_response) @@ -125,21 +139,36 @@ class FailedPreconditionError(FirebaseError): """Request can not be executed in the current system state, such as deleting a non-empty directory.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, FAILED_PRECONDITION, message, cause, http_response) class OutOfRangeError(FirebaseError): """Client specified an invalid range.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, OUT_OF_RANGE, message, cause, http_response) class UnauthenticatedError(FirebaseError): """Request not authenticated due to missing, invalid, or expired OAuth token.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, UNAUTHENTICATED, message, cause, http_response) @@ -150,7 +179,12 @@ class PermissionDeniedError(FirebaseError): have permission, or the API has not been enabled for the client project. """ - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, PERMISSION_DENIED, message, cause, http_response) @@ -158,70 +192,120 @@ class NotFoundError(FirebaseError): """A specified resource is not found, or the request is rejected by undisclosed reasons, such as whitelisting.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, NOT_FOUND, message, cause, http_response) class ConflictError(FirebaseError): """Concurrency conflict, such as read-modify-write conflict.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, CONFLICT, message, cause, http_response) class AbortedError(FirebaseError): """Concurrency conflict, such as read-modify-write conflict.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, ABORTED, message, cause, http_response) class AlreadyExistsError(FirebaseError): """The resource that a client tried to create already exists.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, ALREADY_EXISTS, message, cause, http_response) class ResourceExhaustedError(FirebaseError): """Either out of resource quota or reaching rate limiting.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, RESOURCE_EXHAUSTED, message, cause, http_response) class CancelledError(FirebaseError): """Request cancelled by the client.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, CANCELLED, message, cause, http_response) class DataLossError(FirebaseError): """Unrecoverable data loss or data corruption.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, DATA_LOSS, message, cause, http_response) class UnknownError(FirebaseError): """Unknown server error.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, UNKNOWN, message, cause, http_response) class InternalError(FirebaseError): """Internal server error.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, INTERNAL, message, cause, http_response) class UnavailableError(FirebaseError): """Service unavailable. Typically the server is down.""" - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, UNAVAILABLE, message, cause, http_response) @@ -233,5 +317,10 @@ class DeadlineExceededError(FirebaseError): request) and the request did not finish within the deadline. """ - def __init__(self, message, cause=None, http_response=None): + def __init__( + self, + message: str, + cause: typing.Optional[Exception] = None, + http_response: typing.Optional[requests.Response] = None + ) -> None: FirebaseError.__init__(self, DEADLINE_EXCEEDED, message, cause, http_response) diff --git a/firebase_admin/firestore.py b/firebase_admin/firestore.py index 52ea9067..2901d732 100644 --- a/firebase_admin/firestore.py +++ b/firebase_admin/firestore.py @@ -18,18 +18,15 @@ Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from __future__ import annotations -from typing import Optional, Dict -from firebase_admin import App +import typing + +import firebase_admin from firebase_admin import _utils try: - from google.cloud import firestore + # firestore defines __all__ for safe import * + from google.cloud.firestore import * # type: ignore[reportWildcardImportFromLibrary] from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - existing = globals().keys() - for key, value in firestore.__dict__.items(): - if not key.startswith('_') and key not in existing: - globals()[key] = value except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') from error @@ -38,7 +35,7 @@ _FIRESTORE_ATTRIBUTE = '_firestore' -def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.Client: +def client(app: typing.Optional[firebase_admin.App] = None, database_id: typing.Optional[str] = None) -> Client: """Returns a client that can be used to interact with Google Cloud Firestore. Args: @@ -68,11 +65,11 @@ def client(app: Optional[App] = None, database_id: Optional[str] = None) -> fire class _FirestoreService: """Service that maintains a collection of firestore clients.""" - def __init__(self, app: App) -> None: - self._app: App = app - self._clients: Dict[str, firestore.Client] = {} + def __init__(self, app: firebase_admin.App) -> None: + self._app = app + self._clients: typing.Dict[str, Client] = {} - def get_client(self, database_id: Optional[str]) -> firestore.Client: + def get_client(self, database_id: typing.Optional[str]) -> Client: """Creates a client based on the database_id. These clients are cached.""" database_id = database_id or DEFAULT_DATABASE if database_id not in self._clients: @@ -85,7 +82,7 @@ def get_client(self, database_id: Optional[str]) -> firestore.Client: 'or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - fs_client = firestore.Client( + fs_client = Client( credentials=credentials, project=project, database=database_id) self._clients[database_id] = fs_client diff --git a/firebase_admin/firestore_async.py b/firebase_admin/firestore_async.py index 4a197e9d..538e07f0 100644 --- a/firebase_admin/firestore_async.py +++ b/firebase_admin/firestore_async.py @@ -18,27 +18,27 @@ associated with Firebase apps. This requires the ``google-cloud-firestore`` Python module. """ -from __future__ import annotations -from typing import Optional, Dict -from firebase_admin import App +import typing + +import firebase_admin from firebase_admin import _utils try: - from google.cloud import firestore + # firestore defines __all__ for safe import * + from google.cloud.firestore import * # type: ignore[reportWildcardImportFromLibrary] from google.cloud.firestore_v1.base_client import DEFAULT_DATABASE - existing = globals().keys() - for key, value in firestore.__dict__.items(): - if not key.startswith('_') and key not in existing: - globals()[key] = value except ImportError as error: raise ImportError('Failed to import the Cloud Firestore library for Python. Make sure ' 'to install the "google-cloud-firestore" module.') from error -_FIRESTORE_ASYNC_ATTRIBUTE: str = '_firestore_async' +_FIRESTORE_ASYNC_ATTRIBUTE = '_firestore_async' -def client(app: Optional[App] = None, database_id: Optional[str] = None) -> firestore.AsyncClient: +def client( + app: typing.Optional[firebase_admin.App] = None, + database_id: typing.Optional[str] = None, +) -> AsyncClient: """Returns an async client that can be used to interact with Google Cloud Firestore. Args: @@ -68,11 +68,11 @@ def client(app: Optional[App] = None, database_id: Optional[str] = None) -> fire class _FirestoreAsyncService: """Service that maintains a collection of firestore async clients.""" - def __init__(self, app: App) -> None: - self._app: App = app - self._clients: Dict[str, firestore.AsyncClient] = {} + def __init__(self, app: firebase_admin.App) -> None: + self._app = app + self._clients: typing.Dict[str, AsyncClient] = {} - def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: + def get_client(self, database_id: typing.Optional[str]) -> AsyncClient: """Creates an async client based on the database_id. These clients are cached.""" database_id = database_id or DEFAULT_DATABASE if database_id not in self._clients: @@ -85,7 +85,7 @@ def get_client(self, database_id: Optional[str]) -> firestore.AsyncClient: 'or use service account credentials. Alternatively, set the ' 'GOOGLE_CLOUD_PROJECT environment variable.') - fs_client = firestore.AsyncClient( + fs_client = AsyncClient( credentials=credentials, project=project, database=database_id) self._clients[database_id] = fs_client diff --git a/firebase_admin/functions.py b/firebase_admin/functions.py index fa17dfc0..3a2f628e 100644 --- a/firebase_admin/functions.py +++ b/firebase_admin/functions.py @@ -14,21 +14,24 @@ """Firebase Functions module.""" -from __future__ import annotations -from datetime import datetime, timedelta -from urllib import parse -import re +import base64 +import dataclasses +import datetime import json -from base64 import b64encode -from typing import Any, Optional, Dict -from dataclasses import dataclass -from google.auth.compute_engine import Credentials as ComputeEngineCredentials +import re +import typing +import typing_extensions +import urllib.parse import requests +from google.auth.credentials import Credentials as GoogleAuthCredentials +from google.auth.compute_engine import Credentials as ComputeEngineCredentials + import firebase_admin -from firebase_admin import App from firebase_admin import _http_client +from firebase_admin import _typing from firebase_admin import _utils +from firebase_admin import exceptions _FUNCTIONS_ATTRIBUTE = '_functions' @@ -54,14 +57,14 @@ # Default canonical location ID of the task queue. _DEFAULT_LOCATION = 'us-central1' -def _get_functions_service(app) -> _FunctionsService: +def _get_functions_service(app: typing.Optional[firebase_admin.App]) -> "_FunctionsService": return _utils.get_app_service(app, _FUNCTIONS_ATTRIBUTE, _FunctionsService) def task_queue( - function_name: str, - extension_id: Optional[str] = None, - app: Optional[App] = None - ) -> TaskQueue: + function_name: str, + extension_id: typing.Optional[str] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> "TaskQueue": """Creates a reference to a TaskQueue for a given function name. The function name can be either: @@ -89,9 +92,10 @@ def task_queue( """ return _get_functions_service(app).task_queue(function_name, extension_id) + class _FunctionsService: """Service class that implements Firebase Functions functionality.""" - def __init__(self, app: App): + def __init__(self, app: firebase_admin.App) -> None: self._project_id = app.project_id if not self._project_id: raise ValueError( @@ -102,28 +106,27 @@ def __init__(self, app: App): self._credential = app.credential.get_credential() self._http_client = _http_client.JsonHttpClient(credential=self._credential) - def task_queue(self, function_name: str, extension_id: Optional[str] = None) -> TaskQueue: + def task_queue(self, function_name: str, extension_id: typing.Optional[str] = None) -> "TaskQueue": """Creates a TaskQueue instance.""" return TaskQueue( function_name, extension_id, self._project_id, self._credential, self._http_client) @classmethod - def handle_functions_error(cls, error: Any): + def handle_functions_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" - return _utils.handle_platform_error_from_requests(error) + class TaskQueue: """TaskQueue class that implements Firebase Cloud Tasks Queues functionality.""" def __init__( - self, - function_name: str, - extension_id: Optional[str], - project_id, - credential, - http_client - ) -> None: - + self, + function_name: str, + extension_id: typing.Optional[str], + project_id: typing.Optional[str], + credential: GoogleAuthCredentials, + http_client: _http_client.HttpClient[typing.Dict[str, _typing.Json]] + ) -> None: # Validate function_name _Validators.check_non_empty_string('function_name', function_name) @@ -144,8 +147,7 @@ def __init__( _Validators.check_non_empty_string('extension_id', self._extension_id) self._resource.resource_id = f'ext-{self._extension_id}-{self._resource.resource_id}' - - def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: + def enqueue(self, task_data: typing.Any, opts: typing.Optional["TaskOptions"] = None) -> str: """Creates a task and adds it to the queue. Tasks cannot be updated after creation. This action requires `cloudtasks.tasks.create` IAM permission on the service account. @@ -172,7 +174,7 @@ def enqueue(self, task_data: Any, opts: Optional[TaskOptions] = None) -> str: headers=_FUNCTIONS_HEADERS, json={'task': task_payload.__dict__} ) - task_name = resp.get('name', None) + task_name = typing.cast(str, resp['name']) task_resource = \ self._parse_resource_name(task_name, f'queues/{self._resource.resource_id}/tasks') return task_resource.resource_id @@ -203,8 +205,7 @@ def delete(self, task_id: str) -> None: except requests.exceptions.RequestException as error: raise _FunctionsService.handle_functions_error(error) - - def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Resource: + def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> "Resource": """Parses a full or partial resource path into a ``Resource``.""" if '/' not in resource_name: return Resource(resource_id=resource_name) @@ -215,7 +216,7 @@ def _parse_resource_name(self, resource_name: str, resource_id_key: str) -> Reso raise ValueError('Invalid resource name format.') return Resource(project_id=match[2], location_id=match[3], resource_id=match[4]) - def _get_url(self, resource: Resource, url_format: str) -> str: + def _get_url(self, resource: "Resource", url_format: str) -> str: """Generates url path from a ``Resource`` and url format string.""" return url_format.format( project_id=resource.project_id, @@ -223,18 +224,18 @@ def _get_url(self, resource: Resource, url_format: str) -> str: resource_id=resource.resource_id) def _validate_task_options( - self, - data: Any, - resource: Resource, - opts: Optional[TaskOptions] = None - ) -> Task: + self, + data: typing.Dict[str, typing.Any], + resource: "Resource", + opts: typing.Optional["TaskOptions"] = None, + ) -> "Task": """Validate and create a Task from optional ``TaskOptions``.""" task_http_request = { 'url': '', 'oidc_token': { 'service_account_email': '' }, - 'body': b64encode(json.dumps(data).encode()).decode(), + 'body': base64.b64encode(json.dumps(data).encode()).decode(), 'headers': { 'Content-Type': 'application/json', } @@ -255,7 +256,8 @@ def _validate_task_options( if not isinstance(opts.schedule_delay_seconds, int) \ or opts.schedule_delay_seconds < 0: raise ValueError('schedule_delay_seconds should be positive int.') - schedule_time = datetime.utcnow() + timedelta(seconds=opts.schedule_delay_seconds) + schedule_time = datetime.datetime.now(datetime.timezone.utc) + \ + datetime.timedelta(seconds=opts.schedule_delay_seconds) task.schedule_time = schedule_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') if opts.dispatch_deadline_seconds is not None: if not isinstance(opts.dispatch_deadline_seconds, int) \ @@ -279,7 +281,12 @@ def _validate_task_options( task.http_request['url'] = opts.uri return task - def _update_task_payload(self, task: Task, resource: Resource, extension_id: str) -> Task: + def _update_task_payload( + self, + task: "Task", + resource: "Resource", + extension_id: typing.Optional[str], + ) -> "Task": """Prepares task to be sent with credentials.""" # Get function url from task or generate from resources if not _Validators.is_non_empty_string(task.http_request['url']): @@ -289,49 +296,50 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str if _Validators.is_non_empty_string(extension_id) and \ isinstance(self._credential, ComputeEngineCredentials): - id_token = self._credential.token + id_token = typing.cast(str, self._credential.token) # type: ignore[reportUnknownMemberType] task.http_request['headers'] = \ {**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'} # Delete oidc token del task.http_request['oidc_token'] else: + # possible issue: _credential needs more specific annotation task.http_request['oidc_token'] = \ - {'service_account_email': self._credential.service_account_email} + {'service_account_email': self._credential.service_account_email} # type: ignore[reportUnknownMemberType] return task class _Validators: """A collection of data validation utilities.""" - @classmethod - def check_non_empty_string(cls, label: str, value: Any): + @staticmethod + def check_non_empty_string(label: str, value: typing.Any) -> None: """Checks if given value is a non-empty string and throws error if not.""" if not isinstance(value, str): raise ValueError('{0} "{1}" must be a string.'.format(label, value)) if value == '': raise ValueError('{0} "{1}" must be a non-empty string.'.format(label, value)) - @classmethod - def is_non_empty_string(cls, value: Any): + @staticmethod + def is_non_empty_string(value: typing.Any) -> typing_extensions.TypeGuard[str]: """Checks if given value is a non-empty string and returns bool.""" if not isinstance(value, str) or value == '': return False return True - @classmethod - def is_task_id(cls, task_id: Any): + @staticmethod + def is_task_id(task_id: str) -> bool: """Checks if given value is a valid task id.""" reg = '^[A-Za-z0-9_-]+$' if re.match(reg, task_id) is not None and len(task_id) <= 500: return True return False - @classmethod - def is_url(cls, url: Any): + @staticmethod + def is_url(url: typing.Any) -> typing_extensions.TypeGuard[str]: """Checks if given value is a valid url.""" if not isinstance(url, str): return False try: - parsed = parse.urlparse(url) + parsed = urllib.parse.urlparse(url) if not parsed.netloc or parsed.scheme not in ['http', 'https']: return False return True @@ -339,7 +347,7 @@ def is_url(cls, url: Any): return False -@dataclass +@dataclasses.dataclass class TaskOptions: """Task Options that can be applied to a Task. @@ -397,14 +405,15 @@ class TaskOptions: uri: The full URL that the request will be sent to. Must be a valid RFC3986 https or http URL. """ - schedule_delay_seconds: Optional[int] = None - schedule_time: Optional[datetime] = None - dispatch_deadline_seconds: Optional[int] = None - task_id: Optional[str] = None - headers: Optional[Dict[str, str]] = None - uri: Optional[str] = None - -@dataclass + schedule_delay_seconds: typing.Optional[int] = None + schedule_time: typing.Optional[datetime.datetime] = None + dispatch_deadline_seconds: typing.Optional[int] = None + task_id: typing.Optional[str] = None + headers: typing.Optional[typing.Dict[str, str]] = None + uri: typing.Optional[str] = None + + +@dataclasses.dataclass class Task: """Contains the relevant fields for enqueueing tasks that trigger Cloud Functions. @@ -418,13 +427,13 @@ class Task: schedule_time: The time when the task is scheduled to be attempted or retried. dispatch_deadline: The deadline for requests sent to the worker. """ - http_request: Dict[str, Optional[str | dict]] - name: Optional[str] = None - schedule_time: Optional[str] = None - dispatch_deadline: Optional[str] = None + http_request: typing.Dict[str, typing.Any] + name: typing.Optional[str] = None + schedule_time: typing.Optional[str] = None + dispatch_deadline: typing.Optional[str] = None -@dataclass +@dataclasses.dataclass class Resource: """Contains the parsed address of a resource. @@ -434,5 +443,5 @@ class Resource: location_id: The location ID of the resource. """ resource_id: str - project_id: Optional[str] = None - location_id: Optional[str] = None + project_id: typing.Optional[str] = None + location_id: typing.Optional[str] = None diff --git a/firebase_admin/instance_id.py b/firebase_admin/instance_id.py index 604158d9..d237aed3 100644 --- a/firebase_admin/instance_id.py +++ b/firebase_admin/instance_id.py @@ -16,9 +16,11 @@ This module enables deleting instance IDs associated with Firebase projects. """ +import typing import requests +import firebase_admin from firebase_admin import _http_client from firebase_admin import _utils @@ -27,11 +29,11 @@ _IID_ATTRIBUTE = '_iid' -def _get_iid_service(app): +def _get_iid_service(app: typing.Optional[firebase_admin.App]) -> "_InstanceIdService": return _utils.get_app_service(app, _IID_ATTRIBUTE, _InstanceIdService) -def delete_instance_id(instance_id, app=None): +def delete_instance_id(instance_id: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Deletes the specified instance ID and the associated data from Firebase. Note that Google Analytics for Firebase uses its own form of Instance ID to @@ -55,7 +57,7 @@ def delete_instance_id(instance_id, app=None): class _InstanceIdService: """Provides methods for interacting with the remote instance ID service.""" - error_codes = { + error_codes: typing.Dict[int, str] = { 400: 'Malformed instance ID argument.', 401: 'Request not authorized.', 403: 'Project does not match instance ID or the client does not have ' @@ -67,7 +69,7 @@ class _InstanceIdService: 503: 'Backend servers are over capacity. Try again later.' } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -78,7 +80,7 @@ def __init__(self, app): self._client = _http_client.JsonHttpClient( credential=app.credential.get_credential(), base_url=_IID_SERVICE_URL) - def delete_instance_id(self, instance_id): + def delete_instance_id(self, instance_id: str) -> None: if not isinstance(instance_id, str) or not instance_id: raise ValueError('Instance ID must be a non-empty string.') path = 'project/{0}/instanceId/{1}'.format(self._project_id, instance_id) @@ -88,7 +90,7 @@ def delete_instance_id(self, instance_id): msg = self._extract_message(instance_id, error) raise _utils.handle_requests_error(error, msg) - def _extract_message(self, instance_id, error): + def _extract_message(self, instance_id: str, error: requests.RequestException) -> typing.Optional[str]: if error.response is None: return None status = error.response.status_code diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index d2ad04a0..d34e2f2a 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -16,9 +16,12 @@ import concurrent.futures import json +import typing import warnings -import requests +import httplib2 +import requests +from google.auth import credentials from googleapiclient import http from googleapiclient import _auth @@ -27,9 +30,15 @@ from firebase_admin import _messaging_encoder from firebase_admin import _messaging_utils from firebase_admin import _gapic_utils +from firebase_admin import _typing from firebase_admin import _utils from firebase_admin import exceptions +if typing.TYPE_CHECKING: + from oauth2client.client import Credentials as OAuth2Credentials +else: + OAuth2Credentials = typing.Any + _MESSAGING_ATTRIBUTE = '_messaging' @@ -71,6 +80,13 @@ 'unsubscribe_from_topic', ] +_TransportBuilder = typing.Callable[[ + typing.Union[ + credentials.Credentials, + OAuth2Credentials, + ]], + httplib2.Http +] AndroidConfig = _messaging_utils.AndroidConfig AndroidFCMOptions = _messaging_utils.AndroidFCMOptions @@ -97,10 +113,11 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app): +def _get_messaging_service(app: typing.Optional[firebase_admin.App]) -> "_MessagingService": return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) -def send(message, dry_run=False, app=None): + +def send(message: Message, dry_run: bool = False, app: typing.Optional[firebase_admin.App] = None) -> str: """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -120,7 +137,12 @@ def send(message, dry_run=False, app=None): """ return _get_messaging_service(app).send(message, dry_run) -def send_each(messages, dry_run=False, app=None): + +def send_each( + messages: typing.List[Message], + dry_run: bool = False, + app: typing.Optional[firebase_admin.App] = None, +) -> "BatchResponse": """Sends each message in the given list via Firebase Cloud Messaging. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -140,7 +162,12 @@ def send_each(messages, dry_run=False, app=None): """ return _get_messaging_service(app).send_each(messages, dry_run) -def send_each_for_multicast(multicast_message, dry_run=False, app=None): + +def send_each_for_multicast( + multicast_message: MulticastMessage, + dry_run: bool = False, + app: typing.Optional[firebase_admin.App] = None, +): """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -171,7 +198,12 @@ def send_each_for_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_each(messages, dry_run) -def send_all(messages, dry_run=False, app=None): + +def send_all( + messages: typing.List[Message], + dry_run: bool = False, + app: typing.Optional[firebase_admin.App] = None, +) -> "BatchResponse": """Sends the given list of messages via Firebase Cloud Messaging as a single batch. If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -194,7 +226,12 @@ def send_all(messages, dry_run=False, app=None): warnings.warn('send_all() is deprecated. Use send_each() instead.', DeprecationWarning) return _get_messaging_service(app).send_all(messages, dry_run) -def send_multicast(multicast_message, dry_run=False, app=None): + +def send_multicast( + multicast_message: MulticastMessage, + dry_run: bool = False, + app: typing.Optional[firebase_admin.App] = None, +) -> "BatchResponse": """Sends the given mutlicast message to all tokens via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -229,7 +266,12 @@ def send_multicast(multicast_message, dry_run=False, app=None): ) for token in multicast_message.tokens] return _get_messaging_service(app).send_all(messages, dry_run) -def subscribe_to_topic(tokens, topic, app=None): + +def subscribe_to_topic( + tokens: typing.Union[typing.List[str], str], + topic: str, + app: typing.Optional[firebase_admin.App] = None, +) -> "TopicManagementResponse": """Subscribes a list of registration tokens to an FCM topic. Args: @@ -248,7 +290,12 @@ def subscribe_to_topic(tokens, topic, app=None): return _get_messaging_service(app).make_topic_management_request( tokens, topic, 'iid/v1:batchAdd') -def unsubscribe_from_topic(tokens, topic, app=None): + +def unsubscribe_from_topic( + tokens: typing.Union[typing.List[str], str], + topic: str, + app: typing.Optional[firebase_admin.App] = None, +) -> "TopicManagementResponse": """Unsubscribes a list of registration tokens from an FCM topic. Args: @@ -271,17 +318,17 @@ def unsubscribe_from_topic(tokens, topic, app=None): class ErrorInfo: """An error encountered when performing a topic management operation.""" - def __init__(self, index, reason): + def __init__(self, index: int, reason: str) -> None: self._index = index self._reason = reason @property - def index(self): + def index(self) -> int: """Index of the registration token to which this error is related to.""" return self._index @property - def reason(self): + def reason(self) -> str: """String describing the nature of the error.""" return self._reason @@ -289,12 +336,12 @@ def reason(self): class TopicManagementResponse: """The response received from a topic management operation.""" - def __init__(self, resp): + def __init__(self, resp: typing.Dict[str, typing.Any]) -> None: if not isinstance(resp, dict) or 'results' not in resp: raise ValueError('Unexpected topic management response: {0}.'.format(resp)) self._success_count = 0 self._failure_count = 0 - self._errors = [] + self._errors: typing.List[ErrorInfo] = [] for index, result in enumerate(resp['results']): if 'error' in result: self._failure_count += 1 @@ -303,17 +350,17 @@ def __init__(self, resp): self._success_count += 1 @property - def success_count(self): + def success_count(self) -> int: """Number of tokens that were successfully subscribed or unsubscribed.""" return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: """Number of tokens that could not be subscribed or unsubscribed due to errors.""" return self._failure_count @property - def errors(self): + def errors(self) -> typing.List[ErrorInfo]: """A list of ``messaging.ErrorInfo`` objects (possibly empty).""" return self._errors @@ -321,45 +368,49 @@ def errors(self): class BatchResponse: """The response received from a batch request to the FCM API.""" - def __init__(self, responses): + def __init__(self, responses: typing.List["SendResponse"]) -> None: self._responses = responses self._success_count = len([resp for resp in responses if resp.success]) @property - def responses(self): + def responses(self) -> typing.List["SendResponse"]: """A list of ``messaging.SendResponse`` objects (possibly empty).""" return self._responses @property - def success_count(self): + def success_count(self) -> int: return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: return len(self.responses) - self.success_count class SendResponse: """The response received from an individual batched request to the FCM API.""" - def __init__(self, resp, exception): + def __init__( + self, + resp: typing.Optional[typing.Dict[str, typing.Any]], + exception: typing.Optional[exceptions.FirebaseError], + ) -> None: self._exception = exception - self._message_id = None + self._message_id: typing.Optional[str] = None if resp: self._message_id = resp.get('name', None) @property - def message_id(self): + def message_id(self) -> typing.Optional[str]: """A message ID string that uniquely identifies the message.""" return self._message_id @property - def success(self): + def success(self) -> bool: """A boolean indicating if the request was successful.""" return self._message_id is not None and not self._exception @property - def exception(self): + def exception(self) -> typing.Optional[exceptions.FirebaseError]: """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception @@ -373,15 +424,15 @@ class _MessagingService: IID_HEADERS = {'access_token_auth': 'true'} JSON_ENCODER = _messaging_encoder.MessageEncoder() - FCM_ERROR_TYPES = { + FCM_ERROR_TYPES: typing.Dict[str, _typing.FirebaseErrorFactoryWithDefaults] = { 'APNS_AUTH_ERROR': ThirdPartyAuthError, 'QUOTA_EXCEEDED': QuotaExceededError, 'SENDER_ID_MISMATCH': SenderIdMismatchError, 'THIRD_PARTY_AUTH_ERROR': ThirdPartyAuthError, - 'UNREGISTERED': UnregisteredError, + 'UNREGISTERED': UnregisteredError } - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -396,15 +447,15 @@ def __init__(self, app): timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) - self._build_transport = _auth.authorized_http + self._build_transport: _TransportBuilder = _auth.authorized_http # type: ignore[reportUnknownMemberType] @classmethod - def encode_message(cls, message): + def encode_message(cls, message: Message) -> typing.Dict[str, typing.Any]: if not isinstance(message, Message): raise ValueError('Message must be an instance of messaging.Message class.') return cls.JSON_ENCODER.default(message) - def send(self, message, dry_run=False): + def send(self, message: Message, dry_run: bool = False) -> str: """Sends the given message to FCM via the FCM v1 API.""" data = self._message_data(message, dry_run) try: @@ -417,16 +468,16 @@ def send(self, message, dry_run=False): except requests.exceptions.RequestException as error: raise self._handle_fcm_error(error) else: - return resp['name'] + return typing.cast(str, resp['name']) - def send_each(self, messages, dry_run=False): + def send_each(self, messages: typing.List[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the FCM v1 API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') - def send_data(data): + def send_data(data: typing.Dict[str, typing.Any]) -> SendResponse: try: resp = self._client.body( 'post', @@ -448,16 +499,20 @@ def send_data(data): message='Unknown error while making remote service calls: {0}'.format(error), cause=error) - def send_all(self, messages, dry_run=False): + def send_all(self, messages: typing.List[Message], dry_run: bool = False) -> BatchResponse: """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): raise ValueError('messages must be a list of messaging.Message instances.') if len(messages) > 500: raise ValueError('messages must not contain more than 500 elements.') - responses = [] + responses: typing.List[SendResponse] = [] - def batch_callback(_, response, error): + def batch_callback( + _: str, + response: typing.Optional[typing.Dict[str, typing.Any]], + error: typing.Optional[Exception], + ) -> None: exception = None if error: exception = self._handle_batch_error(error) @@ -477,16 +532,21 @@ def batch_callback(_, response, error): body=body, headers=self._fcm_headers ) - batch.add(req) + batch.add(req) # type: ignore[reportUnknownMemberType] try: - batch.execute() + batch.execute() # type: ignore[reportUnknownMemberType] except Exception as error: raise self._handle_batch_error(error) else: return BatchResponse(responses) - def make_topic_management_request(self, tokens, topic, operation): + def make_topic_management_request( + self, + tokens: typing.Union[typing.List[str], str], + topic: str, + operation: str, + ) -> TopicManagementResponse: """Invokes the IID service for topic management functionality.""" if isinstance(tokens, str): tokens = [tokens] @@ -517,30 +577,30 @@ def make_topic_management_request(self, tokens, topic, operation): else: return TopicManagementResponse(resp) - def _message_data(self, message, dry_run): - data = {'message': _MessagingService.encode_message(message)} + def _message_data(self, message: Message, dry_run: bool) -> typing.Dict[str, typing.Any]: + data: typing.Dict[str, typing.Any] = {'message': _MessagingService.encode_message(message)} if dry_run: data['validate_only'] = True return data - def _postproc(self, _, body): + def _postproc(self, _: httplib2.Response, body: bytes) -> typing.Any: """Handle response from batch API request.""" # This only gets called for 2xx responses. return json.loads(body.decode()) - def _handle_fcm_error(self, error): + def _handle_fcm_error(self, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the FCM API.""" return _utils.handle_platform_error_from_requests( error, _MessagingService._build_fcm_error_requests) - def _handle_iid_error(self, error): + def _handle_iid_error(self, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Instance ID API.""" if error.response is None: raise _utils.handle_requests_error(error) - data = {} + data: typing.Dict[str, typing.Any] = {} try: - parsed_body = error.response.json() + parsed_body: _typing.Json = error.response.json() if isinstance(parsed_body, dict): data = parsed_body except ValueError: @@ -557,30 +617,41 @@ def _handle_iid_error(self, error): return _utils.handle_requests_error(error, msg) - def _handle_batch_error(self, error): + def _handle_batch_error(self, error: Exception) -> exceptions.FirebaseError: """Handles errors received from the googleapiclient while making batch requests.""" return _gapic_utils.handle_platform_error_from_googleapiclient( error, _MessagingService._build_fcm_error_googleapiclient) @classmethod - def _build_fcm_error_requests(cls, error, message, error_dict): + def _build_fcm_error_requests( + cls, + error: requests.RequestException, + message: str, + error_dict: typing.Dict[str, typing.Any] + ) -> typing.Optional[exceptions.FirebaseError]: """Parses an error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) return exc_type(message, cause=error, http_response=error.response) if exc_type else None @classmethod - def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): + def _build_fcm_error_googleapiclient( + cls, + error: typing.Optional[Exception], + message: str, + error_dict: typing.Dict[str, typing.Any], + http_response: typing.Optional[requests.Response] + ) -> typing.Optional[exceptions.FirebaseError]: """Parses an error response from the FCM API and creates a FCM-specific exception if appropriate.""" exc_type = cls._build_fcm_error(error_dict) return exc_type(message, cause=error, http_response=http_response) if exc_type else None @classmethod - def _build_fcm_error(cls, error_dict): + def _build_fcm_error(cls, error_dict: typing.Dict[str, typing.Any]) -> typing.Optional[_typing.FirebaseErrorFactory]: if not error_dict: return None - fcm_code = None + fcm_code: typing.Any = None for detail in error_dict.get('details', []): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': fcm_code = detail.get('errorCode') diff --git a/firebase_admin/ml.py b/firebase_admin/ml.py index 98bdbb56..519211cc 100644 --- a/firebase_admin/ml.py +++ b/firebase_admin/ml.py @@ -18,34 +18,59 @@ deleting, publishing and unpublishing Firebase ML models. """ - import datetime +import os import re import time -import os -from urllib import parse +import typing +import urllib.parse import warnings import requests import firebase_admin from firebase_admin import _http_client +from firebase_admin import _typing from firebase_admin import _utils from firebase_admin import exceptions # pylint: disable=import-error,no-name-in-module -try: +if typing.TYPE_CHECKING: + from _typeshed import Incomplete from firebase_admin import storage - _GCS_ENABLED = True -except ImportError: - _GCS_ENABLED = False + + _GCS_ENABLED: bool +else: + Incomplete = typing.Any + + try: + from firebase_admin import storage + _GCS_ENABLED = True + except ImportError: + _GCS_ENABLED = False # pylint: disable=import-error,no-name-in-module -try: +if typing.TYPE_CHECKING: import tensorflow as tf - _TF_ENABLED = True -except ImportError: - _TF_ENABLED = False + + _TF_ENABLED: bool +else: + try: + import tensorflow as tf + _TF_ENABLED = True + except ImportError: + _TF_ENABLED = False # type: ignore[reportConstantRedefinition] + + +_DownloadCallback = typing.Callable[ + [ + typing.Optional[str], + typing.Optional[int], + typing.Optional[str] + ], + typing.Dict[str, _typing.Json] +] + _ML_ATTRIBUTE = '_ml' _MAX_PAGE_SIZE = 100 @@ -63,7 +88,7 @@ r'^projects/(?P[a-z0-9-]{6,30})/operations/[^/]+$') -def _get_ml_service(app): +def _get_ml_service(app: typing.Optional[firebase_admin.App]) -> "_MLService": """ Returns an _MLService instance for an App. Args: @@ -78,7 +103,7 @@ def _get_ml_service(app): return _utils.get_app_service(app, _ML_ATTRIBUTE, _MLService) -def create_model(model, app=None): +def create_model(model: "Model", app: typing.Optional[firebase_admin.App] = None) -> "Model": """Creates a model in the current Firebase project. Args: @@ -92,7 +117,7 @@ def create_model(model, app=None): return Model.from_dict(ml_service.create_model(model), app=app) -def update_model(model, app=None): +def update_model(model: "Model", app: typing.Optional[firebase_admin.App] = None) -> "Model": """Updates a model's metadata or model file. Args: @@ -106,7 +131,7 @@ def update_model(model, app=None): return Model.from_dict(ml_service.update_model(model), app=app) -def publish_model(model_id, app=None): +def publish_model(model_id: str, app: typing.Optional[firebase_admin.App] = None) -> "Model": """Publishes a Firebase ML model. A published model can be downloaded to client apps. @@ -122,7 +147,7 @@ def publish_model(model_id, app=None): return Model.from_dict(ml_service.set_published(model_id, publish=True), app=app) -def unpublish_model(model_id, app=None): +def unpublish_model(model_id: str, app: typing.Optional[firebase_admin.App] = None) -> "Model": """Unpublishes a Firebase ML model. Args: @@ -136,7 +161,7 @@ def unpublish_model(model_id, app=None): return Model.from_dict(ml_service.set_published(model_id, publish=False), app=app) -def get_model(model_id, app=None): +def get_model(model_id: str, app: typing.Optional[firebase_admin.App] = None) -> "Model": """Gets the model specified by the given ID. Args: @@ -150,7 +175,12 @@ def get_model(model_id, app=None): return Model.from_dict(ml_service.get_model(model_id), app=app) -def list_models(list_filter=None, page_size=None, page_token=None, app=None): +def list_models( + list_filter: typing.Optional[str] = None, + page_size: typing.Optional[int] = None, + page_token: typing.Optional[str] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> "ListModelsPage": """Lists the current project's models. Args: @@ -169,7 +199,7 @@ def list_models(list_filter=None, page_size=None, page_token=None, app=None): ml_service.list_models, list_filter, page_size, page_token, app=app) -def delete_model(model_id, app=None): +def delete_model(model_id: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Deletes a model from the current project. Args: @@ -188,9 +218,14 @@ class Model: tags: Optional list of strings associated with your model. Can be used in list queries. model_format: A subclass of ModelFormat. (e.g. TFLiteFormat) Specifies the model details. """ - def __init__(self, display_name=None, tags=None, model_format=None): - self._app = None # Only needed for wait_for_unlo - self._data = {} + def __init__( + self, + display_name: typing.Optional[str] = None, + tags: typing.Optional[typing.List[str]] = None, + model_format: typing.Optional["ModelFormat"] = None, + ) -> None: + self._app: typing.Optional[firebase_admin.App] = None # Only needed for wait_for_unlo + self._data: typing.Dict[str, typing.Any] = {} self._model_format = None if display_name is not None: @@ -201,7 +236,7 @@ def __init__(self, display_name=None, tags=None, model_format=None): self.model_format = model_format @classmethod - def from_dict(cls, data, app=None): + def from_dict(cls, data: typing.Dict[str, typing.Any], app: typing.Optional[firebase_admin.App] = None) -> "Model": """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = None @@ -214,97 +249,95 @@ def from_dict(cls, data, app=None): model._app = app # pylint: disable=protected-access return model - def _update_from_dict(self, data): + def _update_from_dict(self, data: typing.Dict[str, typing.Any]) -> None: copy = Model.from_dict(data) self.model_format = copy.model_format self._data = copy._data # pylint: disable=protected-access - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_format == other._model_format return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property - def model_id(self): + def model_id(self) -> typing.Optional[str]: """The model's ID, unique to the project.""" if not self._data.get('name'): return None - _, model_id = _validate_and_parse_name(self._data.get('name')) + _, model_id = _validate_and_parse_name(self._data['name']) return model_id @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: """The model's display name, used to refer to the model in code and in the Firebase console.""" return self._data.get('displayName') @display_name.setter - def display_name(self, display_name): + def display_name(self, display_name: str) -> None: self._data['displayName'] = _validate_display_name(display_name) - return self @staticmethod - def _convert_to_millis(date_string): + def _convert_to_millis(date_string: typing.Optional[str]) -> typing.Optional[int]: if not date_string: return None format_str = '%Y-%m-%dT%H:%M:%S.%fZ' - epoch = datetime.datetime.utcfromtimestamp(0) + epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) datetime_object = datetime.datetime.strptime(date_string, format_str) millis = int((datetime_object - epoch).total_seconds() * 1000) return millis @property - def create_time(self): + def create_time(self) -> typing.Optional[int]: """The time the model was created.""" return Model._convert_to_millis(self._data.get('createTime', None)) @property - def update_time(self): + def update_time(self) -> typing.Optional[int]: """The time the model was last updated.""" return Model._convert_to_millis(self._data.get('updateTime', None)) @property - def validation_error(self): + def validation_error(self) -> typing.Optional[str]: """Validation error message.""" return self._data.get('state', {}).get('validationError', {}).get('message') @property - def published(self): + def published(self) -> bool: """True if the model is published and available for clients to download.""" return bool(self._data.get('state', {}).get('published')) @property - def etag(self): + def etag(self) -> typing.Optional[Incomplete]: """The entity tag (ETag) of the model resource.""" return self._data.get('etag') @property - def model_hash(self): + def model_hash(self) -> typing.Optional[Incomplete]: """SHA256 hash of the model binary.""" return self._data.get('modelHash') @property - def tags(self): + def tags(self) -> typing.Optional[typing.List[str]]: """Tag strings, used for filtering query results.""" return self._data.get('tags') @tags.setter - def tags(self, tags): + def tags(self, tags: typing.List[str]) -> None: self._data['tags'] = _validate_tags(tags) - return self @property - def locked(self): + def locked(self) -> bool: """True if the Model object is locked by an active operation.""" return bool(self._data.get('activeOperations') and - len(self._data.get('activeOperations')) > 0) + len(self._data['activeOperations']) > 0) - def wait_for_unlocked(self, max_time_seconds=None): + def wait_for_unlocked(self, max_time_seconds: typing.Optional[float] = None) -> None: """Waits for the model to be unlocked. (All active operations complete) Args: @@ -317,7 +350,7 @@ def wait_for_unlocked(self, max_time_seconds=None): if not self.locked: return ml_service = _get_ml_service(self._app) - op_name = self._data.get('activeOperations')[0].get('name') + op_name = self._data['activeOperations'][0].get('name') model_dict = ml_service.handle_operation( ml_service.get_operation(op_name), wait_for_operation=True, @@ -325,19 +358,18 @@ def wait_for_unlocked(self, max_time_seconds=None): self._update_from_dict(model_dict) @property - def model_format(self): + def model_format(self) -> typing.Optional["ModelFormat"]: """The model's ``ModelFormat`` object, which represents the model's format and storage location.""" return self._model_format @model_format.setter - def model_format(self, model_format): + def model_format(self, model_format: typing.Optional["ModelFormat"]) -> None: if model_format is not None: _validate_model_format(model_format) self._model_format = model_format #Can be None - return self - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> typing.Dict[str, typing.Any]: """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_format: @@ -347,7 +379,7 @@ def as_dict(self, for_upload=False): class ModelFormat: """Abstract base class representing a Model Format such as TFLite.""" - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> typing.Dict[str, typing.Any]: """Returns a serializable representation of the object.""" raise NotImplementedError @@ -358,32 +390,32 @@ class TFLiteFormat(ModelFormat): Args: model_source: A TFLiteModelSource sub class. Specifies the details of the model source. """ - def __init__(self, model_source=None): - self._data = {} - self._model_source = None + def __init__(self, model_source: typing.Optional["TFLiteModelSource"] = None) -> None: + self._data: typing.Dict[str, typing.Any] = {} + self._model_source: typing.Optional[TFLiteModelSource] = None if model_source is not None: self.model_source = model_source @classmethod - def from_dict(cls, data): + def from_dict(cls, data: typing.Dict[str, typing.Any]) -> "TFLiteFormat": """Create an instance of the object from a dict.""" data_copy = dict(data) tflite_format = TFLiteFormat(model_source=cls._init_model_source(data_copy)) tflite_format._data = data_copy # pylint: disable=protected-access return tflite_format - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): # pylint: disable=protected-access return self._data == other._data and self._model_source == other._model_source return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @staticmethod - def _init_model_source(data): + def _init_model_source(data: typing.Dict[str, typing.Any]) -> typing.Optional["TFLiteModelSource"]: """Initialize the ML model source.""" gcs_tflite_uri = data.pop('gcsTfliteUri', None) if gcs_tflite_uri: @@ -396,23 +428,23 @@ def _init_model_source(data): return None @property - def model_source(self): + def model_source(self) -> typing.Optional["TFLiteModelSource"]: """The TF Lite model's location.""" return self._model_source @model_source.setter - def model_source(self, model_source): + def model_source(self, model_source: typing.Optional["TFLiteModelSource"]) -> None: if model_source is not None: if not isinstance(model_source, TFLiteModelSource): raise TypeError('Model source must be a TFLiteModelSource object.') self._model_source = model_source # Can be None @property - def size_bytes(self): + def size_bytes(self) -> typing.Optional[Incomplete]: """The size in bytes of the TF Lite model.""" return self._data.get('sizeBytes') - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> typing.Dict[str, typing.Any]: """Returns a serializable representation of the object.""" copy = dict(self._data) if self._model_source: @@ -422,7 +454,7 @@ def as_dict(self, for_upload=False): class TFLiteModelSource: """Abstract base class representing a model source for TFLite format models.""" - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> typing.Dict[str, typing.Any]: """Returns a serializable representation of the object.""" raise NotImplementedError @@ -434,13 +466,13 @@ class _CloudStorageClient: BLOB_NAME = 'Firebase/ML/Models/{0}' @staticmethod - def _assert_gcs_enabled(): + def _assert_gcs_enabled() -> None: if not _GCS_ENABLED: raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') @staticmethod - def _parse_gcs_tflite_uri(uri): + def _parse_gcs_tflite_uri(uri: str) -> typing.Tuple[str, str]: # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. matcher = _GCS_TFLITE_URI_PATTERN.match(uri) @@ -449,25 +481,28 @@ def _parse_gcs_tflite_uri(uri): return matcher.group('bucket_name'), matcher.group('blob_name') @staticmethod - def upload(bucket_name, model_file_name, app): + def upload( + bucket_name: typing.Optional[str], + model_file_name: typing.Union[str, os.PathLike[str]], + app: typing.Optional[firebase_admin.App], + ) -> str: """Upload a model file to the specified Storage bucket.""" _CloudStorageClient._assert_gcs_enabled() - file_name = os.path.basename(model_file_name) bucket = storage.bucket(bucket_name, app=app) blob_name = _CloudStorageClient.BLOB_NAME.format(file_name) - blob = bucket.blob(blob_name) - blob.upload_from_filename(model_file_name) - return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) + blob = bucket.blob(blob_name) # type: ignore[reportUnknownMemberType] + blob.upload_from_filename(model_file_name) # type: ignore[reportUnknownMemberType] + return _CloudStorageClient.GCS_URI.format(bucket.name, blob_name) # type: ignore[reportUnknownMemberType] @staticmethod - def sign_uri(gcs_tflite_uri, app): + def sign_uri(gcs_tflite_uri: str, app: typing.Optional[firebase_admin.App]) -> str: """Makes the gcs_tflite_uri readable for GET for 10 minutes via signed_uri.""" _CloudStorageClient._assert_gcs_enabled() bucket_name, blob_name = _CloudStorageClient._parse_gcs_tflite_uri(gcs_tflite_uri) bucket = storage.bucket(bucket_name, app=app) - blob = bucket.blob(blob_name) - return blob.generate_signed_url( + blob = bucket.blob(blob_name) # type: ignore[reportUnknownMemberType] + return blob.generate_signed_url( # type: ignore[reportUnknownMemberType] version='v4', expiration=datetime.timedelta(minutes=10), method='GET' @@ -479,20 +514,29 @@ class TFLiteGCSModelSource(TFLiteModelSource): _STORAGE_CLIENT = _CloudStorageClient() - def __init__(self, gcs_tflite_uri, app=None): + def __init__( + self, + gcs_tflite_uri: str, + app: typing.Optional[firebase_admin.App] = None, + ) -> None: self._app = app self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self._gcs_tflite_uri == other._gcs_tflite_uri # pylint: disable=protected-access return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @classmethod - def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): + def from_tflite_model_file( + cls, + model_file_name: typing.Union[str, os.PathLike[str]], + bucket_name: typing.Optional[str] = None, + app: typing.Optional[firebase_admin.App] = None, + ) -> "TFLiteGCSModelSource": """Uploads the model file to an existing Google Cloud Storage bucket. Args: @@ -511,7 +555,7 @@ def from_tflite_model_file(cls, model_file_name, bucket_name=None, app=None): return TFLiteGCSModelSource(gcs_tflite_uri=gcs_uri, app=app) @staticmethod - def _assert_tf_enabled(): + def _assert_tf_enabled() -> None: if not _TF_ENABLED: raise ImportError('Failed to import the tensorflow library for Python. Make sure ' 'to install the tensorflow module.') @@ -520,27 +564,32 @@ def _assert_tf_enabled(): .format(tf.version.VERSION)) @staticmethod - def _tf_convert_from_saved_model(saved_model_dir): + def _tf_convert_from_saved_model(saved_model_dir: Incomplete) -> Incomplete: # Same for both v1.x and v2.x - converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) - return converter.convert() + converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # type: ignore[reportUnknownMemberType] + return converter.convert() # type: ignore[reportUnknownMemberType] @staticmethod - def _tf_convert_from_keras_model(keras_model): + def _tf_convert_from_keras_model(keras_model: Incomplete) -> Incomplete: """Converts the given Keras model into a TF Lite model.""" # Version 1.x conversion function takes a model file. Version 2.x takes the model itself. if tf.version.VERSION.startswith('1.'): keras_file = 'firebase_keras_model.h5' - tf.keras.models.save_model(keras_model, keras_file) - converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) + tf.keras.models.save_model(keras_model, keras_file) # type: ignore[reportUnknownMemberType] + converter = tf.lite.TFLiteConverter.from_keras_model_file(keras_file) # type: ignore[reportUnknownMemberType] else: - converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) + converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) # type: ignore[reportUnknownMemberType] - return converter.convert() + return converter.convert() # type: ignore[reportUnknownMemberType] @classmethod - def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tflite', - bucket_name=None, app=None): + def from_saved_model( + cls, + saved_model_dir: Incomplete, + model_file_name: typing.Union[str, os.PathLike[str]] = 'firebase_ml_model.tflite', + bucket_name: typing.Optional[str] = None, + app: typing.Optional[firebase_admin.App] = None, + ) -> "TFLiteGCSModelSource": """Creates a Tensor Flow Lite model from the saved model, and uploads the model to GCS. Args: @@ -563,8 +612,13 @@ def from_saved_model(cls, saved_model_dir, model_file_name='firebase_ml_model.tf return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @classmethod - def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite', - bucket_name=None, app=None): + def from_keras_model( + cls, + keras_model: os.PathLike[str], + model_file_name: str = 'firebase_ml_model.tflite', + bucket_name: typing.Optional[str] = None, + app: typing.Optional[firebase_admin.App] = None, + ) -> "TFLiteGCSModelSource": """Creates a Tensor Flow Lite model from the keras model, and uploads the model to GCS. Args: @@ -587,19 +641,19 @@ def from_keras_model(cls, keras_model, model_file_name='firebase_ml_model.tflite return TFLiteGCSModelSource.from_tflite_model_file(model_file_name, bucket_name, app) @property - def gcs_tflite_uri(self): + def gcs_tflite_uri(self) -> str: """URI of the model file in Cloud Storage.""" return self._gcs_tflite_uri @gcs_tflite_uri.setter - def gcs_tflite_uri(self, gcs_tflite_uri): + def gcs_tflite_uri(self, gcs_tflite_uri: str) -> None: self._gcs_tflite_uri = _validate_gcs_tflite_uri(gcs_tflite_uri) - def _get_signed_gcs_tflite_uri(self): + def _get_signed_gcs_tflite_uri(self) -> str: """Signs the GCS uri, so the model file can be uploaded to Firebase ML and verified.""" return TFLiteGCSModelSource._STORAGE_CLIENT.sign_uri(self._gcs_tflite_uri, self._app) - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> typing.Dict[str, typing.Any]: """Returns a serializable representation of the object.""" if for_upload: return {'gcsTfliteUri': self._get_signed_gcs_tflite_uri()} @@ -613,30 +667,30 @@ class TFLiteAutoMlSource(TFLiteModelSource): AutoML model support is deprecated and will be removed in the next major version. """ - def __init__(self, auto_ml_model, app=None): + def __init__(self, auto_ml_model: str, app: typing.Optional[firebase_admin.App] = None) -> None: warnings.warn('AutoML model support is deprecated and will be removed in the next ' 'major version.', DeprecationWarning) self._app = app self.auto_ml_model = auto_ml_model - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return self.auto_ml_model == other.auto_ml_model return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property - def auto_ml_model(self): + def auto_ml_model(self) -> str: """Resource name of the model, created by the AutoML API or Cloud console.""" return self._auto_ml_model @auto_ml_model.setter - def auto_ml_model(self, auto_ml_model): + def auto_ml_model(self, auto_ml_model: str) -> None: self._auto_ml_model = _validate_auto_ml_model(auto_ml_model) - def as_dict(self, for_upload=False): + def as_dict(self, for_upload: bool = False) -> typing.Dict[str, typing.Any]: """Returns a serializable representation of the object.""" # Upload is irrelevant for auto_ml models return {'automlModel': self._auto_ml_model} @@ -650,7 +704,14 @@ class ListModelsPage: ``iterate_all()`` can be used to iterate through all the models in the Firebase project starting from this page. """ - def __init__(self, list_models_func, list_filter, page_size, page_token, app): + def __init__( + self, + list_models_func: _DownloadCallback, + list_filter: typing.Optional[str], + page_size: typing.Optional[int], + page_token: typing.Optional[str], + app: typing.Optional[firebase_admin.App], + ) -> None: self._list_models_func = list_models_func self._list_filter = list_filter self._page_size = page_size @@ -659,28 +720,32 @@ def __init__(self, list_models_func, list_filter, page_size, page_token, app): self._list_response = list_models_func(list_filter, page_size, page_token) @property - def models(self): + def models(self) -> typing.List[Model]: """A list of Models from this page.""" return [ - Model.from_dict(model, app=self._app) for model in self._list_response.get('models', []) + Model.from_dict(model, app=self._app) + for model in typing.cast( + typing.List[typing.Dict[str, _typing.Json]], + self._list_response.get('models', []), + ) ] @property - def list_filter(self): + def list_filter(self) -> typing.Optional[str]: """The filter string used to filter the models.""" return self._list_filter @property - def next_page_token(self): + def next_page_token(self) -> str: """Token identifying the next page of results.""" - return self._list_response.get('nextPageToken', '') + return typing.cast(str, self._list_response.get('nextPageToken', '')) @property - def has_next_page(self): + def has_next_page(self) -> bool: """True if more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> typing.Optional["ListModelsPage"]: """Retrieves the next page of models if available. Returns: @@ -695,7 +760,7 @@ def get_next_page(self): self._app) return None - def iterate_all(self): + def iterate_all(self) -> "_ModelIterator": """Retrieves an iterator for Models. Returned iterator will iterate through all the models in the Firebase @@ -715,16 +780,16 @@ class _ModelIterator: When the whole page has been traversed, it loads another page. This class never keeps more than one page of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: ListModelsPage) -> None: if not isinstance(current_page, ListModelsPage): raise TypeError('Current page must be a ListModelsPage') self._current_page = current_page - self._index = 0 + self._index: int = 0 - def next(self): + def next(self) -> Model: if self._index == len(self._current_page.models): if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() + self._current_page = typing.cast(ListModelsPage, self._current_page.get_next_page()) self._index = 0 if self._index < len(self._current_page.models): result = self._current_page.models[self._index] @@ -732,14 +797,14 @@ def next(self): return result raise StopIteration - def __next__(self): + def __next__(self) -> Model: return self.next() - def __iter__(self): + def __iter__(self) -> typing.Iterator[Model]: return self -def _validate_and_parse_name(name): +def _validate_and_parse_name(name: str) -> typing.Tuple[str, str]: # The resource name is added automatically from API call responses. # The only way it could be invalid is if someone tries to # create a model from a dictionary manually and does it incorrectly. @@ -749,65 +814,66 @@ def _validate_and_parse_name(name): return matcher.group('project_id'), matcher.group('model_id') -def _validate_model(model, update_mask=None): +def _validate_model(model: Model, update_mask: typing.Optional[str] = None) -> None: if not isinstance(model, Model): raise TypeError('Model must be an ml.Model.') if update_mask is None and not model.display_name: raise ValueError('Model must have a display name.') -def _validate_model_id(model_id): +def _validate_model_id(model_id: str) -> None: if not _MODEL_ID_PATTERN.match(model_id): raise ValueError('Model ID format is invalid.') -def _validate_operation_name(op_name): +def _validate_operation_name(op_name: str) -> str: if not _OPERATION_NAME_PATTERN.match(op_name): raise ValueError('Operation name format is invalid.') return op_name -def _validate_display_name(display_name): +def _validate_display_name(display_name: str) -> str: if not _DISPLAY_NAME_PATTERN.match(display_name): raise ValueError('Display name format is invalid.') return display_name -def _validate_tags(tags): +def _validate_tags(tags: typing.Any) -> typing.List[str]: if not isinstance(tags, list) or not \ - all(isinstance(tag, str) for tag in tags): + all(isinstance(tag, str) for tag in tags): # type: ignore[reportUnknownVariableType] raise TypeError('Tags must be a list of strings.') + tags = typing.cast(typing.List[str], tags) if not all(_TAG_PATTERN.match(tag) for tag in tags): raise ValueError('Tag format is invalid.') return tags -def _validate_gcs_tflite_uri(uri): +def _validate_gcs_tflite_uri(uri: str) -> str: # GCS Bucket naming rules are complex. The regex is not comprehensive. # See https://cloud.google.com/storage/docs/naming for full details. if not _GCS_TFLITE_URI_PATTERN.match(uri): raise ValueError('GCS TFLite URI format is invalid.') return uri -def _validate_auto_ml_model(model): +def _validate_auto_ml_model(model: str) -> str: if not _AUTO_ML_MODEL_PATTERN.match(model): raise ValueError('Model resource name format is invalid.') return model -def _validate_model_format(model_format): +def _validate_model_format(model_format: typing.Any) -> ModelFormat: if not isinstance(model_format, ModelFormat): raise TypeError('Model format must be a ModelFormat object.') return model_format -def _validate_list_filter(list_filter): +def _validate_list_filter(list_filter: typing.Optional[str]) -> None: if list_filter is not None: if not isinstance(list_filter, str): raise TypeError('List filter must be a string or None.') -def _validate_page_size(page_size): +def _validate_page_size(page_size: typing.Optional[int]) -> None: if page_size is not None: if type(page_size) is not int: # pylint: disable=unidiomatic-typecheck # Specifically type() to disallow boolean which is a subtype of int @@ -817,7 +883,7 @@ def _validate_page_size(page_size): '1 and {0}'.format(_MAX_PAGE_SIZE)) -def _validate_page_token(page_token): +def _validate_page_token(page_token: typing.Optional[str]) -> None: if page_token is not None: if not isinstance(page_token, str): raise TypeError('Page token must be a string or None.') @@ -831,7 +897,7 @@ class _MLService: POLL_EXPONENTIAL_BACKOFF_FACTOR = 1.5 POLL_BASE_WAIT_TIME_SECONDS = 3 - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: self._project_id = app.project_id if not self._project_id: raise ValueError( @@ -850,14 +916,14 @@ def __init__(self, app): headers=ml_headers, base_url=_MLService.OPERATION_URL) - def get_operation(self, op_name): + def get_operation(self, op_name: str) -> typing.Dict[str, _typing.Json]: _validate_operation_name(op_name) try: return self._operation_client.body('get', url=op_name) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def _exponential_backoff(self, current_attempt, stop_time): + def _exponential_backoff(self, current_attempt: int, stop_time: typing.Optional[datetime.datetime]) -> None: """Sleeps for the appropriate amount of time. Or throws deadline exceeded.""" delay_factor = pow(_MLService.POLL_EXPONENTIAL_BACKOFF_FACTOR, current_attempt) wait_time_seconds = delay_factor * _MLService.POLL_BASE_WAIT_TIME_SECONDS @@ -869,7 +935,12 @@ def _exponential_backoff(self, current_attempt, stop_time): wait_time_seconds = min(wait_time_seconds, max_seconds_left - 1) time.sleep(wait_time_seconds) - def handle_operation(self, operation, wait_for_operation=False, max_time_seconds=None): + def handle_operation( + self, + operation: typing.Dict[str, _typing.Json], + wait_for_operation: bool = False, + max_time_seconds: typing.Optional[float] = None, + ) -> typing.Dict[str, typing.Any]: """Handles long running operations. Args: @@ -894,17 +965,18 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds if operation.get('done'): # Operations which are immediately done don't have an operation name if operation.get('response'): - return operation.get('response') + return typing.cast(typing.Dict[str, _typing.Json], operation['response']) if operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) + error = typing.cast(typing.Dict[str, typing.Any], operation['error']) + raise _utils.handle_operation_error(error) raise exceptions.UnknownError(message='Internal Error: Malformed Operation.') - op_name = _validate_operation_name(operation.get('name')) - metadata = operation.get('metadata', {}) + op_name = _validate_operation_name(typing.cast(str, operation['name'])) + metadata = typing.cast(typing.Dict[str, typing.Any], operation.get('metadata', {})) metadata_type = metadata.get('@type', '') if not metadata_type.endswith('ModelOperationMetadata'): raise TypeError('Unknown type of operation metadata.') - _, model_id = _validate_and_parse_name(metadata.get('name')) + _, model_id = _validate_and_parse_name(metadata['name']) current_attempt = 0 start_time = datetime.datetime.now() stop_time = (None if max_time_seconds is None else @@ -918,15 +990,16 @@ def handle_operation(self, operation, wait_for_operation=False, max_time_seconds if operation.get('done'): if operation.get('response'): - return operation.get('response') + return typing.cast(typing.Dict[str, _typing.Json], operation['response']) if operation.get('error'): - raise _utils.handle_operation_error(operation.get('error')) + error = typing.cast(typing.Dict[str, typing.Any], operation['error']) + raise _utils.handle_operation_error(error) # If the operation is not complete or timed out, return a (locked) model instead return get_model(model_id).as_dict() - def create_model(self, model): + def create_model(self, model: Model) -> typing.Dict[str, typing.Any]: _validate_model(model) try: return self.handle_operation( @@ -934,7 +1007,7 @@ def create_model(self, model): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def update_model(self, model, update_mask=None): + def update_model(self, model: Model, update_mask: typing.Optional[str] = None) -> typing.Dict[str, typing.Any]: _validate_model(model, update_mask) path = 'models/{0}'.format(model.model_id) if update_mask is not None: @@ -945,7 +1018,7 @@ def update_model(self, model, update_mask=None): except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def set_published(self, model_id, publish): + def set_published(self, model_id: str, publish: bool) -> typing.Dict[str, typing.Any]: _validate_model_id(model_id) model_name = 'projects/{0}/models/{1}'.format(self._project_id, model_id) model = Model.from_dict({ @@ -956,19 +1029,24 @@ def set_published(self, model_id, publish): }) return self.update_model(model, update_mask='state.published') - def get_model(self, model_id): + def get_model(self, model_id: str) -> typing.Dict[str, _typing.Json]: _validate_model_id(model_id) try: return self._client.body('get', url='models/{0}'.format(model_id)) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def list_models(self, list_filter, page_size, page_token): + def list_models( + self, + list_filter: typing.Optional[str], + page_size: typing.Optional[int], + page_token: typing.Optional[str], + ) -> typing.Dict[str, _typing.Json]: """ lists Firebase ML models.""" _validate_list_filter(list_filter) _validate_page_size(page_size) _validate_page_token(page_token) - params = {} + params: typing.Dict[str, typing.Any] = {} if list_filter: params['filter'] = list_filter if page_size: @@ -977,14 +1055,14 @@ def list_models(self, list_filter, page_size, page_token): params['page_token'] = page_token path = 'models' if params: - param_str = parse.urlencode(sorted(params.items()), True) + param_str = urllib.parse.urlencode(sorted(params.items()), True) path = path + '?' + param_str try: return self._client.body('get', url=path) except requests.exceptions.RequestException as error: raise _utils.handle_platform_error_from_requests(error) - def delete_model(self, model_id): + def delete_model(self, model_id: str) -> None: _validate_model_id(model_id) try: self._client.body('delete', url='models/{0}'.format(model_id)) diff --git a/firebase_admin/project_management.py b/firebase_admin/project_management.py index ed292b80..2e50ac2a 100644 --- a/firebase_admin/project_management.py +++ b/firebase_admin/project_management.py @@ -20,23 +20,28 @@ import base64 import re import time +import typing import requests import firebase_admin from firebase_admin import exceptions from firebase_admin import _http_client +from firebase_admin import _typing from firebase_admin import _utils +_T = typing.TypeVar("_T") +_AppMetadataT = typing.TypeVar("_AppMetadataT", bound="_AppMetadata") + _PROJECT_MANAGEMENT_ATTRIBUTE = '_project_management' -def _get_project_management_service(app): +def _get_project_management_service(app: typing.Optional["firebase_admin.App"]) -> "_ProjectManagementService": return _utils.get_app_service(app, _PROJECT_MANAGEMENT_ATTRIBUTE, _ProjectManagementService) -def android_app(app_id, app=None): +def android_app(app_id: str, app: typing.Optional["firebase_admin.App"] = None) -> "AndroidApp": """Obtains a reference to an Android app in the associated Firebase project. Args: @@ -49,7 +54,7 @@ def android_app(app_id, app=None): return AndroidApp(app_id=app_id, service=_get_project_management_service(app)) -def ios_app(app_id, app=None): +def ios_app(app_id: str, app: typing.Optional["firebase_admin.App"] = None) -> "IOSApp": """Obtains a reference to an iOS app in the associated Firebase project. Args: @@ -62,7 +67,7 @@ def ios_app(app_id, app=None): return IOSApp(app_id=app_id, service=_get_project_management_service(app)) -def list_android_apps(app=None): +def list_android_apps(app: typing.Optional["firebase_admin.App"] = None) -> typing.List["AndroidApp"]: """Lists all Android apps in the associated Firebase project. Args: @@ -75,7 +80,7 @@ def list_android_apps(app=None): return _get_project_management_service(app).list_android_apps() -def list_ios_apps(app=None): +def list_ios_apps(app: typing.Optional["firebase_admin.App"] = None) -> typing.List["IOSApp"]: """Lists all iOS apps in the associated Firebase project. Args: @@ -87,7 +92,11 @@ def list_ios_apps(app=None): return _get_project_management_service(app).list_ios_apps() -def create_android_app(package_name, display_name=None, app=None): +def create_android_app( + package_name: str, + display_name: typing.Optional[str] = None, + app: typing.Optional["firebase_admin.App"] = None, +) -> "AndroidApp": """Creates a new Android app in the associated Firebase project. Args: @@ -101,7 +110,11 @@ def create_android_app(package_name, display_name=None, app=None): return _get_project_management_service(app).create_android_app(package_name, display_name) -def create_ios_app(bundle_id, display_name=None, app=None): +def create_ios_app( + bundle_id: str, + display_name: typing.Optional[str] = None, + app: typing.Optional["firebase_admin.App"] = None, +) -> "IOSApp": """Creates a new iOS app in the associated Firebase project. Args: @@ -115,25 +128,29 @@ def create_ios_app(bundle_id, display_name=None, app=None): return _get_project_management_service(app).create_ios_app(bundle_id, display_name) -def _check_is_string_or_none(obj, field_name): +def _check_is_string_or_none(obj: typing.Any, field_name: str) -> typing.Optional[str]: if obj is None or isinstance(obj, str): return obj raise ValueError('{0} must be a string.'.format(field_name)) -def _check_is_nonempty_string(obj, field_name): +def _check_is_nonempty_string(obj: typing.Any, field_name: str) -> str: if isinstance(obj, str) and obj: return obj raise ValueError('{0} must be a non-empty string.'.format(field_name)) -def _check_is_nonempty_string_or_none(obj, field_name): +def _check_is_nonempty_string_or_none(obj: typing.Any, field_name: str) -> typing.Optional[str]: if obj is None: return None return _check_is_nonempty_string(obj, field_name) -def _check_not_none(obj, field_name): +@typing.overload +def _check_not_none(obj: None, field_name: str) -> typing.NoReturn: ... +@typing.overload +def _check_not_none(obj: _T, field_name: str) -> _T: ... +def _check_not_none(obj: typing.Optional[_T], field_name: str) -> _T: if obj is None: raise ValueError('{0} cannot be None.'.format(field_name)) return obj @@ -148,12 +165,12 @@ class AndroidApp: instead of instantiating it directly. """ - def __init__(self, app_id, service): + def __init__(self, app_id: str, service: "_ProjectManagementService") -> None: self._app_id = app_id self._service = service @property - def app_id(self): + def app_id(self) -> str: """Returns the app ID of the Android app to which this instance refers. Note: This method does not make an RPC. @@ -163,7 +180,7 @@ def app_id(self): """ return self._app_id - def get_metadata(self): + def get_metadata(self) -> "AndroidAppMetadata": """Retrieves detailed information about this Android app. Returns: @@ -175,7 +192,7 @@ def get_metadata(self): """ return self._service.get_android_app_metadata(self._app_id) - def set_display_name(self, new_display_name): + def set_display_name(self, new_display_name: typing.Optional[str]) -> None: """Updates the display name attribute of this Android app to the one given. Args: @@ -188,13 +205,13 @@ def set_display_name(self, new_display_name): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ - return self._service.set_android_app_display_name(self._app_id, new_display_name) + self._service.set_android_app_display_name(self._app_id, new_display_name) - def get_config(self): + def get_config(self) -> str: """Retrieves the configuration artifact associated with this Android app.""" return self._service.get_android_app_config(self._app_id) - def get_sha_certificates(self): + def get_sha_certificates(self) -> typing.List["SHACertificate"]: """Retrieves the entire list of SHA certificates associated with this Android app. Returns: @@ -206,7 +223,7 @@ def get_sha_certificates(self): """ return self._service.get_sha_certificates(self._app_id) - def add_sha_certificate(self, certificate_to_add): + def add_sha_certificate(self, certificate_to_add: "SHACertificate") -> None: """Adds a SHA certificate to this Android app. Args: @@ -219,9 +236,9 @@ def add_sha_certificate(self, certificate_to_add): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_add already exists.) """ - return self._service.add_sha_certificate(self._app_id, certificate_to_add) + self._service.add_sha_certificate(self._app_id, certificate_to_add) - def delete_sha_certificate(self, certificate_to_delete): + def delete_sha_certificate(self, certificate_to_delete: "SHACertificate") -> None: """Removes a SHA certificate from this Android app. Args: @@ -234,7 +251,7 @@ def delete_sha_certificate(self, certificate_to_delete): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. (For example, if the certificate_to_delete is not found.) """ - return self._service.delete_sha_certificate(certificate_to_delete) + self._service.delete_sha_certificate(certificate_to_delete) class IOSApp: @@ -246,12 +263,12 @@ class IOSApp: instead of instantiating it directly. """ - def __init__(self, app_id, service): + def __init__(self, app_id: str, service: "_ProjectManagementService") -> None: self._app_id = app_id self._service = service @property - def app_id(self): + def app_id(self) -> str: """Returns the app ID of the iOS app to which this instance refers. Note: This method does not make an RPC. @@ -261,7 +278,7 @@ def app_id(self): """ return self._app_id - def get_metadata(self): + def get_metadata(self) -> "IOSAppMetadata": """Retrieves detailed information about this iOS app. Returns: @@ -273,7 +290,7 @@ def get_metadata(self): """ return self._service.get_ios_app_metadata(self._app_id) - def set_display_name(self, new_display_name): + def set_display_name(self, new_display_name: typing.Optional[str]) -> None: """Updates the display name attribute of this iOS app to the one given. Args: @@ -286,9 +303,9 @@ def set_display_name(self, new_display_name): FirebaseError: If an error occurs while communicating with the Firebase Project Management Service. """ - return self._service.set_ios_app_display_name(self._app_id, new_display_name) + self._service.set_ios_app_display_name(self._app_id, new_display_name) - def get_config(self): + def get_config(self) -> str: """Retrieves the configuration artifact associated with this iOS app.""" return self._service.get_ios_app_config(self._app_id) @@ -296,7 +313,7 @@ def get_config(self): class _AppMetadata: """Detailed information about a Firebase Android or iOS app.""" - def __init__(self, name, app_id, display_name, project_id): + def __init__(self, name: str, app_id: str, display_name: typing.Optional[str], project_id: str) -> None: # _name is the fully qualified resource name of this Android or iOS app; currently it is not # exposed to client code. self._name = _check_is_nonempty_string(name, 'name') @@ -305,7 +322,7 @@ def __init__(self, name, app_id, display_name, project_id): self._project_id = _check_is_nonempty_string(project_id, 'project_id') @property - def app_id(self): + def app_id(self) -> str: """The globally unique, Firebase-assigned identifier of this Android or iOS app. This ID is unique even across apps of different platforms. @@ -313,18 +330,18 @@ def app_id(self): return self._app_id @property - def display_name(self): + def display_name(self) -> typing.Optional[str]: """The user-assigned display name of this Android or iOS app. Note that the display name can be None if it has never been set by the user.""" return self._display_name @property - def project_id(self): + def project_id(self) -> str: """The permanent, globally unique, user-assigned ID of the parent Firebase project.""" return self._project_id - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, type(self)): return False # pylint: disable=protected-access @@ -336,24 +353,31 @@ def __eq__(self, other): class AndroidAppMetadata(_AppMetadata): """Android-specific information about an Android Firebase app.""" - def __init__(self, package_name, name, app_id, display_name, project_id): + def __init__( + self, + package_name: str, + name: str, + app_id: str, + display_name: typing.Optional[str], + project_id: str, + ) -> None: """Clients should not instantiate this class directly.""" super(AndroidAppMetadata, self).__init__(name, app_id, display_name, project_id) self._package_name = _check_is_nonempty_string(package_name, 'package_name') @property - def package_name(self): + def package_name(self) -> str: """The canonical package name of this Android app as it would appear in the Play Store.""" return self._package_name - def __eq__(self, other): + def __eq__(self, other: typing.Any) -> bool: return (super(AndroidAppMetadata, self).__eq__(other) and self.package_name == other.package_name) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash( (self._name, self.app_id, self.display_name, self.project_id, self.package_name)) @@ -361,23 +385,30 @@ def __hash__(self): class IOSAppMetadata(_AppMetadata): """iOS-specific information about an iOS Firebase app.""" - def __init__(self, bundle_id, name, app_id, display_name, project_id): + def __init__( + self, + bundle_id: str, + name: str, + app_id: str, + display_name: typing.Optional[str], + project_id: str, + ) -> None: """Clients should not instantiate this class directly.""" super(IOSAppMetadata, self).__init__(name, app_id, display_name, project_id) self._bundle_id = _check_is_nonempty_string(bundle_id, 'bundle_id') @property - def bundle_id(self): + def bundle_id(self) -> str: """The canonical bundle ID of this iOS app as it would appear in the iOS AppStore.""" return self._bundle_id - def __eq__(self, other): + def __eq__(self, other: typing.Any) -> bool: return super(IOSAppMetadata, self).__eq__(other) and self.bundle_id == other.bundle_id - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self._name, self.app_id, self.display_name, self.project_id, self.bundle_id)) @@ -390,7 +421,7 @@ class SHACertificate: _SHA_1_RE = re.compile('^[0-9A-Fa-f]{40}$') _SHA_256_RE = re.compile('^[0-9A-Fa-f]{64}$') - def __init__(self, sha_hash, name=None): + def __init__(self, sha_hash: str, name: typing.Optional[str] = None) -> None: """Creates a new SHACertificate instance. Args: @@ -415,7 +446,7 @@ def __init__(self, sha_hash, name=None): 'The supplied certificate hash is neither a valid SHA-1 nor SHA_256 hash.') @property - def name(self): + def name(self) -> typing.Optional[str]: """Returns the fully qualified resource name of this certificate, if known. Returns: @@ -425,7 +456,7 @@ def name(self): return self._name @property - def sha_hash(self): + def sha_hash(self) -> str: """Returns the certificate hash. Returns: @@ -434,7 +465,7 @@ def sha_hash(self): return self._sha_hash @property - def cert_type(self): + def cert_type(self) -> str: """Returns the type of the SHA certificate encoded in the hash. Returns: @@ -442,16 +473,16 @@ def cert_type(self): """ return self._cert_type - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not isinstance(other, SHACertificate): return False return (self.name == other.name and self.sha_hash == other.sha_hash and self.cert_type == other.cert_type) - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.sha_hash, self.cert_type)) @@ -469,7 +500,7 @@ class _ProjectManagementService: IOS_APPS_RESOURCE_NAME = 'iosApps' IOS_APP_IDENTIFIER_NAME = 'bundleId' - def __init__(self, app): + def __init__(self, app: "firebase_admin.App") -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -485,25 +516,31 @@ def __init__(self, app): headers={'X-Client-Version': version_header}, timeout=timeout) - def get_android_app_metadata(self, app_id): + def get_android_app_metadata(self, app_id: str) -> AndroidAppMetadata: return self._get_app_metadata( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, metadata_class=AndroidAppMetadata, app_id=app_id) - def get_ios_app_metadata(self, app_id): + def get_ios_app_metadata(self, app_id: str) -> IOSAppMetadata: return self._get_app_metadata( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, metadata_class=IOSAppMetadata, app_id=app_id) - def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_class, app_id): + def _get_app_metadata( + self, + platform_resource_name: str, + identifier_name: str, + metadata_class: "_typing.AppMetadataSubclass[_AppMetadataT]", + app_id: str, + ) -> _AppMetadataT: """Retrieves detailed information about an Android or iOS app.""" _check_is_nonempty_string(app_id, 'app_id') path = '/v1beta1/projects/-/{0}/{1}'.format(platform_resource_name, app_id) - response = self._make_request('get', path) + response = typing.cast(typing.Dict[str, typing.Any], self._make_request('get', path)) return metadata_class( response[identifier_name], name=response['name'], @@ -511,45 +548,45 @@ def _get_app_metadata(self, platform_resource_name, identifier_name, metadata_cl display_name=response.get('displayName') or None, project_id=response['projectId']) - def set_android_app_display_name(self, app_id, new_display_name): + def set_android_app_display_name(self, app_id: str, new_display_name: typing.Optional[str]) -> None: self._set_display_name( app_id=app_id, new_display_name=new_display_name, platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME) - def set_ios_app_display_name(self, app_id, new_display_name): + def set_ios_app_display_name(self, app_id: str, new_display_name: typing.Optional[str]) -> None: self._set_display_name( app_id=app_id, new_display_name=new_display_name, platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME) - def _set_display_name(self, app_id, new_display_name, platform_resource_name): + def _set_display_name(self, app_id: str, new_display_name: typing.Optional[str], platform_resource_name: str) -> None: """Sets the display name of an Android or iOS app.""" path = '/v1beta1/projects/-/{0}/{1}?updateMask=displayName'.format( platform_resource_name, app_id) request_body = {'displayName': new_display_name} self._make_request('patch', path, json=request_body) - def list_android_apps(self): + def list_android_apps(self) -> typing.List[AndroidApp]: return self._list_apps( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, app_class=AndroidApp) - def list_ios_apps(self): + def list_ios_apps(self) -> typing.List[IOSApp]: return self._list_apps( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_class=IOSApp) - def _list_apps(self, platform_resource_name, app_class): + def _list_apps(self, platform_resource_name: str, app_class: "_typing.ProjectApp[_T]") -> typing.List[_T]: """Lists all the Android or iOS apps within the Firebase project.""" path = '/v1beta1/projects/{0}/{1}?pageSize={2}'.format( self._project_id, platform_resource_name, _ProjectManagementService.MAXIMUM_LIST_APPS_PAGE_SIZE) response = self._make_request('get', path) - apps_list = [] + apps_list: typing.List[_T] = [] while True: - apps = response.get('apps') + apps = typing.cast(typing.List[typing.Dict[str, typing.Any]], response.get('apps')) if not apps: break apps_list.extend(app_class(app_id=app['appId'], service=self) for app in apps) @@ -565,7 +602,7 @@ def _list_apps(self, platform_resource_name, app_class): response = self._make_request('get', path) return apps_list - def create_android_app(self, package_name, display_name=None): + def create_android_app(self, package_name: str, display_name: typing.Optional[str] = None) -> AndroidApp: return self._create_app( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.ANDROID_APP_IDENTIFIER_NAME, @@ -573,7 +610,7 @@ def create_android_app(self, package_name, display_name=None): display_name=display_name, app_class=AndroidApp) - def create_ios_app(self, bundle_id, display_name=None): + def create_ios_app(self, bundle_id: str, display_name: typing.Optional[str] = None) -> IOSApp: return self._create_app( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, identifier_name=_ProjectManagementService.IOS_APP_IDENTIFIER_NAME, @@ -582,12 +619,13 @@ def create_ios_app(self, bundle_id, display_name=None): app_class=IOSApp) def _create_app( - self, - platform_resource_name, - identifier_name, - identifier, - display_name, - app_class): + self, + platform_resource_name: str, + identifier_name: str, + identifier: str, + display_name: typing.Optional[str], + app_class: "_typing.ProjectApp[_T]", + ) -> _T: """Creates an Android or iOS app.""" _check_is_string_or_none(display_name, 'display_name') path = '/v1beta1/projects/{0}/{1}'.format(self._project_id, platform_resource_name) @@ -599,7 +637,7 @@ def _create_app( poll_response = self._poll_app_creation(operation_name) return app_class(app_id=poll_response['appId'], service=self) - def _poll_app_creation(self, operation_name): + def _poll_app_creation(self, operation_name: object) -> typing.Dict[str, typing.Any]: """Polls the Long-Running Operation repeatedly until it is done with exponential backoff.""" for current_attempt in range(_ProjectManagementService.MAXIMUM_POLLING_ATTEMPTS): delay_factor = pow( @@ -607,10 +645,11 @@ def _poll_app_creation(self, operation_name): wait_time_seconds = delay_factor * _ProjectManagementService.POLL_BASE_WAIT_TIME_SECONDS time.sleep(wait_time_seconds) path = '/v1/{0}'.format(operation_name) + poll_response: typing.Dict[str, typing.Any] poll_response, http_response = self._body_and_response('get', path) done = poll_response.get('done') if done: - response = poll_response.get('response') + response: typing.Optional[typing.Dict[str, typing.Any]] = poll_response.get('response') if response: return response @@ -619,45 +658,56 @@ def _poll_app_creation(self, operation_name): http_response=http_response) raise exceptions.DeadlineExceededError('Polling deadline exceeded.') - def get_android_app_config(self, app_id): + def get_android_app_config(self, app_id: str) -> str: return self._get_app_config( platform_resource_name=_ProjectManagementService.ANDROID_APPS_RESOURCE_NAME, app_id=app_id) - def get_ios_app_config(self, app_id): + def get_ios_app_config(self, app_id: str) -> str: return self._get_app_config( platform_resource_name=_ProjectManagementService.IOS_APPS_RESOURCE_NAME, app_id=app_id) - def _get_app_config(self, platform_resource_name, app_id): + def _get_app_config(self, platform_resource_name: str, app_id: str) -> str: path = '/v1beta1/projects/-/{0}/{1}/config'.format(platform_resource_name, app_id) response = self._make_request('get', path) # In Python 2.7, the base64 module works with strings, while in Python 3, it works with # bytes objects. This line works in both versions. - return base64.standard_b64decode(response['configFileContents']).decode(encoding='utf-8') + content = typing.cast(str, response['configFileContents']) + return base64.standard_b64decode(content).decode(encoding='utf-8') - def get_sha_certificates(self, app_id): + def get_sha_certificates(self, app_id: str) -> typing.List[SHACertificate]: path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) - response = self._make_request('get', path) - cert_list = response.get('certificates') or [] + response: typing.Dict[str, typing.Any] = self._make_request('get', path) + cert_list: typing.List[typing.Dict[str, typing.Any]] = response.get('certificates') or [] return [SHACertificate(sha_hash=cert['shaHash'], name=cert['name']) for cert in cert_list] - def add_sha_certificate(self, app_id, certificate_to_add): + def add_sha_certificate(self, app_id: str, certificate_to_add: SHACertificate) -> None: path = '/v1beta1/projects/-/androidApps/{0}/sha'.format(app_id) sha_hash = _check_not_none(certificate_to_add, 'certificate_to_add').sha_hash cert_type = certificate_to_add.cert_type request_body = {'shaHash': sha_hash, 'certType': cert_type} self._make_request('post', path, json=request_body) - def delete_sha_certificate(self, certificate_to_delete): + def delete_sha_certificate(self, certificate_to_delete: SHACertificate) -> None: name = _check_not_none(certificate_to_delete, 'certificate_to_delete').name path = '/v1beta1/{0}'.format(name) self._make_request('delete', path) - def _make_request(self, method, url, json=None): + def _make_request( + self, + method: str, + url: str, + json: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> typing.Dict[str, '_typing.Json']: body, _ = self._body_and_response(method, url, json) return body - def _body_and_response(self, method, url, json=None): + def _body_and_response( + self, + method: str, + url: str, + json: typing.Optional[typing.Dict[str, typing.Any]] = None, + ) -> typing.Tuple[typing.Dict[str, '_typing.Json'], requests.Response]: try: return self._client.body_and_response(method=method, url=url, json=json) except requests.exceptions.RequestException as error: diff --git a/firebase_admin/remote_config.py b/firebase_admin/remote_config.py index 943141cc..5fb3f78a 100644 --- a/firebase_admin/remote_config.py +++ b/firebase_admin/remote_config.py @@ -20,13 +20,19 @@ import json import logging import threading -from typing import Dict, Optional, Literal, Union, Any +import typing from enum import Enum import re import hashlib + import requests -from firebase_admin import App, _http_client, _utils + import firebase_admin +from firebase_admin import _http_client +from firebase_admin import _typing +from firebase_admin import _utils +from firebase_admin import exceptions + # Set up logging (you can customize the level and output) logging.basicConfig(level=logging.INFO) @@ -34,7 +40,8 @@ _REMOTE_CONFIG_ATTRIBUTE = '_remoteconfig' MAX_CONDITION_RECURSION_DEPTH = 10 -ValueSource = Literal['default', 'remote', 'static'] # Define the ValueSource type +ValueSource = typing.Literal['default', 'remote', 'static'] # Define the ValueSource type + class PercentConditionOperator(Enum): """Enum representing the available operators for percent conditions. @@ -44,6 +51,7 @@ class PercentConditionOperator(Enum): BETWEEN = "BETWEEN" UNKNOWN = "UNKNOWN" + class CustomSignalOperator(Enum): """Enum representing the available operators for custom signal conditions. """ @@ -65,9 +73,10 @@ class CustomSignalOperator(Enum): SEMANTIC_VERSION_GREATER_EQUAL = "SEMANTIC_VERSION_GREATER_EQUAL" UNKNOWN = "UNKNOWN" + class _ServerTemplateData: """Parses, validates and encapsulates template data and metadata.""" - def __init__(self, template_data): + def __init__(self, template_data: typing.Dict[str, typing.Any]) -> None: """Initializes a new ServerTemplateData instance. Args: @@ -78,7 +87,7 @@ def __init__(self, template_data): """ if 'parameters' in template_data: if template_data['parameters'] is not None: - self._parameters = template_data['parameters'] + self._parameters: typing.Dict[str, typing.Dict[str, typing.Any]] = template_data['parameters'] else: raise ValueError('Remote Config parameters must be a non-null object') else: @@ -86,32 +95,32 @@ def __init__(self, template_data): if 'conditions' in template_data: if template_data['conditions'] is not None: - self._conditions = template_data['conditions'] + self._conditions: typing.List[typing.Dict[str, typing.Any]] = template_data['conditions'] else: raise ValueError('Remote Config conditions must be a non-null object') else: self._conditions = [] - self._version = '' + self._version: str = '' if 'version' in template_data: self._version = template_data['version'] - self._etag = '' + self._etag: str = '' if 'etag' in template_data and isinstance(template_data['etag'], str): self._etag = template_data['etag'] self._template_data_json = json.dumps(template_data) @property - def parameters(self): + def parameters(self) -> typing.Dict[str, typing.Dict[str, typing.Any]]: return self._parameters @property - def etag(self): + def etag(self) -> str: return self._etag @property - def version(self): + def version(self) -> str: return self._version @property @@ -119,13 +128,17 @@ def conditions(self): return self._conditions @property - def template_data_json(self): + def template_data_json(self) -> str: return self._template_data_json class ServerTemplate: """Represents a Server Template with implementations for loading and evaluating the template.""" - def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = None): + def __init__( + self, + app: typing.Optional[firebase_admin.App] = None, + default_config: typing.Optional[typing.Dict[str, str]] = None, + ) -> None: """Initializes a ServerTemplate instance. Args: @@ -137,8 +150,8 @@ def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = N _REMOTE_CONFIG_ATTRIBUTE, _RemoteConfigService) # This gets set when the template is # fetched from RC servers via the load API, or via the set API. - self._cache = None - self._stringified_default_config: Dict[str, str] = {} + self._cache: typing.Optional[_ServerTemplateData] = None + self._stringified_default_config: typing.Dict[str, str] = {} self._lock = threading.RLock() # RC stores all remote values as string, but it's more intuitive @@ -148,13 +161,13 @@ def __init__(self, app: App = None, default_config: Optional[Dict[str, str]] = N for key in default_config: self._stringified_default_config[key] = str(default_config[key]) - async def load(self): + async def load(self) -> None: """Fetches the server template and caches the data.""" rc_server_template = await self._rc_service.get_server_template() with self._lock: self._cache = rc_server_template - def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'ServerConfig': + def evaluate(self, context: typing.Optional[typing.Dict[str, typing.Union[str, int]]] = None) -> 'ServerConfig': """Evaluates the cached server template to produce a ServerConfig. Args: @@ -170,14 +183,14 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser raise ValueError("""No Remote Config Server template in cache. Call load() before calling evaluate().""") context = context or {} - config_values = {} + config_values: typing.Dict[str, _Value] = {} with self._lock: template_conditions = self._cache.conditions template_parameters = self._cache.parameters # Initializes config Value objects with default values. - if self._stringified_default_config is not None: + if self._stringified_default_config: for key, value in self._stringified_default_config.items(): config_values[key] = _Value('default', value) self._evaluator = _ConditionEvaluator(template_conditions, @@ -185,7 +198,7 @@ def evaluate(self, context: Optional[Dict[str, Union[str, int]]] = None) -> 'Ser config_values) return ServerConfig(config_values=self._evaluator.evaluate()) - def set(self, template_data_json: str): + def set(self, template_data_json: str) -> None: """Updates the cache to store the given template is of type ServerTemplateData. Args: @@ -197,7 +210,7 @@ def set(self, template_data_json: str): with self._lock: self._cache = template_data - def to_json(self): + def to_json(self) -> str: """Provides the server template in a JSON format to be used for initialization later.""" if not self._cache: raise ValueError("""No Remote Config Server template in cache. @@ -209,30 +222,30 @@ def to_json(self): class ServerConfig: """Represents a Remote Config Server Side Config.""" - def __init__(self, config_values): + def __init__(self, config_values: typing.Dict[str, "_Value"]): self._config_values = config_values # dictionary of param key to values - def get_boolean(self, key): + def get_boolean(self, key: str) -> bool: """Returns the value as a boolean.""" return self._get_value(key).as_boolean() - def get_string(self, key): + def get_string(self, key: str) -> str: """Returns the value as a string.""" return self._get_value(key).as_string() - def get_int(self, key): + def get_int(self, key: str) -> int: """Returns the value as an integer.""" return self._get_value(key).as_int() - def get_float(self, key): + def get_float(self, key: str) -> float: """Returns the value as a float.""" return self._get_value(key).as_float() - def get_value_source(self, key): + def get_value_source(self, key: str) -> ValueSource: """Returns the source of the value.""" return self._get_value(key).get_source() - def _get_value(self, key): + def _get_value(self, key: str) -> "_Value": return self._config_values.get(key, _Value('static')) @@ -240,7 +253,7 @@ class _RemoteConfigService: """Internal class that facilitates sending requests to the Firebase Remote Config backend API. """ - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: """Initialize a JsonHttpClient with necessary inputs. Args: @@ -258,7 +271,7 @@ def __init__(self, app): base_url=remote_config_base_url, headers=rc_headers, timeout=timeout) - async def get_server_template(self): + async def get_server_template(self) -> _ServerTemplateData: """Requests for a server template and converts the response to an instance of ServerTemplateData for storing the template parameters and conditions.""" try: @@ -272,13 +285,13 @@ async def get_server_template(self): template_data['etag'] = headers.get('etag') return _ServerTemplateData(template_data) - def _get_url(self): + def _get_url(self) -> str: """Returns project prefix for url, in the format of /v1/projects/${projectId}""" return "/v1/projects/{0}/namespaces/firebase-server/serverRemoteConfig".format( self._project_id) @classmethod - def _handle_remote_config_error(cls, error: Any): + def _handle_remote_config_error(cls, error: requests.RequestException) -> exceptions.FirebaseError: """Handles errors received from the Cloud Functions API.""" return _utils.handle_platform_error_from_requests(error) @@ -286,13 +299,19 @@ def _handle_remote_config_error(cls, error: Any): class _ConditionEvaluator: """Internal class that facilitates sending requests to the Firebase Remote Config backend API.""" - def __init__(self, conditions, parameters, context, config_values): + def __init__( + self, + conditions: typing.List[typing.Dict[str, typing.Any]], + parameters: typing.Dict[str, typing.Dict[str, typing.Any]], + context: typing.Dict[str, typing.Any], + config_values: typing.Dict[str, "_Value"], + ) -> None: self._context = context self._conditions = conditions self._parameters = parameters self._config_values = config_values - def evaluate(self): + def evaluate(self) -> typing.Dict[str, "_Value"]: """Internal function that evaluates the cached server template to produce a ServerConfig""" evaluated_conditions = self.evaluate_conditions(self._conditions, self._context) @@ -300,9 +319,9 @@ def evaluate(self): # Overlays config Value objects derived by evaluating the template. if self._parameters: for key, parameter in self._parameters.items(): - conditional_values = parameter.get('conditionalValues', {}) - default_value = parameter.get('defaultValue', {}) - parameter_value_wrapper = None + conditional_values: typing.Dict[str, typing.Any] = parameter.get('conditionalValues', {}) + default_value: typing.Dict[str, typing.Any] = parameter.get('defaultValue', {}) + parameter_value_wrapper: typing.Optional[typing.Dict[str, typing.Any]] = None # Iterates in order over condition list. If there is a value associated # with a condition, this checks if the condition is true. if evaluated_conditions: @@ -316,8 +335,9 @@ def evaluate(self): continue if parameter_value_wrapper: + # possible issue: Is `None` a valid value for `_Value`? parameter_value = parameter_value_wrapper.get('value') - self._config_values[key] = _Value('remote', parameter_value) + self._config_values[key] = _Value('remote', parameter_value) # type: ignore[reportArgumentType] continue if not default_value: @@ -327,10 +347,14 @@ def evaluate(self): if default_value.get('useInAppDefault'): logger.info("Using in-app default value for key '%s'", key) continue - self._config_values[key] = _Value('remote', default_value.get('value')) + self._config_values[key] = _Value('remote', default_value.get('value')) # type: ignore[reportArgumentType] return self._config_values - def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: + def evaluate_conditions( + self, + conditions: typing.List[typing.Dict[str, typing.Any]], + context: typing.Dict[str, typing.Any], + )-> typing.Dict[str, bool]: """Evaluates a list of conditions and returns a dictionary of results. Args: @@ -340,15 +364,20 @@ def evaluate_conditions(self, conditions, context)-> Dict[str, bool]: Returns: A dictionary that maps condition names to boolean evaluation results. """ - evaluated_conditions = {} + evaluated_conditions: typing.Dict[typing.Any, typing.Any] = {} for condition in conditions: + # possible issue: does condition always have `name`? evaluated_conditions[condition.get('name')] = self.evaluate_condition( - condition.get('condition'), context + condition['condition'], context ) return evaluated_conditions - def evaluate_condition(self, condition, context, - nesting_level: int = 0) -> bool: + def evaluate_condition( + self, + condition: typing.Dict[str, typing.Any], + context: typing.Dict[str, typing.Any], + nesting_level: int = 0, + ) -> bool: """Recursively evaluates a condition. Args: @@ -363,25 +392,28 @@ def evaluate_condition(self, condition, context, logger.warning("Maximum condition recursion depth exceeded.") return False if condition.get('orCondition') is not None: - return self.evaluate_or_condition(condition.get('orCondition'), + return self.evaluate_or_condition(condition['orCondition'], context, nesting_level + 1) if condition.get('andCondition') is not None: - return self.evaluate_and_condition(condition.get('andCondition'), + return self.evaluate_and_condition(condition['andCondition'], context, nesting_level + 1) if condition.get('true') is not None: return True if condition.get('false') is not None: return False if condition.get('percent') is not None: - return self.evaluate_percent_condition(condition.get('percent'), context) + return self.evaluate_percent_condition(condition['percent'], context) if condition.get('customSignal') is not None: - return self.evaluate_custom_signal_condition(condition.get('customSignal'), context) + return self.evaluate_custom_signal_condition(condition['customSignal'], context) logger.warning("Unknown condition type encountered.") return False - def evaluate_or_condition(self, or_condition, - context, - nesting_level: int = 0) -> bool: + def evaluate_or_condition( + self, + or_condition: typing.Dict[str, typing.Any], + context: typing.Dict[str, typing.Any], + nesting_level: int = 0, + ) -> bool: """Evaluates an OR condition. Args: @@ -392,16 +424,19 @@ def evaluate_or_condition(self, or_condition, Returns: True if any of the subconditions are true, False otherwise. """ - sub_conditions = or_condition.get('conditions') or [] + sub_conditions: typing.List[typing.Dict[str, typing.Any]] = or_condition.get('conditions') or [] for sub_condition in sub_conditions: result = self.evaluate_condition(sub_condition, context, nesting_level + 1) if result: return True return False - def evaluate_and_condition(self, and_condition, - context, - nesting_level: int = 0) -> bool: + def evaluate_and_condition( + self, + and_condition: typing.Dict[str, typing.Any], + context: typing.Dict[str, typing.Any], + nesting_level: int = 0, + ) -> bool: """Evaluates an AND condition. Args: @@ -412,15 +447,18 @@ def evaluate_and_condition(self, and_condition, Returns: True if all of the subconditions are met; False otherwise. """ - sub_conditions = and_condition.get('conditions') or [] + sub_conditions: typing.List[typing.Dict[str, typing.Any]] = and_condition.get('conditions') or [] for sub_condition in sub_conditions: result = self.evaluate_condition(sub_condition, context, nesting_level + 1) if not result: return False return True - def evaluate_percent_condition(self, percent_condition, - context) -> bool: + def evaluate_percent_condition( + self, + percent_condition: typing.Dict[str, typing.Any], + context: typing.Dict[str, typing.Any], + ) -> bool: """Evaluates a percent condition. Args: @@ -464,6 +502,7 @@ def evaluate_percent_condition(self, percent_condition, return norm_percent_lower_bound < instance_micro_percentile <= norm_percent_upper_bound logger.warning("Unknown percent operator: %s", percent_operator) return False + def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: """Hashes a seeded randomization ID. @@ -478,8 +517,11 @@ def hash_seeded_randomization_id(self, seeded_randomization_id: str) -> int: hash64 = hash_object.hexdigest() return abs(int(hash64, 16)) - def evaluate_custom_signal_condition(self, custom_signal_condition, - context) -> bool: + def evaluate_custom_signal_condition( + self, + custom_signal_condition: typing.Dict[str, typing.Any], + context: typing.Dict[str, typing.Any], + ) -> bool: """Evaluates a custom signal condition. Args: @@ -489,18 +531,16 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, Returns: True if the condition is met, False otherwise. """ - custom_signal_operator = custom_signal_condition.get('customSignalOperator') or {} - custom_signal_key = custom_signal_condition.get('customSignalKey') or {} - target_custom_signal_values = ( - custom_signal_condition.get('targetCustomSignalValues') or {}) + custom_signal_operator: typing.Optional[str] = custom_signal_condition.get('customSignalOperator') + custom_signal_key: typing.Optional[str] = custom_signal_condition.get('customSignalKey') + target_custom_signal_values: typing.Optional[typing.List[typing.Any]] = ( + custom_signal_condition.get('targetCustomSignalValues')) - if not all([custom_signal_operator, custom_signal_key, target_custom_signal_values]): + if not (custom_signal_operator and custom_signal_key and target_custom_signal_values): logger.warning("Missing operator, key, or target values for custom signal condition.") return False - if not target_custom_signal_values: - return False - actual_custom_signal_value = context.get(custom_signal_key) or {} + actual_custom_signal_value: typing.Optional[typing.Any] = context.get(custom_signal_key) if not actual_custom_signal_value: logger.debug("Custom signal value not found in context: %s", custom_signal_key) @@ -521,7 +561,7 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, if custom_signal_operator == CustomSignalOperator.STRING_CONTAINS_REGEX.value: return self._compare_strings(target_custom_signal_values, actual_custom_signal_value, - re.search) + lambda pattern, string: bool(re.search(pattern, string))) # For numeric operators only one target value is allowed. if custom_signal_operator == CustomSignalOperator.NUMERIC_LESS_THAN.value: @@ -589,7 +629,12 @@ def evaluate_custom_signal_condition(self, custom_signal_condition, logger.warning("Unknown custom signal operator: %s", custom_signal_operator) return False - def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: + def _compare_strings( + self, + target_values: typing.List[str], + actual_value: str, + predicate_fn: typing.Callable[[str, str], bool], + ) -> bool: """Compares the actual string value of a signal against a list of target values. Args: @@ -609,7 +654,13 @@ def _compare_strings(self, target_values, actual_value, predicate_fn) -> bool: return True return False - def _compare_numbers(self, custom_signal_key, target_value, actual_value, predicate_fn) -> bool: + def _compare_numbers( + self, + custom_signal_key: str, + target_value: _typing.ConvertibleToFloat, + actual_value: _typing.ConvertibleToFloat, + predicate_fn: typing.Callable[[float], bool], + ) -> bool: try: target = float(target_value) actual = float(actual_value) @@ -620,8 +671,13 @@ def _compare_numbers(self, custom_signal_key, target_value, actual_value, predic custom_signal_key) return False - def _compare_semantic_versions(self, custom_signal_key, - target_value, actual_value, predicate_fn) -> bool: + def _compare_semantic_versions( + self, + custom_signal_key: str, + target_value: str, + actual_value: str, + predicate_fn: typing.Callable[[typing.Literal[-1, 0, 1]], bool] + ) -> bool: """Compares the actual semantic version value of a signal against a target value. Calls the predicate function with -1, 0, 1 if actual is less than, equal to, or greater than target. @@ -639,8 +695,13 @@ def _compare_semantic_versions(self, custom_signal_key, return self._compare_versions(custom_signal_key, str(actual_value), str(target_value), predicate_fn) - def _compare_versions(self, custom_signal_key, - sem_version_1, sem_version_2, predicate_fn) -> bool: + def _compare_versions( + self, + custom_signal_key: str, + sem_version_1: str, + sem_version_2: str, + predicate_fn: typing.Callable[[typing.Literal[-1, 0, 1]], bool] + ) -> bool: """Compares two semantic version strings. Args: @@ -673,7 +734,11 @@ def _compare_versions(self, custom_signal_key, custom_signal_key) return False -async def get_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None): + +async def get_server_template( + app: typing.Optional[firebase_admin.App] = None, + default_config: typing.Optional[typing.Dict[str, str]] = None +) -> ServerTemplate: """Initializes a new ServerTemplate instance and fetches the server template. Args: @@ -688,8 +753,12 @@ async def get_server_template(app: App = None, default_config: Optional[Dict[str await template.load() return template -def init_server_template(app: App = None, default_config: Optional[Dict[str, str]] = None, - template_data_json: Optional[str] = None): + +def init_server_template( + app: typing.Optional[firebase_admin.App] = None, + default_config: typing.Optional[typing.Dict[str, str]] = None, + template_data_json: typing.Optional[str] = None, +) -> ServerTemplate: """Initializes a new ServerTemplate instance. Args: @@ -707,6 +776,7 @@ def init_server_template(app: App = None, default_config: Optional[Dict[str, str template.set(template_data_json) return template + class _Value: """Represents a value fetched from Remote Config. """ @@ -716,7 +786,7 @@ class _Value: DEFAULT_VALUE_FOR_FLOAT_NUMBER = 0.0 BOOLEAN_TRUTHY_VALUES = ['1', 'true', 't', 'yes', 'y', 'on'] - def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): + def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING) -> None: """Initializes a Value instance. Args: @@ -726,7 +796,7 @@ def __init__(self, source: ValueSource, value: str = DEFAULT_VALUE_FOR_STRING): "remote" indicates the value was defined by config produced by evaluating a template. value: The string value. """ - self.source = source + self.source: ValueSource = source self.value = value def as_string(self) -> str: @@ -741,7 +811,7 @@ def as_boolean(self) -> bool: return self.DEFAULT_VALUE_FOR_BOOLEAN return str(self.value).lower() in self.BOOLEAN_TRUTHY_VALUES - def as_int(self) -> float: + def as_int(self) -> int: """Returns the value as a number.""" if self.source == 'static': return self.DEFAULT_VALUE_FOR_INTEGER diff --git a/firebase_admin/storage.py b/firebase_admin/storage.py index b6084842..2abdec6e 100644 --- a/firebase_admin/storage.py +++ b/firebase_admin/storage.py @@ -18,6 +18,8 @@ Firebase apps. This requires the ``google-cloud-storage`` Python module. """ +import typing + # pylint: disable=import-error,no-name-in-module try: from google.cloud import storage @@ -25,12 +27,14 @@ raise ImportError('Failed to import the Cloud Storage library for Python. Make sure ' 'to install the "google-cloud-storage" module.') -from firebase_admin import _utils +from google.auth import credentials + +from firebase_admin import App, _utils _STORAGE_ATTRIBUTE = '_storage' -def bucket(name=None, app=None) -> storage.Bucket: +def bucket(name: typing.Optional[str] = None, app: typing.Optional[App] = None) -> storage.Bucket: """Returns a handle to a Google Cloud Storage bucket. If the name argument is not provided, uses the 'storageBucket' option specified when @@ -59,20 +63,25 @@ class _StorageClient: 'x-goog-api-client': _utils.get_metrics_header(), } - def __init__(self, credentials, project, default_bucket): + def __init__( + self, + credentials: credentials.Credentials, + project: typing.Optional[str], + default_bucket: typing.Optional[str] + ) -> None: self._client = storage.Client( credentials=credentials, project=project, extra_headers=self.STORAGE_HEADERS) self._default_bucket = default_bucket @classmethod - def from_app(cls, app): + def from_app(cls, app: App) -> "_StorageClient": credentials = app.credential.get_credential() default_bucket = app.options.get('storageBucket') # Specifying project ID is not required, but providing it when available # significantly speeds up the initialization of the storage client. return _StorageClient(credentials, app.project_id, default_bucket) - def bucket(self, name=None): + def bucket(self, name: typing.Optional[str] = None) -> storage.Bucket: """Returns a handle to the specified Cloud Storage Bucket.""" bucket_name = name if name is not None else self._default_bucket if bucket_name is None: @@ -84,4 +93,4 @@ def bucket(self, name=None): raise ValueError( 'Invalid storage bucket name: "{0}". Bucket name must be a non-empty ' 'string.'.format(bucket_name)) - return self._client.bucket(bucket_name) + return self._client.bucket(bucket_name) # type: ignore[reportUnknownMemberType] diff --git a/firebase_admin/tenant_mgt.py b/firebase_admin/tenant_mgt.py index 8c53e30a..db15de96 100644 --- a/firebase_admin/tenant_mgt.py +++ b/firebase_admin/tenant_mgt.py @@ -20,6 +20,7 @@ import re import threading +import typing import requests @@ -54,7 +55,7 @@ TenantNotFoundError = _auth_utils.TenantNotFoundError -def auth_for_tenant(tenant_id, app=None): +def auth_for_tenant(tenant_id: str, app: typing.Optional[firebase_admin.App] = None) -> auth.Client: """Gets an Auth Client instance scoped to the given tenant ID. Args: @@ -71,7 +72,7 @@ def auth_for_tenant(tenant_id, app=None): return tenant_mgt_service.auth_for_tenant(tenant_id) -def get_tenant(tenant_id, app=None): +def get_tenant(tenant_id: str, app: typing.Optional[firebase_admin.App] = None) -> "Tenant": """Gets the tenant corresponding to the given ``tenant_id``. Args: @@ -91,7 +92,11 @@ def get_tenant(tenant_id, app=None): def create_tenant( - display_name, allow_password_sign_up=None, enable_email_link_sign_in=None, app=None): + display_name: str, + allow_password_sign_up: typing.Optional[bool] = None, + enable_email_link_sign_in: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None +) -> "Tenant": """Creates a new tenant from the given options. Args: @@ -117,8 +122,12 @@ def create_tenant( def update_tenant( - tenant_id, display_name=None, allow_password_sign_up=None, enable_email_link_sign_in=None, - app=None): + tenant_id: str, + display_name: typing.Optional[str] = None, + allow_password_sign_up: typing.Optional[bool] = None, + enable_email_link_sign_in: typing.Optional[bool] = None, + app: typing.Optional[firebase_admin.App] = None, +) -> "Tenant": """Updates an existing tenant with the given options. Args: @@ -144,7 +153,7 @@ def update_tenant( enable_email_link_sign_in=enable_email_link_sign_in) -def delete_tenant(tenant_id, app=None): +def delete_tenant(tenant_id: str, app: typing.Optional[firebase_admin.App] = None) -> None: """Deletes the tenant corresponding to the given ``tenant_id``. Args: @@ -160,7 +169,11 @@ def delete_tenant(tenant_id, app=None): tenant_mgt_service.delete_tenant(tenant_id) -def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=None): +def list_tenants( + page_token: typing.Optional[str] = None, + max_results: int = _MAX_LIST_TENANTS_RESULTS, + app: typing.Optional[firebase_admin.App] = None +) -> "ListTenantsPage": """Retrieves a page of tenants from a Firebase project. The ``page_token`` argument governs the starting point of the page. The ``max_results`` @@ -183,12 +196,12 @@ def list_tenants(page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS, app=Non FirebaseError: If an error occurs while retrieving the user accounts. """ tenant_mgt_service = _get_tenant_mgt_service(app) - def download(page_token, max_results): + def download(page_token: typing.Optional[str], max_results: int) -> typing.Dict[str, typing.Any]: return tenant_mgt_service.list_tenants(page_token, max_results) return ListTenantsPage(download, page_token, max_results) -def _get_tenant_mgt_service(app): +def _get_tenant_mgt_service(app: typing.Optional[firebase_admin.App]) -> "_TenantManagementService": return _utils.get_app_service(app, _TENANT_MGT_ATTRIBUTE, _TenantManagementService) @@ -203,7 +216,7 @@ class Tenant: such as the display name, tenant identifier and email authentication configuration. """ - def __init__(self, data): + def __init__(self, data: typing.Dict[str, typing.Any]) -> None: if not isinstance(data, dict): raise ValueError('Invalid data argument in Tenant constructor: {0}'.format(data)) if not 'name' in data: @@ -212,7 +225,7 @@ def __init__(self, data): self._data = data @property - def tenant_id(self): + def tenant_id(self) -> str: name = self._data['name'] return name.split('/')[-1] @@ -221,11 +234,11 @@ def display_name(self): return self._data.get('displayName') @property - def allow_password_sign_up(self): + def allow_password_sign_up(self) -> bool: return self._data.get('allowPasswordSignup', False) @property - def enable_email_link_sign_in(self): + def enable_email_link_sign_in(self) -> bool: return self._data.get('enableEmailLinkSignin', False) @@ -234,17 +247,17 @@ class _TenantManagementService: TENANT_MGT_URL = 'https://identitytoolkit.googleapis.com/v2' - def __init__(self, app): + def __init__(self, app: firebase_admin.App) -> None: credential = app.credential.get_credential() version_header = 'Python/Admin/{0}'.format(firebase_admin.__version__) base_url = '{0}/projects/{1}'.format(self.TENANT_MGT_URL, app.project_id) self.app = app self.client = _http_client.JsonHttpClient( credential=credential, base_url=base_url, headers={'X-Client-Version': version_header}) - self.tenant_clients = {} + self.tenant_clients: typing.Dict[str, auth.Client] = {} self.lock = threading.RLock() - def auth_for_tenant(self, tenant_id): + def auth_for_tenant(self, tenant_id: str) -> auth.Client: """Gets an Auth Client instance scoped to the given tenant ID.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -256,9 +269,9 @@ def auth_for_tenant(self, tenant_id): client = auth.Client(self.app, tenant_id=tenant_id) self.tenant_clients[tenant_id] = client - return client + return client - def get_tenant(self, tenant_id): + def get_tenant(self, tenant_id: str) -> Tenant: """Gets the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -272,10 +285,14 @@ def get_tenant(self, tenant_id): return Tenant(body) def create_tenant( - self, display_name, allow_password_sign_up=None, enable_email_link_sign_in=None): + self, + display_name: str, + allow_password_sign_up: typing.Optional[bool] = None, + enable_email_link_sign_in: typing.Optional[bool] = None + ): """Creates a new tenant from the given parameters.""" - payload = {'displayName': _validate_display_name(display_name)} + payload: typing.Dict[str, typing.Any] = {'displayName': _validate_display_name(display_name)} if allow_password_sign_up is not None: payload['allowPasswordSignup'] = _auth_utils.validate_boolean( allow_password_sign_up, 'allowPasswordSignup') @@ -291,13 +308,17 @@ def create_tenant( return Tenant(body) def update_tenant( - self, tenant_id, display_name=None, allow_password_sign_up=None, - enable_email_link_sign_in=None): + self, + tenant_id: str, + display_name: typing.Optional[str] = None, + allow_password_sign_up: typing.Optional[bool] = None, + enable_email_link_sign_in: typing.Optional[bool] = None + ): """Updates the specified tenant with the given parameters.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError('Tenant ID must be a non-empty string.') - payload = {} + payload: typing.Dict[str, typing.Any] = {} if display_name is not None: payload['displayName'] = _validate_display_name(display_name) if allow_password_sign_up is not None: @@ -320,7 +341,7 @@ def update_tenant( else: return Tenant(body) - def delete_tenant(self, tenant_id): + def delete_tenant(self, tenant_id: str) -> None: """Deletes the tenant corresponding to the given ``tenant_id``.""" if not isinstance(tenant_id, str) or not tenant_id: raise ValueError( @@ -331,7 +352,11 @@ def delete_tenant(self, tenant_id): except requests.exceptions.RequestException as error: raise _auth_utils.handle_auth_backend_error(error) - def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): + def list_tenants( + self, + page_token: typing.Optional[str] = None, + max_results: int = _MAX_LIST_TENANTS_RESULTS, + ) -> typing.Dict[str, typing.Any]: """Retrieves a batch of tenants.""" if page_token is not None: if not isinstance(page_token, str) or not page_token: @@ -343,7 +368,7 @@ def list_tenants(self, page_token=None, max_results=_MAX_LIST_TENANTS_RESULTS): 'Max results must be a positive integer less than or equal to ' '{0}.'.format(_MAX_LIST_TENANTS_RESULTS)) - payload = {'pageSize': max_results} + payload: typing.Dict[str, typing.Any] = {'pageSize': max_results} if page_token: payload['pageToken'] = page_token try: @@ -360,27 +385,32 @@ class ListTenantsPage: through all tenants in the Firebase project starting from this page. """ - def __init__(self, download, page_token, max_results): + def __init__( + self, + download: typing.Callable[[typing.Optional[str], int], typing.Dict[str, typing.Any]], + page_token: typing.Optional[str], + max_results: int, + ) -> None: self._download = download self._max_results = max_results self._current = download(page_token, max_results) @property - def tenants(self): + def tenants(self) -> typing.List[Tenant]: """A list of ``ExportedUserRecord`` instances available in this page.""" return [Tenant(data) for data in self._current.get('tenants', [])] @property - def next_page_token(self): + def next_page_token(self) -> str: """Page token string for the next page (empty string indicates no more pages).""" return self._current.get('nextPageToken', '') @property - def has_next_page(self): + def has_next_page(self) -> bool: """A boolean indicating whether more pages are available.""" return bool(self.next_page_token) - def get_next_page(self): + def get_next_page(self) -> typing.Optional["ListTenantsPage"]: """Retrieves the next page of tenants, if available. Returns: @@ -411,16 +441,16 @@ class _TenantIterator: of entries in memory. """ - def __init__(self, current_page): + def __init__(self, current_page: ListTenantsPage) -> None: if not current_page: raise ValueError('Current page must not be None.') - self._current_page = current_page + self._current_page: ListTenantsPage = current_page self._index = 0 - def next(self): + def next(self) -> Tenant: if self._index == len(self._current_page.tenants): if self._current_page.has_next_page: - self._current_page = self._current_page.get_next_page() + self._current_page = typing.cast(ListTenantsPage, self._current_page.get_next_page()) self._index = 0 if self._index < len(self._current_page.tenants): result = self._current_page.tenants[self._index] @@ -428,14 +458,14 @@ def next(self): return result raise StopIteration - def __next__(self): + def __next__(self) -> Tenant: return self.next() - def __iter__(self): + def __iter__(self) -> typing.Iterator[Tenant]: return self -def _validate_display_name(display_name): +def _validate_display_name(display_name: typing.Any) -> str: if not isinstance(display_name, str): raise ValueError('Invalid type for displayName') if not _DISPLAY_NAME_PATTERN.search(display_name): diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..ed8806ea --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,29 @@ +{ + "pythonVersion": "3.7", + "typeCheckingMode": "strict", + + "include": ["firebase_admin"], + + "ignore": [ + "integration", + "snippets", + "tests", + "setup.py", + ], + + // Suppress import cycle errors (using forward references as needed) + "reportImportCycles": "none", + + // Allow dependencies without type annotations or stubs + "reportIncompleteStub": "none", + "reportMissingTypeStubs": "none", + + // Permit usage of private members across modules + "reportPrivateUsage": "none", + + // Allow `isinstance` for type assertions and runtime checks + "reportUnnecessaryIsInstance": "none", + + // Warn when a previously ignored type check is no longer needed + "reportUnnecessaryTypeIgnoreComment": "warning", +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fd5b0b39..c6117ef2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,14 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 pytest-asyncio >= 0.16.0 pytest-mock >= 3.6.1 +oauth2client cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 -pyjwt[crypto] >= 2.5.0 \ No newline at end of file +pyjwt[crypto] >= 2.5.0 +typing-extensions >= 4.12.0 +types-requests +types-httplib2 \ No newline at end of file diff --git a/setup.py b/setup.py index 23be6d48..ae9e402a 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,9 @@ 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', + 'typing-extensions >= 4.12.0', + 'types-requests', + 'types-httplib2' ] setup( @@ -73,5 +76,6 @@ 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', 'License :: OSI Approved :: Apache Software License', + 'Typing :: Typed', ], )