-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: jumpstart estimator for gated uncompressed training #5175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8ae6835
68202a7
e9f2f65
5de1e55
402fbc9
c9ceff2
2830f39
4d45973
f26a170
b1ffab2
a7ff61d
4c46f01
87d752f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1310,6 +1310,7 @@ def fit( | |
logs: str = "All", | ||
job_name: Optional[str] = None, | ||
experiment_config: Optional[Dict[str, str]] = None, | ||
accept_eula: Optional[bool] = None, | ||
): | ||
"""Train a model using the input training dataset. | ||
|
||
|
@@ -1363,14 +1364,21 @@ def fit( | |
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance | ||
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. | ||
However, the value of `TrialComponentDisplayName` is honored for display in Studio. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
Returns: | ||
None or pipeline step arguments in case the Estimator instance is built with | ||
:class:`~sagemaker.workflow.pipeline_context.PipelineSession` | ||
""" | ||
self._prepare_for_training(job_name=job_name) | ||
|
||
experiment_config = check_and_get_run_experiment_config(experiment_config) | ||
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) | ||
self.latest_training_job = _TrainingJob.start_new( | ||
self, inputs, experiment_config, accept_eula | ||
) | ||
self.jobs.append(self.latest_training_job) | ||
forward_to_mlflow_tracking_server = False | ||
if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation(): | ||
|
@@ -2484,7 +2492,7 @@ class _TrainingJob(_Job): | |
"""Placeholder docstring""" | ||
|
||
@classmethod | ||
def start_new(cls, estimator, inputs, experiment_config): | ||
def start_new(cls, estimator, inputs, experiment_config, accept_eula=None): | ||
"""Create a new Amazon SageMaker training job from the estimator. | ||
|
||
Args: | ||
|
@@ -2504,19 +2512,24 @@ def start_new(cls, estimator, inputs, experiment_config): | |
will be unassociated. | ||
* `TrialComponentDisplayName` is used for display in Studio. | ||
* `RunName` is used to record an experiment run. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
Returns: | ||
sagemaker.estimator._TrainingJob: Constructed object that captures | ||
all information about the started training job. | ||
""" | ||
train_args = cls._get_train_args(estimator, inputs, experiment_config) | ||
train_args = cls._get_train_args(estimator, inputs, experiment_config, accept_eula) | ||
|
||
logger.debug("Train args after processing defaults: %s", train_args) | ||
estimator.sagemaker_session.train(**train_args) | ||
|
||
return cls(estimator.sagemaker_session, estimator._current_job_name) | ||
|
||
@classmethod | ||
def _get_train_args(cls, estimator, inputs, experiment_config): | ||
def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula=None): | ||
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator. | ||
|
||
Args: | ||
|
@@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config): | |
will be unassociated. | ||
* `TrialComponentDisplayName` is used for display in Studio. | ||
* `RunName` is used to record an experiment run. | ||
accept_eula (bool): For models that require a Model Access Config, specify True or | ||
False to indicate whether model terms of use have been accepted. | ||
The `accept_eula` value must be explicitly defined as `True` in order to | ||
accept the end-user license agreement (EULA) that some | ||
models require. (Default: None). | ||
|
||
Returns: | ||
Dict: dict for `sagemaker.session.Session.train` method | ||
|
@@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): | |
if estimator.get_session_chaining_config() is not None: | ||
train_args["session_chaining_config"] = estimator.get_session_chaining_config() | ||
|
||
if accept_eula is not None: | ||
cls._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) | ||
|
||
return train_args | ||
|
||
@classmethod | ||
|
@@ -2674,6 +2695,36 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args): | |
raise ValueError("Setting checkpoint_local_path is not supported in local mode.") | ||
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path | ||
|
||
@classmethod | ||
def _set_accept_eula_for_model_channel_input_data_config(cls, train_args, accept_eula): | ||
"""Set the AcceptEula flag for model channel in input data configurations. | ||
|
||
This method sets the AcceptEula flag in the ModelAccessConfig for the model channel | ||
S3DataSource in the input_config array. It handles cases where keys might not exist in the | ||
nested dictionary structure. | ||
|
||
Args: | ||
train_args (dict): The training job arguments dictionary | ||
accept_eula (bool): The value to set for AcceptEula flag | ||
""" | ||
if "input_config" not in train_args: | ||
return | ||
|
||
if accept_eula is None: | ||
return | ||
|
||
for idx in range(len(train_args["input_config"])): | ||
if ( | ||
"DataSource" in train_args["input_config"][idx] | ||
and train_args["input_config"][idx]["ChannelName"].lower().strip() == "model" | ||
): | ||
data_source = train_args["input_config"][idx]["DataSource"] | ||
if "S3DataSource" in data_source: | ||
s3_data_source = data_source["S3DataSource"] | ||
if "ModelAccessConfig" not in s3_data_source: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Who sets this ModelAccessConfig? Is it the default artifacts set by JumpStart or is it something user would explicitly add to their inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's set by us when the customer inputs |
||
s3_data_source["ModelAccessConfig"] = {} | ||
s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula | ||
|
||
@classmethod | ||
def _is_local_channel(cls, input_uri): | ||
"""Placeholder docstring""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq- I know @Narrohag added some code that does something very similar. I want to make sure we're not doing the same work twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm not fully following why this extra accept_eula logic is needed. For reference Evan this is where I added most of the model access config logic: https://github.com/aws/sagemaker-python-sdk/pull/5070/files. I'll take a closer look in a bit and maybe we can chat about it today