Skip to content

fix: jumpstart estimator for gated uncompressed training #5175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

59 changes: 55 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def fit(
logs: str = "All",
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
accept_eula: Optional[bool] = None,
):
"""Train a model using the input training dataset.

Expand Down Expand Up @@ -1363,14 +1364,21 @@ def fit(
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
Returns:
None or pipeline step arguments in case the Estimator instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
self._prepare_for_training(job_name=job_name)

experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
self.latest_training_job = _TrainingJob.start_new(
self, inputs, experiment_config, accept_eula
)
self.jobs.append(self.latest_training_job)
forward_to_mlflow_tracking_server = False
if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation():
Expand Down Expand Up @@ -2484,7 +2492,7 @@ class _TrainingJob(_Job):
"""Placeholder docstring"""

@classmethod
def start_new(cls, estimator, inputs, experiment_config):
def start_new(cls, estimator, inputs, experiment_config, accept_eula=None):
"""Create a new Amazon SageMaker training job from the estimator.

Args:
Expand All @@ -2504,19 +2512,24 @@ def start_new(cls, estimator, inputs, experiment_config):
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
* `RunName` is used to record an experiment run.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
Returns:
sagemaker.estimator._TrainingJob: Constructed object that captures
all information about the started training job.
"""
train_args = cls._get_train_args(estimator, inputs, experiment_config)
train_args = cls._get_train_args(estimator, inputs, experiment_config, accept_eula)

logger.debug("Train args after processing defaults: %s", train_args)
estimator.sagemaker_session.train(**train_args)

return cls(estimator.sagemaker_session, estimator._current_job_name)

@classmethod
def _get_train_args(cls, estimator, inputs, experiment_config):
def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula=None):
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.

Args:
Expand All @@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
* `RunName` is used to record an experiment run.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
Dict: dict for `sagemaker.session.Session.train` method
Expand Down Expand Up @@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.get_session_chaining_config() is not None:
train_args["session_chaining_config"] = estimator.get_session_chaining_config()

if accept_eula is not None:
cls._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)

return train_args

@classmethod
Expand All @@ -2674,6 +2695,36 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
raise ValueError("Setting checkpoint_local_path is not supported in local mode.")
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path

@classmethod
def _set_accept_eula_for_model_channel_input_data_config(cls, train_args, accept_eula):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq- I know @Narrohag added some code that does something very similar. I want to make sure we're not doing the same work twice

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not fully following why this extra accept_eula logic is needed. For reference Evan this is where I added most of the model access config logic: https://github.com/aws/sagemaker-python-sdk/pull/5070/files. I'll take a closer look in a bit and maybe we can chat about it today

"""Set the AcceptEula flag for model channel in input data configurations.

This method sets the AcceptEula flag in the ModelAccessConfig for the model channel
S3DataSource in the input_config array. It handles cases where keys might not exist in the
nested dictionary structure.

Args:
train_args (dict): The training job arguments dictionary
accept_eula (bool): The value to set for AcceptEula flag
"""
if "input_config" not in train_args:
return

if accept_eula is None:
return

for idx in range(len(train_args["input_config"])):
if (
"DataSource" in train_args["input_config"][idx]
and train_args["input_config"][idx]["ChannelName"].lower().strip() == "model"
):
data_source = train_args["input_config"][idx]["DataSource"]
if "S3DataSource" in data_source:
s3_data_source = data_source["S3DataSource"]
if "ModelAccessConfig" not in s3_data_source:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who sets this ModelAccessConfig? Is it the default artifacts set by JumpStart or is it something user would explicitly add to their inputs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's set by us when the customer inputs accept_eula=True

s3_data_source["ModelAccessConfig"] = {}
s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula

@classmethod
def _is_local_channel(cls, input_uri):
"""Placeholder docstring"""
Expand Down
16 changes: 12 additions & 4 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,18 @@ def _get_json_file(
object and None when reading from the local file system.
"""
if self._is_local_metadata_mode():
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
else:
file_content, etag = self._get_json_file_and_etag_from_s3(key)
return file_content, etag
if filetype in {
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
}:
return self._get_json_file_from_local_override(key, filetype), None
else:
JUMPSTART_LOGGER.warning(
"Local metadata mode is enabled, but the file type %s is not supported "
"for local override. Falling back to s3.",
filetype,
)
return self._get_json_file_and_etag_from_s3(key)

def _get_json_md5_hash(self, key: str):
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def fit(
sagemaker_session=self.sagemaker_session,
config_name=self.config_name,
hub_access_config=self.hub_access_config,
accept_eula=accept_eula,
)
remove_env_var_from_estimator_kwargs_if_model_access_config_present(
self.init_kwargs, self.model_access_config
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def get_fit_kwargs(
sagemaker_session: Optional[Session] = None,
config_name: Optional[str] = None,
hub_access_config: Optional[Dict] = None,
accept_eula: Optional[bool] = None,
) -> JumpStartEstimatorFitKwargs:
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""

Expand All @@ -283,6 +284,7 @@ def get_fit_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
accept_eula=accept_eula,
)

estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs)
Expand Down
19 changes: 15 additions & 4 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,12 +1940,20 @@ def use_inference_script_uri(self) -> bool:

def use_training_model_artifact(self) -> bool:
"""Returns True if the model should use a model uri when kicking off training job."""
# gated model never use training model artifact
if self.gated_bucket:
# old models with this environment variable present don't use model channel
if any(
self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value(
instance_type
)
for instance_type in self.supported_training_instance_types
):
return False

# even older models with training model package artifact uris present also don't use model channel
if len(self.training_model_package_artifact_uris or {}) > 0:
return False

# otherwise, return true is a training model package is not set
return len(self.training_model_package_artifact_uris or {}) == 0
return getattr(self, "training_artifact_key", None) is not None

def is_gated_model(self) -> bool:
"""Returns True if the model has a EULA key or the model bucket is gated."""
Expand Down Expand Up @@ -2595,6 +2603,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"sagemaker_session",
"config_name",
"specs",
"accept_eula",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2625,6 +2634,7 @@ def __init__(
tolerate_vulnerable_model: Optional[bool] = None,
sagemaker_session: Optional[Session] = None,
config_name: Optional[str] = None,
accept_eula: Optional[bool] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand All @@ -2642,6 +2652,7 @@ def __init__(
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.sagemaker_session = sagemaker_session
self.config_name = config_name
self.accept_eula = accept_eula


class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/sagemaker/experiments/test_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker
assert _RunContext.get_current_run() == run_obj

expected_exp_config = run_obj.experiment_config
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None)

# _RunContext is cleaned up after exiting the with statement
assert not _RunContext.get_current_run()
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_auto_pass_in_exp_config_under_load_run(
assert loaded_run.experiment_config == run_obj.experiment_config

expected_exp_config = run_obj.experiment_config
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None)

# _RunContext is cleaned up after exiting the with statement
assert not _RunContext.get_current_run()
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_

assert _RunContext.get_current_run() == run_obj

mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg)
mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg, None)

# _RunContext is cleaned up after exiting the with statement
assert not _RunContext.get_current_run()
Expand Down
Loading