Skip to content

Commit c0b740c

Browse files
pintaoz-awspintaoz
and
pintaoz
authored
Fix all type hint and docstrings for callable (#5035)
* Fix all type hint and docstrings for callable * Fix codestyle --------- Co-authored-by: pintaoz <pintaoz@amazon.com>
1 parent d08c294 commit c0b740c

File tree

24 files changed

+71
-69
lines changed

24 files changed

+71
-69
lines changed

src/sagemaker/amazon/hyperparameter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, name, validate=lambda _: True, validation_message="", data_ty
2828
"""Args:
2929
3030
name (str): The name of this hyperparameter validate
31-
(callable[object]->[bool]): A validation function or list of validation
31+
(Callable[object]->[bool]): A validation function or list of validation
3232
functions.
3333
3434
Each function validates an object and returns False if the object

src/sagemaker/amazon/ipinsights.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def __init__(
209209
chain.
210210
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
211211
serializes input data to text/csv.
212-
deserializer (callable): Optional. Default parses JSON responses
212+
deserializer (Callable): Optional. Default parses JSON responses
213213
using ``json.load(...)``.
214214
component_name (str): Optional. Name of the Amazon SageMaker inference
215215
component corresponding the predictor.

src/sagemaker/automl/automl.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def create_model(
478478
training cluster for distributed training. Default: False
479479
model_kms_key (str): KMS key ARN used to encrypt the repacked
480480
model archive file if the model is repacked
481-
predictor_cls (callable[string, sagemaker.session.Session]): A
481+
Callable[[string, sagemaker.session.Session], Any]: A
482482
function to call to create a predictor (default: None). If
483483
specified, ``deploy()`` returns the result of invoking this
484484
function on the created endpoint name.
@@ -591,7 +591,7 @@ def deploy(
591591
training cluster for distributed training. Default: False
592592
model_kms_key (str): KMS key ARN used to encrypt the repacked
593593
model archive file if the model is repacked
594-
predictor_cls (callable[string, sagemaker.session.Session]): A
594+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
595595
function to call to create a predictor (default: None). If
596596
specified, ``deploy()`` returns the result of invoking this
597597
function on the created endpoint name.
@@ -609,7 +609,7 @@ def deploy(
609609
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
610610
611611
Returns:
612-
callable[string, sagemaker.session.Session] or ``None``:
612+
Optional[Callable[[string, sagemaker.session.Session], Any]]:
613613
If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
614614
the created endpoint name. Otherwise, ``None``.
615615
"""

src/sagemaker/automl/automlv2.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def create_model(
10221022
training cluster for distributed training. Default: False
10231023
model_kms_key (str): KMS key ARN used to encrypt the repacked
10241024
model archive file if the model is repacked
1025-
predictor_cls (callable[string, sagemaker.session.Session]): A
1025+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
10261026
function to call to create a predictor (default: None). If
10271027
specified, ``deploy()`` returns the result of invoking this
10281028
function on the created endpoint name.
@@ -1130,7 +1130,7 @@ def deploy(
11301130
training cluster for distributed training. Default: False
11311131
model_kms_key (str): KMS key ARN used to encrypt the repacked
11321132
model archive file if the model is repacked
1133-
predictor_cls (callable[string, sagemaker.session.Session]): A
1133+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
11341134
function to call to create a predictor (default: None). If
11351135
specified, ``deploy()`` returns the result of invoking this
11361136
function on the created endpoint name.
@@ -1148,7 +1148,7 @@ def deploy(
11481148
https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-algo-ping-requests
11491149
11501150
Returns:
1151-
callable[string, sagemaker.session.Session] or ``None``:
1151+
Optional[Callable[[string, sagemaker.session.Session], Any]]:
11521152
If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
11531153
the created endpoint name. Otherwise, ``None``.
11541154
"""

src/sagemaker/chainer/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Optional, Union, List, Dict
17+
from typing import Callable, Optional, Union, List, Dict
1818

1919
import sagemaker
2020
from sagemaker import image_uris, ModelMetrics
@@ -96,7 +96,7 @@ def __init__(
9696
image_uri: Optional[Union[str, PipelineVariable]] = None,
9797
framework_version: Optional[str] = None,
9898
py_version: Optional[str] = None,
99-
predictor_cls: callable = ChainerPredictor,
99+
predictor_cls: Optional[Callable] = ChainerPredictor,
100100
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
101101
**kwargs,
102102
):
@@ -125,7 +125,7 @@ def __init__(
125125
py_version (str): Python version you want to use for executing your
126126
model training code. Defaults to ``None``. Required unless
127127
``image_uri`` is provided.
128-
predictor_cls (callable[str, sagemaker.session.Session]): A function
128+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function
129129
to call to create a predictor with an endpoint name and
130130
SageMaker ``Session``. If specified, ``deploy()`` returns the
131131
result of invoking this function on the created endpoint name.

src/sagemaker/djl_inference/model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Optional, Dict, Any
17+
from typing import Callable, Optional, Dict, Any
1818

1919
from sagemaker import image_uris
2020
from sagemaker.model import Model
@@ -54,7 +54,7 @@ def __init__(
5454
parallel_loading: bool = False,
5555
model_loading_timeout: Optional[int] = None,
5656
prediction_timeout: Optional[int] = None,
57-
predictor_cls: callable = DJLPredictor,
57+
predictor_cls: Optional[Callable] = DJLPredictor,
5858
huggingface_hub_token: Optional[str] = None,
5959
**kwargs,
6060
):
@@ -97,10 +97,10 @@ def __init__(
9797
None. If not provided, the default is 240 seconds.
9898
prediction_timeout (int): The worker predict call (handler) timeout in seconds.
9999
Defaults to None. If not provided, the default is 120 seconds.
100-
predictor_cls (callable[str, sagemaker.session.Session]): A function to call to create a
101-
predictor with an endpoint name and SageMaker ``Session``. If specified,
102-
``deploy()`` returns
103-
the result of invoking this function on the created endpoint name.
100+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function to call
101+
to create a predictor with an endpoint name and SageMaker ``Session``. If
102+
specified, ``deploy()`` returns the result of invoking this function on the created
103+
endpoint name.
104104
huggingface_hub_token (str): The HuggingFace Hub token to use for downloading the model
105105
artifacts for a model stored on the huggingface hub.
106106
Defaults to None. If not provided, the token must be specified in the

src/sagemaker/huggingface/model.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import logging
17-
from typing import Optional, Union, List, Dict
17+
from typing import Callable, Optional, Union, List, Dict
1818

1919
import sagemaker
2020
from sagemaker import image_uris, ModelMetrics
@@ -123,7 +123,7 @@ def __init__(
123123
pytorch_version: Optional[str] = None,
124124
py_version: Optional[str] = None,
125125
image_uri: Optional[Union[str, PipelineVariable]] = None,
126-
predictor_cls: callable = HuggingFacePredictor,
126+
predictor_cls: Optional[Callable] = HuggingFacePredictor,
127127
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
128128
**kwargs,
129129
):
@@ -158,7 +158,7 @@ def __init__(
158158
If not specified, a default image for PyTorch will be used. If ``framework_version``
159159
or ``py_version`` are ``None``, then ``image_uri`` is required. If
160160
also ``None``, then a ``ValueError`` will be raised.
161-
predictor_cls (callable[str, sagemaker.session.Session]): A function
161+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function
162162
to call to create a predictor with an endpoint name and
163163
SageMaker ``Session``. If specified, ``deploy()`` returns the
164164
result of invoking this function on the created endpoint name.
@@ -304,7 +304,7 @@ def deploy(
304304
- If a wrong type of object is provided as serverless inference config or async
305305
inference config
306306
Returns:
307-
callable[string, sagemaker.session.Session] or None: Invocation of
307+
Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of
308308
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
309309
is not None. Otherwise, return None.
310310
"""

src/sagemaker/jumpstart/estimator.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616

17-
from typing import Dict, List, Optional, Union
17+
from typing import Callable, Dict, List, Optional, Union
1818
from sagemaker import session
1919
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2020
from sagemaker.base_deserializers import BaseDeserializer
@@ -817,7 +817,7 @@ def deploy(
817817
explainer_config: Optional[ExplainerConfig] = None,
818818
image_uri: Optional[Union[str, PipelineVariable]] = None,
819819
role: Optional[str] = None,
820-
predictor_cls: Optional[callable] = None,
820+
predictor_cls: Optional[Callable] = None,
821821
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
822822
model_name: Optional[str] = None,
823823
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
@@ -918,7 +918,7 @@ def deploy(
918918
It can be null if this is being used to create a Model to pass
919919
to a ``PipelineModel`` which has its own Role field. (Default:
920920
None).
921-
predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A
921+
predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A
922922
function to call to create a predictor (Default: None). If not
923923
None, ``deploy`` will return the result of invoking this
924924
function on the created endpoint name. (Default: None).

src/sagemaker/jumpstart/factory/estimator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import absolute_import
1515

1616

17-
from typing import Dict, List, Optional, Union
17+
from typing import Callable, Dict, List, Optional, Union
1818
from sagemaker import (
1919
environment_variables,
2020
hyperparameters as hyperparameters_utils,
@@ -330,7 +330,7 @@ def get_deploy_kwargs(
330330
explainer_config: Optional[ExplainerConfig] = None,
331331
image_uri: Optional[Union[str, PipelineVariable]] = None,
332332
role: Optional[str] = None,
333-
predictor_cls: Optional[callable] = None,
333+
predictor_cls: Optional[Callable] = None,
334334
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
335335
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
336336
sagemaker_session: Optional[Session] = None,

src/sagemaker/jumpstart/factory/model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import json
1616

1717

18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, Callable, Dict, List, Optional, Union
1919
from sagemaker_core.shapes import ModelAccessConfig
2020
from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris
2121
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
@@ -855,7 +855,7 @@ def get_init_kwargs(
855855
image_uri: Optional[Union[str, PipelineVariable]] = None,
856856
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
857857
role: Optional[str] = None,
858-
predictor_cls: Optional[callable] = None,
858+
predictor_cls: Optional[Callable] = None,
859859
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
860860
name: Optional[str] = None,
861861
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,

src/sagemaker/jumpstart/model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import absolute_import
1616

17-
from typing import Dict, List, Optional, Any, Union
17+
from typing import Callable, Dict, List, Optional, Any, Union
1818
import pandas as pd
1919
from botocore.exceptions import ClientError
2020

@@ -95,7 +95,7 @@ def __init__(
9595
image_uri: Optional[Union[str, PipelineVariable]] = None,
9696
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
9797
role: Optional[str] = None,
98-
predictor_cls: Optional[callable] = None,
98+
predictor_cls: Optional[Callable] = None,
9999
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
100100
name: Optional[str] = None,
101101
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
@@ -149,7 +149,7 @@ def __init__(
149149
It can be null if this is being used to create a Model to pass
150150
to a ``PipelineModel`` which has its own Role field. (Default:
151151
None).
152-
predictor_cls (Optional[callable[string, sagemaker.session.Session]]): A
152+
predictor_cls (Optional[Callable[[string, sagemaker.session.Session], Any]]): A
153153
function to call to create a predictor (Default: None). If not
154154
None, ``deploy`` will return the result of invoking this
155155
function on the created endpoint name. (Default: None).

src/sagemaker/jumpstart/types.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import re
1717
from copy import deepcopy
1818
from enum import Enum
19-
from typing import Any, Dict, List, Optional, Set, Union
19+
from typing import Any, Callable, Dict, List, Optional, Set, Union
2020
from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig
2121
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
2222
from sagemaker.utils import (
@@ -2150,7 +2150,7 @@ def __init__(
21502150
image_uri: Optional[Union[str, Any]] = None,
21512151
model_data: Optional[Union[str, Any, dict]] = None,
21522152
role: Optional[str] = None,
2153-
predictor_cls: Optional[callable] = None,
2153+
predictor_cls: Optional[Callable] = None,
21542154
env: Optional[Dict[str, Union[str, Any]]] = None,
21552155
name: Optional[str] = None,
21562156
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,
@@ -2698,7 +2698,7 @@ def __init__(
26982698
explainer_config: Optional[Any] = None,
26992699
image_uri: Optional[Union[str, Any]] = None,
27002700
role: Optional[str] = None,
2701-
predictor_cls: Optional[callable] = None,
2701+
predictor_cls: Optional[Callable] = None,
27022702
env: Optional[Dict[str, Union[str, Any]]] = None,
27032703
model_name: Optional[str] = None,
27042704
vpc_config: Optional[Dict[str, List[Union[str, Any]]]] = None,

src/sagemaker/model.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121
import re
2222
import copy
23-
from typing import List, Dict, Optional, Union, Any
23+
from typing import Callable, List, Dict, Optional, Union, Any
2424

2525
import sagemaker
2626
from sagemaker import (
@@ -154,7 +154,7 @@ def __init__(
154154
image_uri: Optional[Union[str, PipelineVariable]] = None,
155155
model_data: Optional[Union[str, PipelineVariable, dict]] = None,
156156
role: Optional[str] = None,
157-
predictor_cls: Optional[callable] = None,
157+
predictor_cls: Optional[Callable] = None,
158158
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
159159
name: Optional[str] = None,
160160
vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None,
@@ -186,7 +186,7 @@ def __init__(
186186
It can be null if this is being used to create a Model to pass
187187
to a ``PipelineModel`` which has its own Role field. (default:
188188
None)
189-
predictor_cls (callable[string, sagemaker.session.Session]): A
189+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
190190
function to call to create a predictor (default: None). If not
191191
None, ``deploy`` will return the result of invoking this
192192
function on the created endpoint name.
@@ -1501,7 +1501,7 @@ def deploy(
15011501
inference config or
15021502
- If inference recommendation id is specified along with incompatible parameters
15031503
Returns:
1504-
callable[string, sagemaker.session.Session] or None: Invocation of
1504+
Callable[[string, sagemaker.session.Session], Any] or None: Invocation of
15051505
``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls``
15061506
is not None. Otherwise, return None.
15071507
"""
@@ -1959,7 +1959,7 @@ def __init__(
19591959
role: Optional[str] = None,
19601960
entry_point: Optional[str] = None,
19611961
source_dir: Optional[str] = None,
1962-
predictor_cls: Optional[callable] = None,
1962+
predictor_cls: Optional[Callable] = None,
19631963
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
19641964
name: Optional[str] = None,
19651965
container_log_level: Union[int, PipelineVariable] = logging.INFO,
@@ -2012,7 +2012,7 @@ def __init__(
20122012
>>> |----- test.py
20132013
20142014
You can assign entry_point='inference.py', source_dir='src'.
2015-
predictor_cls (callable[string, sagemaker.session.Session]): A
2015+
predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A
20162016
function to call to create a predictor (default: None). If not
20172017
None, ``deploy`` will return the result of invoking this
20182018
function on the created endpoint name.

src/sagemaker/multidatamodel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def deploy(
223223
Amazon SageMaker Model Monitoring. Default: None.
224224
225225
Returns:
226-
callable[string, sagemaker.session.Session] or None: Invocation of
226+
Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of
227227
``self.predictor_cls`` on the created endpoint name,
228228
if ``self.predictor_cls``
229229
is not None. Otherwise, return None.

0 commit comments

Comments
 (0)