From 8ae6835581d0e36262335e77c74d043153163d64 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 13 May 2025 15:20:47 +0000 Subject: [PATCH 01/11] fix: jumpstart estimator for gated uncompressed training --- src/sagemaker/estimator.py | 65 ++++++++++++++++++-- src/sagemaker/jumpstart/estimator.py | 1 + src/sagemaker/jumpstart/factory/estimator.py | 2 + src/sagemaker/jumpstart/types.py | 6 +- 4 files changed, 67 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index fa40719c9f..7f0d00e76d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1310,6 +1310,7 @@ def fit( logs: str = "All", job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, + accept_eula: Optional[bool] = None, ): """Train a model using the input training dataset. @@ -1363,6 +1364,11 @@ def fit( * Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. However, the value of `TrialComponentDisplayName` is honored for display in Studio. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: None or pipeline step arguments in case the Estimator instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` @@ -1370,7 +1376,9 @@ def fit( self._prepare_for_training(job_name=job_name) experiment_config = check_and_get_run_experiment_config(experiment_config) - self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) + self.latest_training_job = _TrainingJob.start_new( + self, inputs, experiment_config, accept_eula + ) self.jobs.append(self.latest_training_job) forward_to_mlflow_tracking_server = False if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation(): @@ -2484,7 +2492,7 @@ class _TrainingJob(_Job): """Placeholder docstring""" @classmethod - def start_new(cls, estimator, inputs, experiment_config): + def start_new(cls, estimator, inputs, experiment_config, accept_eula): """Create a new Amazon SageMaker training job from the estimator. Args: @@ -2504,11 +2512,16 @@ def start_new(cls, estimator, inputs, experiment_config): will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. * `RunName` is used to record an experiment run. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started training job. """ - train_args = cls._get_train_args(estimator, inputs, experiment_config) + train_args = cls._get_train_args(estimator, inputs, experiment_config, accept_eula) logger.debug("Train args after processing defaults: %s", train_args) estimator.sagemaker_session.train(**train_args) @@ -2516,7 +2529,7 @@ def start_new(cls, estimator, inputs, experiment_config): return cls(estimator.sagemaker_session, estimator._current_job_name) @classmethod - def _get_train_args(cls, estimator, inputs, experiment_config): + def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula): """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator. Args: @@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config): will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. * `RunName` is used to record an experiment run. + accept_eula (bool): For models that require a Model Access Config, specify True or + False to indicate whether model terms of use have been accepted. + The `accept_eula` value must be explicitly defined as `True` in order to + accept the end-user license agreement (EULA) that some + models require. (Default: None). Returns: Dict: dict for `sagemaker.session.Session.train` method @@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config): if estimator.get_session_chaining_config() is not None: train_args["session_chaining_config"] = estimator.get_session_chaining_config() + if accept_eula is not None: + cls._set_accept_eula_for_input_data_config(train_args, accept_eula) + return train_args @classmethod @@ -2674,6 +2695,42 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args): raise ValueError("Setting checkpoint_local_path is not supported in local mode.") train_args["checkpoint_local_path"] = estimator.checkpoint_local_path + @classmethod + def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula): + """Set the AcceptEula flag for all input data configurations. + + This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources + in the InputDataConfig array. It handles cases where keys might not exist in the + nested dictionary structure. + + Args: + train_args (dict): The training job arguments dictionary + accept_eula (bool): The value to set for AcceptEula flag + """ + if "InputDataConfig" not in train_args: + return + + eula_count = 0 + s3_uris = [] + + for idx in range(len(train_args["InputDataConfig"])): + if "DataSource" in train_args["InputDataConfig"][idx]: + data_source = train_args["InputDataConfig"][idx]["DataSource"] + if "S3DataSource" in data_source: + s3_data_source = data_source["S3DataSource"] + if "ModelAccessConfig" not in s3_data_source: + s3_data_source["ModelAccessConfig"] = {} + s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula + eula_count += 1 + + # Collect S3 URI if available + if "S3Uri" in s3_data_source: + s3_uris.append(s3_data_source["S3Uri"]) + + # Log info if more than one EULA needs to be accepted + if eula_count > 1: + logger.info("Accepting EULA for %d S3 data sources: %s", eula_count, ", ".join(s3_uris)) + @classmethod def _is_local_channel(cls, input_uri): """Placeholder docstring""" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 4daf9b1810..6609999209 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -713,6 +713,7 @@ def fit( sagemaker_session=self.sagemaker_session, config_name=self.config_name, hub_access_config=self.hub_access_config, + accept_eula=accept_eula, ) remove_env_var_from_estimator_kwargs_if_model_access_config_present( self.init_kwargs, self.model_access_config diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 12eb30daaf..3ad9d61bac 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -266,6 +266,7 @@ def get_fit_kwargs( sagemaker_session: Optional[Session] = None, config_name: Optional[str] = None, hub_access_config: Optional[Dict] = None, + accept_eula: Optional[bool] = None, ) -> JumpStartEstimatorFitKwargs: """Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object.""" @@ -283,6 +284,7 @@ def get_fit_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, sagemaker_session=sagemaker_session, config_name=config_name, + accept_eula=accept_eula, ) estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs) diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 0cd4bcc902..b37d1317e6 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1940,9 +1940,6 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" - # gated model never use training model artifact - if self.gated_bucket: - return False # otherwise, return true is a training model package is not set return len(self.training_model_package_artifact_uris or {}) == 0 @@ -2595,6 +2592,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "sagemaker_session", "config_name", "specs", + "accept_eula", ] SERIALIZATION_EXCLUSION_SET = { @@ -2625,6 +2623,7 @@ def __init__( tolerate_vulnerable_model: Optional[bool] = None, sagemaker_session: Optional[Session] = None, config_name: Optional[str] = None, + accept_eula: Optional[bool] = None, ) -> None: """Instantiates JumpStartEstimatorInitKwargs object.""" @@ -2642,6 +2641,7 @@ def __init__( self.tolerate_vulnerable_model = tolerate_vulnerable_model self.sagemaker_session = sagemaker_session self.config_name = config_name + self.accept_eula = accept_eula class JumpStartEstimatorDeployKwargs(JumpStartKwargs): From 68202a702e6ffd7fae52963e1f3c0a99bc2c765d Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Tue, 13 May 2025 16:22:59 +0000 Subject: [PATCH 02/11] fix: optional accept_eula arg --- src/sagemaker/estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7f0d00e76d..31d9e983aa 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2492,7 +2492,7 @@ class _TrainingJob(_Job): """Placeholder docstring""" @classmethod - def start_new(cls, estimator, inputs, experiment_config, accept_eula): + def start_new(cls, estimator, inputs, experiment_config, accept_eula=None): """Create a new Amazon SageMaker training job from the estimator. Args: @@ -2529,7 +2529,7 @@ def start_new(cls, estimator, inputs, experiment_config, accept_eula): return cls(estimator.sagemaker_session, estimator._current_job_name) @classmethod - def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula): + def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula=None): """Constructs a dict of arguments for an Amazon SageMaker training job from the estimator. Args: From e9f2f65826266ef817f84b435eaafa057fef6965 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 14 May 2025 00:24:07 +0000 Subject: [PATCH 03/11] fix: unit tests --- src/sagemaker/estimator.py | 3 + src/sagemaker/model.py | 2 +- .../sagemaker/experiments/test_run_context.py | 6 +- .../jumpstart/estimator/test_estimator.py | 4 +- tests/unit/sagemaker/jumpstart/test_types.py | 2 +- tests/unit/test_estimator.py | 142 ++++++++++++++++++ 6 files changed, 153 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 31d9e983aa..d86a3f743d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2710,6 +2710,9 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula): if "InputDataConfig" not in train_args: return + if accept_eula is None: + return + eula_count = 0 s3_uris = [] diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 3bfac0c8da..b281d9f489 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1805,7 +1805,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, explainer_config_dict=explainer_config_dict, async_inference_config_dict=async_inference_config_dict, - serverless_inference_config_dict=serverless_inference_config_dict, + serverless_inference_config=serverless_inference_config_dict, routing_config=routing_config, inference_ami_version=inference_ami_version, ) diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py index 7026c48f41..7a63e5eaa2 100644 --- a/tests/unit/sagemaker/experiments/test_run_context.py +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -54,7 +54,7 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker assert _RunContext.get_current_run() == run_obj expected_exp_config = run_obj.experiment_config - mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None) # _RunContext is cleaned up after exiting the with statement assert not _RunContext.get_current_run() @@ -94,7 +94,7 @@ def test_auto_pass_in_exp_config_under_load_run( assert loaded_run.experiment_config == run_obj.experiment_config expected_exp_config = run_obj.experiment_config - mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None) # _RunContext is cleaned up after exiting the with statement assert not _RunContext.get_current_run() @@ -174,7 +174,7 @@ def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_ assert _RunContext.get_current_run() == run_obj - mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg) + mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg, None) # _RunContext is cleaned up after exiting the with statement assert not _RunContext.get_current_run() diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 4a64b413f4..fba4aa8764 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -607,6 +607,7 @@ def test_gated_model_s3_uri_with_eula_in_fit( inputs=channels, wait=True, job_name="meta-textgeneration-llama-2-7b-f-8675309", + accept_eula=True, ) assert hasattr(estimator, "model_access_config") @@ -688,6 +689,7 @@ def test_gated_model_non_model_package_s3_uri( instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt" "orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + model_uri="s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key", source_dir="s3://jumpstart-cache-prod-us-west-2/source-d" "irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", entry_point="transfer_learning.py", @@ -1346,7 +1348,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["kwargs"]) - fit_args_to_skip: Set[str] = set(["accept_eula"]) + fit_args_to_skip: Set[str] = set([]) deploy_args_to_skip: Set[str] = set(["kwargs"]) parent_class_init = Estimator.__init__ diff --git a/tests/unit/sagemaker/jumpstart/test_types.py b/tests/unit/sagemaker/jumpstart/test_types.py index 0b5ef63947..3ebb1d0c9e 100644 --- a/tests/unit/sagemaker/jumpstart/test_types.py +++ b/tests/unit/sagemaker/jumpstart/test_types.py @@ -333,7 +333,7 @@ def test_use_training_model_artifact(): specs1 = JumpStartModelSpecs(BASE_SPEC) assert specs1.use_training_model_artifact() specs1.gated_bucket = True - assert not specs1.use_training_model_artifact() + assert specs1.use_training_model_artifact() specs1.gated_bucket = False specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"} assert not specs1.use_training_model_artifact() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 8294eb0039..828e55f8a3 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -336,6 +336,148 @@ def training_job_description(sagemaker_session): return returned_job_description +def test_set_accept_eula_for_input_data_config_no_input_data_config(): + """Test when InputDataConfig is not in train_args.""" + train_args = {} + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify train_args remains unchanged + assert train_args == {} + + +def test_set_accept_eula_for_input_data_config_none_accept_eula(): + """Test when accept_eula is None.""" + train_args = {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]} + accept_eula = None + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify train_args remains unchanged + assert train_args == {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]} + + +def test_set_accept_eula_for_input_data_config_single_data_source(): + """Test with a single S3DataSource.""" + with patch("sagemaker.estimator.logger") as logger: + train_args = { + "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}] + } + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify ModelAccessConfig and AcceptEula are set correctly + assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ + "ModelAccessConfig" + ] == {"AcceptEula": True} + + # Verify no logging occurred since there's only one data source + logger.info.assert_not_called() + + +def test_set_accept_eula_for_input_data_config_multiple_data_sources(): + """Test with multiple S3DataSources.""" + with patch("sagemaker.estimator.logger") as logger: + train_args = { + "InputDataConfig": [ + {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}}, + {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}}, + ] + } + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources + assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ + "ModelAccessConfig" + ] == {"AcceptEula": True} + assert train_args["InputDataConfig"][1]["DataSource"]["S3DataSource"][ + "ModelAccessConfig" + ] == {"AcceptEula": True} + + # Verify logging occurred with correct information + logger.info.assert_called_once() + args = logger.info.call_args[0] + assert args[0] == "Accepting EULA for %d S3 data sources: %s" + assert args[1] == 2 + assert args[2] == "s3://bucket/model1, s3://bucket/model2" + + +def test_set_accept_eula_for_input_data_config_existing_model_access_config(): + """Test when ModelAccessConfig already exists.""" + train_args = { + "InputDataConfig": [ + { + "DataSource": { + "S3DataSource": { + "S3Uri": "s3://bucket/model", + "ModelAccessConfig": {"OtherSetting": "value"}, + } + } + } + ] + } + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify AcceptEula is added to existing ModelAccessConfig + assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "OtherSetting": "value", + "AcceptEula": True, + } + + +def test_set_accept_eula_for_input_data_config_missing_s3_data_source(): + """Test when S3DataSource is missing.""" + train_args = {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]} + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify train_args remains unchanged + assert train_args == {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]} + + +def test_set_accept_eula_for_input_data_config_missing_data_source(): + """Test when DataSource is missing.""" + train_args = {"InputDataConfig": [{"OtherKey": {}}]} + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify train_args remains unchanged + assert train_args == {"InputDataConfig": [{"OtherKey": {}}]} + + +def test_set_accept_eula_for_input_data_config_mixed_data_sources(): + """Test with a mix of S3DataSource and other data sources.""" + with patch("sagemaker.estimator.logger") as logger: + train_args = { + "InputDataConfig": [ + {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}, + {"DataSource": {"OtherDataSource": {}}}, + ] + } + accept_eula = True + + EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + + # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only + assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ + "ModelAccessConfig" + ] == {"AcceptEula": True} + assert "ModelAccessConfig" not in train_args["InputDataConfig"][1]["DataSource"].get( + "OtherDataSource", {} + ) + + # Verify no logging occurred since there's only one S3 data source + logger.info.assert_not_called() + + def test_validate_smdistributed_unsupported_image_raises(sagemaker_session): # Test unsupported image raises error. for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM: From 5de1e55d4b8a9b8bec43c21dc63261868e84a1b4 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 14 May 2025 02:33:05 +0000 Subject: [PATCH 04/11] fix: unit tests --- tests/unit/test_estimator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 828e55f8a3..7548eafe66 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -341,7 +341,7 @@ def test_set_accept_eula_for_input_data_config_no_input_data_config(): train_args = {} accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {} @@ -352,7 +352,7 @@ def test_set_accept_eula_for_input_data_config_none_accept_eula(): train_args = {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]} accept_eula = None - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]} @@ -366,7 +366,7 @@ def test_set_accept_eula_for_input_data_config_single_data_source(): } accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ @@ -388,7 +388,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources(): } accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ @@ -422,7 +422,7 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config(): } accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify AcceptEula is added to existing ModelAccessConfig assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { @@ -436,7 +436,7 @@ def test_set_accept_eula_for_input_data_config_missing_s3_data_source(): train_args = {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]} accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]} @@ -447,7 +447,7 @@ def test_set_accept_eula_for_input_data_config_missing_data_source(): train_args = {"InputDataConfig": [{"OtherKey": {}}]} accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {"InputDataConfig": [{"OtherKey": {}}]} @@ -464,7 +464,7 @@ def test_set_accept_eula_for_input_data_config_mixed_data_sources(): } accept_eula = True - EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ From 402fbc95ea0efeb86c65be79602864a97214bd3e Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 14 May 2025 16:23:47 +0000 Subject: [PATCH 05/11] fix: support legacy training models, fix cache override for unsupported files --- src/sagemaker/jumpstart/cache.py | 16 ++++++++++++---- src/sagemaker/jumpstart/types.py | 15 +++++++++++++-- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 29a903e00b..5a4be3f53f 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -372,10 +372,18 @@ def _get_json_file( object and None when reading from the local file system. """ if self._is_local_metadata_mode(): - file_content, etag = self._get_json_file_from_local_override(key, filetype), None - else: - file_content, etag = self._get_json_file_and_etag_from_s3(key) - return file_content, etag + if filetype in { + JumpStartS3FileType.OPEN_WEIGHT_MANIFEST, + JumpStartS3FileType.OPEN_WEIGHT_SPECS, + }: + return self._get_json_file_from_local_override(key, filetype), None + else: + JUMPSTART_LOGGER.warning( + "Local metadata mode is enabled, but the file type %s is not supported " + "for local override. Falling back to s3.", + filetype, + ) + return self._get_json_file_and_etag_from_s3(key) def _get_json_md5_hash(self, key: str): """Retrieves md5 object hash for s3 objects, using `s3.head_object`. diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index b37d1317e6..44739669f1 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1940,9 +1940,20 @@ def use_inference_script_uri(self) -> bool: def use_training_model_artifact(self) -> bool: """Returns True if the model should use a model uri when kicking off training job.""" + # old models with this environment variable present don't use model channel + if any( + self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value( + instance_type + ) + for instance_type in self.supported_training_instance_types + ): + return False + + # even older models with training model package artifact uris present also don't use model channel + if len(self.training_model_package_artifact_uris or {}) > 0: + return False - # otherwise, return true is a training model package is not set - return len(self.training_model_package_artifact_uris or {}) == 0 + return getattr(self, "training_artifact_key", None) is not None def is_gated_model(self) -> bool: """Returns True if the model has a EULA key or the model bucket is gated.""" From c9ceff259b2bdcb36fc4fbe95449009dd15b9adc Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Wed, 14 May 2025 19:18:08 +0000 Subject: [PATCH 06/11] fix: unit tests --- tests/unit/sagemaker/jumpstart/constants.py | 465 +++++++++++++++++- .../jumpstart/estimator/test_estimator.py | 160 +++++- 2 files changed, 619 insertions(+), 6 deletions(-) diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index ae02c597da..481b7b4b94 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -7821,7 +7821,7 @@ "bedrock_console_supported": True, "bedrock_io_mapping_id": "tgi_default_1.0.0", }, - "js-gated-artifact-non-model-package-trainable-model": { + "js-gated-artifact-gated-env-var-trainable-model": { "model_id": "meta-textgeneration-llama-2-7b", "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", "version": "3.0.0", @@ -8285,6 +8285,469 @@ }, "dynamic_container_deployment_supported": False, }, + "js-gated-artifact-use-model-channel": { + "model_id": "meta-textgeneration-llama-2-7b", + "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", + "version": "3.0.0", + "min_sdk_version": "2.189.0", + "training_supported": True, + "incremental_training_supported": False, + "hosting_ecr_specs": { + "framework": "huggingface-llm", + "framework_version": "1.1.0", + "py_version": "py39", + }, + "training_artifact_key": "some/dummy/key", + "hosting_artifact_key": "meta-textgeneration/meta-textgeneration-llama-2-7b/artifacts/inference/v1.0.0/", + "hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.2.3/sourcedir.tar.gz", + "hosting_prepacked_artifact_key": "meta-textgeneration/meta-textgen" + "eration-llama-2-7b/artifacts/inference-prepack/v1.0.0/", + "hosting_prepacked_artifact_version": "1.0.0", + "hosting_use_script_uri": False, + "hosting_eula_key": "fmhMetadata/eula/llamaEula.txt", + "inference_vulnerable": False, + "inference_dependencies": [ + "sagemaker_jumpstart_huggingface_script_utilities==1.0.8", + "sagemaker_jumpstart_script_utilities==1.1.8", + ], + "inference_vulnerabilities": [], + "training_vulnerable": False, + "training_dependencies": [ + "accelerate==0.21.0", + "bitsandbytes==0.39.1", + "black==23.7.0", + "brotli==1.0.9", + "datasets==2.14.1", + "fire==0.5.0", + "inflate64==0.3.1", + "loralib==0.1.1", + "multivolumefile==0.2.3", + "mypy-extensions==1.0.0", + "pathspec==0.11.1", + "peft==0.4.0", + "py7zr==0.20.5", + "pybcj==1.0.1", + "pycryptodomex==3.18.0", + "pyppmd==1.0.0", + "pytorch-triton==2.1.0+e6216047b8", + "pyzstd==0.15.9", + "safetensors==0.3.1", + "sagemaker_jumpstart_huggingface_script_utilities==1.1.3", + "sagemaker_jumpstart_script_utilities==1.1.9", + "scipy==1.11.1", + "termcolor==2.3.0", + "texttable==1.6.7", + "tokenize-rt==5.1.0", + "tokenizers==0.13.3", + "torch==2.1.0.dev20230905+cu118", + "transformers==4.31.0", + ], + "training_vulnerabilities": [], + "deprecated": False, + "hyperparameters": [ + { + "name": "int8_quantization", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "enable_fsdp", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "epoch", + "type": "int", + "default": 5, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "learning_rate", + "type": "float", + "default": 0.0001, + "min": 1e-08, + "max": 1, + "scope": "algorithm", + }, + {"name": "lora_r", "type": "int", "default": 8, "min": 1, "scope": "algorithm"}, + {"name": "lora_alpha", "type": "int", "default": 32, "min": 1, "scope": "algorithm"}, + { + "name": "lora_dropout", + "type": "float", + "default": 0.05, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "instruction_tuned", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "chat_dataset", + "type": "text", + "default": "False", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "add_input_output_demarcation_key", + "type": "text", + "default": "True", + "options": ["True", "False"], + "scope": "algorithm", + }, + { + "name": "per_device_train_batch_size", + "type": "int", + "default": 4, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "per_device_eval_batch_size", + "type": "int", + "default": 1, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_train_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "max_val_samples", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "seed", + "type": "int", + "default": 10, + "min": 1, + "max": 1000, + "scope": "algorithm", + }, + { + "name": "max_input_length", + "type": "int", + "default": -1, + "min": -1, + "scope": "algorithm", + }, + { + "name": "validation_split_ratio", + "type": "float", + "default": 0.2, + "min": 0, + "max": 1, + "scope": "algorithm", + }, + { + "name": "train_data_split_seed", + "type": "int", + "default": 0, + "min": 0, + "scope": "algorithm", + }, + { + "name": "preprocessing_num_workers", + "type": "text", + "default": "None", + "scope": "algorithm", + }, + { + "name": "sagemaker_submit_directory", + "type": "text", + "default": "/opt/ml/input/data/code/sourcedir.tar.gz", + "scope": "container", + }, + { + "name": "sagemaker_program", + "type": "text", + "default": "transfer_learning.py", + "scope": "container", + }, + { + "name": "sagemaker_container_log_level", + "type": "text", + "default": "20", + "scope": "container", + }, + ], + "training_script_key": "source-directory-tarballs/" + "meta/transfer_learning/textgeneration/v1.0.4/sourcedir.tar.gz", + "training_prepacked_script_key": "source-directory-" + "tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", + "training_prepacked_script_version": "1.0.1", + "training_ecr_specs": { + "framework": "huggingface", + "framework_version": "2.0.0", + "py_version": "py310", + "huggingface_transformers_version": "4.28.1", + }, + "inference_environment_variables": [ + { + "name": "SAGEMAKER_PROGRAM", + "type": "text", + "default": "inference.py", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_SUBMIT_DIRECTORY", + "type": "text", + "default": "/opt/ml/model/code", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_CONTAINER_LOG_LEVEL", + "type": "text", + "default": "20", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_TIMEOUT", + "type": "text", + "default": "3600", + "scope": "container", + "required_for_model_class": False, + }, + { + "name": "ENDPOINT_SERVER_TIMEOUT", + "type": "int", + "default": 3600, + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MODEL_CACHE_ROOT", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_ENV", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "HF_MODEL_ID", + "type": "text", + "default": "/opt/ml/model", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_INPUT_LENGTH", + "type": "text", + "default": "4095", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "MAX_TOTAL_TOKENS", + "type": "text", + "default": "4096", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SM_NUM_GPUS", + "type": "text", + "default": "1", + "scope": "container", + "required_for_model_class": True, + }, + { + "name": "SAGEMAKER_MODEL_SERVER_WORKERS", + "type": "int", + "default": 1, + "scope": "container", + "required_for_model_class": True, + }, + ], + "metrics": [ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + "default_inference_instance_type": "ml.g5.2xlarge", + "supported_inference_instance_types": [ + "ml.g5.2xlarge", + "ml.g5.4xlarge", + "ml.g5.8xlarge", + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p4d.24xlarge", + ], + "default_training_instance_type": "ml.g5.12xlarge", + "supported_training_instance_types": [ + "ml.g5.12xlarge", + "ml.g5.24xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + ], + "model_kwargs": {}, + "deploy_kwargs": { + "model_data_download_timeout": 1200, + "container_startup_health_check_timeout": 1200, + }, + "estimator_kwargs": {"encrypt_inter_container_traffic": True, "max_run": 360000}, + "fit_kwargs": {}, + "predictor_specs": { + "supported_content_types": ["application/json"], + "supported_accept_types": ["application/json"], + "default_content_type": "application/json", + "default_accept_type": "application/json", + }, + "inference_volume_size": 256, + "training_volume_size": 256, + "inference_enable_network_isolation": True, + "training_enable_network_isolation": True, + "default_training_dataset_key": "training-datasets/sec_amazon/", + "validation_supported": True, + "fine_tuning_supported": True, + "resource_name_base": "meta-textgeneration-llama-2-7b", + "default_payloads": { + "meaningOfLife": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "I believe the meaning of life is", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "theoryOfRelativity": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "Simply put, the theory of relativity states that ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "teamMessage": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "A brief message congratulating the team on the launch:\n\nHi everyone,\n\nI just ", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "englishToFrench": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": {"generated_text": "[0].generated_text"}, + "body": { + "inputs": "Translate English to French:\nsea o" + "tter => loutre de mer\npeppermint => ment" + "he poivr\u00e9e\nplush girafe => girafe peluche\ncheese =>", + "parameters": {"max_new_tokens": 64, "top_p": 0.9, "temperature": 0.6}, + }, + }, + "Story": { + "content_type": "application/json", + "prompt_key": "inputs", + "output_keys": { + "generated_text": "[0].generated_text", + "input_logprobs": "[0].details.prefill[*].logprob", + }, + "body": { + "inputs": "Please tell me a story.", + "parameters": { + "max_new_tokens": 64, + "top_p": 0.9, + "temperature": 0.2, + "decoder_input_details": True, + "details": True, + }, + }, + }, + }, + "gated_bucket": True, + "hosting_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazonaws.com/h" + "uggingface-pytorch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "ml.g5.12xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "4"}}}, + "ml.g5.48xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + "ml.p4d.24xlarge": {"properties": {"environment_variables": {"SM_NUM_GPUS": "8"}}}, + }, + }, + "training_instance_type_variants": { + "regional_aliases": { + "us-west-2": { + "gpu_ecr_uri_1": "763104351884.dkr.ecr.us-west-2.amazon" + "aws.com/huggingface-pytorch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04" + }, + }, + "variants": { + "g4dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "g5": { + "regional_properties": {"image_uri": "$gpu_ecr_uri_1"}, + "properties": { + "environment_variables": {"SELF_DESTRUCT": "true"}, + }, + }, + "local_gpu": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p2": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p3dn": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4d": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p4de": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + "p5": {"regional_properties": {"image_uri": "$gpu_ecr_uri_1"}}, + }, + }, + "dynamic_container_deployment_supported": False, + }, "js-gated-artifact-trainable-model": { "model_id": "meta-textgeneration-llama-2-7b-f", "url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/", diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index fba4aa8764..5aed17efa3 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -656,7 +656,7 @@ def test_gated_model_s3_uri_with_eula_in_fit( @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) - def test_gated_model_non_model_package_s3_uri( + def test_gated_model_gated_model_no_model_channel_due_to_gated_env_var( self, mock_estimator_deploy: mock.Mock, mock_estimator_fit: mock.Mock, @@ -675,7 +675,7 @@ def test_gated_model_non_model_package_s3_uri( mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS - model_id, _ = "js-gated-artifact-non-model-package-trainable-model", "*" + model_id, _ = "js-gated-artifact-gated-env-var-trainable-model", "*" mock_get_model_specs.side_effect = get_special_model_spec @@ -689,7 +689,6 @@ def test_gated_model_non_model_package_s3_uri( instance_count=1, image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt" "orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", - model_uri="s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key", source_dir="s3://jumpstart-cache-prod-us-west-2/source-d" "irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", entry_point="transfer_learning.py", @@ -734,7 +733,7 @@ def test_gated_model_non_model_package_s3_uri( tags=[ { "Key": "sagemaker-sdk:jumpstart-model-id", - "Value": "js-gated-artifact-non-model-package-trainable-model", + "Value": "js-gated-artifact-gated-env-var-trainable-model", }, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "3.0.0"}, ], @@ -782,7 +781,158 @@ def test_gated_model_non_model_package_s3_uri( tags=[ { "Key": "sagemaker-sdk:jumpstart-model-id", - "Value": "js-gated-artifact-non-model-package-trainable-model", + "Value": "js-gated-artifact-gated-env-var-trainable-model", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "3.0.0"}, + ], + wait=True, + model_data_download_timeout=1200, + container_startup_health_check_timeout=1200, + role="fake role! do not use!", + enable_network_isolation=True, + model_name="meta-textgeneration-llama-2-7b-8675309", + use_compiled_model=False, + ) + + @mock.patch( + "sagemaker.jumpstart.artifacts.environment_variables.get_jumpstart_gated_content_bucket" + ) + @mock.patch("sagemaker.utils.sagemaker_timestamp") + @mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type") + @mock.patch( + "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch( + "sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix" + ) + @mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.fit") + @mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy") + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region) + def test_gated_model_gated_model_with_model_channel( + self, + mock_estimator_deploy: mock.Mock, + mock_estimator_fit: mock.Mock, + mock_estimator_init: mock.Mock, + mock_get_model_specs: mock.Mock, + mock_session_estimator: mock.Mock, + mock_session_model: mock.Mock, + mock_validate_model_id_and_get_type: mock.Mock, + mock_timestamp: mock.Mock, + mock_get_jumpstart_gated_content_bucket: mock.Mock, + ): + mock_estimator_deploy.return_value = default_predictor + + mock_get_jumpstart_gated_content_bucket.return_value = "top-secret-private-models-bucket" + mock_timestamp.return_value = "8675309" + + mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS + + model_id, _ = "js-gated-artifact-use-model-channel", "*" + + mock_get_model_specs.side_effect = get_special_model_spec + + mock_session_estimator.return_value = sagemaker_session + mock_session_model.return_value = sagemaker_session + + estimator = JumpStartEstimator(model_id=model_id) + + mock_estimator_init.assert_called_once_with( + instance_type="ml.g5.12xlarge", + instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt" + "orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04", + model_uri="s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key", + source_dir="s3://jumpstart-cache-prod-us-west-2/source-d" + "irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz", + entry_point="transfer_learning.py", + hyperparameters={ + "int8_quantization": "False", + "enable_fsdp": "True", + "epoch": "5", + "learning_rate": "0.0001", + "lora_r": "8", + "lora_alpha": "32", + "lora_dropout": "0.05", + "instruction_tuned": "False", + "chat_dataset": "False", + "add_input_output_demarcation_key": "True", + "per_device_train_batch_size": "4", + "per_device_eval_batch_size": "1", + "max_train_samples": "-1", + "max_val_samples": "-1", + "seed": "10", + "max_input_length": "-1", + "validation_split_ratio": "0.2", + "train_data_split_seed": "0", + "preprocessing_num_workers": "None", + }, + metric_definitions=[ + { + "Name": "huggingface-textgeneration:eval-loss", + "Regex": "eval_epoch_loss=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:eval-ppl", + "Regex": "eval_ppl=tensor\\(([0-9\\.]+)", + }, + { + "Name": "huggingface-textgeneration:train-loss", + "Regex": "train_epoch_loss=([0-9\\.]+)", + }, + ], + role="fake role! do not use!", + max_run=360000, + sagemaker_session=sagemaker_session, + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-use-model-channel", + }, + {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "3.0.0"}, + ], + encrypt_inter_container_traffic=True, + enable_network_isolation=True, + environment={"SELF_DESTRUCT": "true"}, + ) + + channels = { + "training": f"s3://{get_jumpstart_content_bucket(region)}/" + f"some-training-dataset-doesn't-matter", + } + + estimator.fit(channels) + + mock_estimator_fit.assert_called_once_with( + inputs=channels, wait=True, job_name="meta-textgeneration-llama-2-7b-8675309" + ) + + estimator.deploy() + + mock_estimator_deploy.assert_called_once_with( + instance_type="ml.g5.2xlarge", + initial_instance_count=1, + image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytor" + "ch-tgi-inference:2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04", + env={ + "SAGEMAKER_PROGRAM": "inference.py", + "ENDPOINT_SERVER_TIMEOUT": "3600", + "MODEL_CACHE_ROOT": "/opt/ml/model", + "SAGEMAKER_ENV": "1", + "HF_MODEL_ID": "/opt/ml/model", + "MAX_INPUT_LENGTH": "4095", + "MAX_TOTAL_TOKENS": "4096", + "SM_NUM_GPUS": "1", + "SAGEMAKER_MODEL_SERVER_WORKERS": "1", + }, + predictor_cls=Predictor, + endpoint_name="meta-textgeneration-llama-2-7b-8675309", + tags=[ + { + "Key": "sagemaker-sdk:jumpstart-model-id", + "Value": "js-gated-artifact-use-model-channel", }, {"Key": "sagemaker-sdk:jumpstart-model-version", "Value": "3.0.0"}, ], From 4d4597355e32e64133a0abcfb99d34f77e91ffb0 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 15 May 2025 15:01:59 +0000 Subject: [PATCH 07/11] chore: add unit test + minor fix --- src/sagemaker/estimator.py | 10 ++--- tests/unit/test_estimator.py | 75 ++++++++++++++++++++++++------------ 2 files changed, 55 insertions(+), 30 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index d86a3f743d..f6f3de3644 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2700,14 +2700,14 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula): """Set the AcceptEula flag for all input data configurations. This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources - in the InputDataConfig array. It handles cases where keys might not exist in the + in the input_config array. It handles cases where keys might not exist in the nested dictionary structure. Args: train_args (dict): The training job arguments dictionary accept_eula (bool): The value to set for AcceptEula flag """ - if "InputDataConfig" not in train_args: + if "input_config" not in train_args: return if accept_eula is None: @@ -2716,9 +2716,9 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula): eula_count = 0 s3_uris = [] - for idx in range(len(train_args["InputDataConfig"])): - if "DataSource" in train_args["InputDataConfig"][idx]: - data_source = train_args["InputDataConfig"][idx]["DataSource"] + for idx in range(len(train_args["input_config"])): + if "DataSource" in train_args["input_config"][idx]: + data_source = train_args["input_config"][idx]["DataSource"] if "S3DataSource" in data_source: s3_data_source = data_source["S3DataSource"] if "ModelAccessConfig" not in s3_data_source: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 7548eafe66..52678ef5a9 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -337,7 +337,7 @@ def training_job_description(sagemaker_session): def test_set_accept_eula_for_input_data_config_no_input_data_config(): - """Test when InputDataConfig is not in train_args.""" + """Test when input_config is not in train_args.""" train_args = {} accept_eula = True @@ -349,29 +349,29 @@ def test_set_accept_eula_for_input_data_config_no_input_data_config(): def test_set_accept_eula_for_input_data_config_none_accept_eula(): """Test when accept_eula is None.""" - train_args = {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]} + train_args = {"input_config": [{"DataSource": {"S3DataSource": {}}}]} accept_eula = None _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged - assert train_args == {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]} + assert train_args == {"input_config": [{"DataSource": {"S3DataSource": {}}}]} def test_set_accept_eula_for_input_data_config_single_data_source(): """Test with a single S3DataSource.""" with patch("sagemaker.estimator.logger") as logger: train_args = { - "InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}] + "input_config": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}] } accept_eula = True _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly - assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ - "ModelAccessConfig" - ] == {"AcceptEula": True} + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } # Verify no logging occurred since there's only one data source logger.info.assert_not_called() @@ -381,7 +381,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources(): """Test with multiple S3DataSources.""" with patch("sagemaker.estimator.logger") as logger: train_args = { - "InputDataConfig": [ + "input_config": [ {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}}, {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}}, ] @@ -391,12 +391,12 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources(): _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources - assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ - "ModelAccessConfig" - ] == {"AcceptEula": True} - assert train_args["InputDataConfig"][1]["DataSource"]["S3DataSource"][ - "ModelAccessConfig" - ] == {"AcceptEula": True} + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } + assert train_args["input_config"][1]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } # Verify logging occurred with correct information logger.info.assert_called_once() @@ -409,7 +409,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources(): def test_set_accept_eula_for_input_data_config_existing_model_access_config(): """Test when ModelAccessConfig already exists.""" train_args = { - "InputDataConfig": [ + "input_config": [ { "DataSource": { "S3DataSource": { @@ -425,7 +425,7 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config(): _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify AcceptEula is added to existing ModelAccessConfig - assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { "OtherSetting": "value", "AcceptEula": True, } @@ -433,31 +433,31 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config(): def test_set_accept_eula_for_input_data_config_missing_s3_data_source(): """Test when S3DataSource is missing.""" - train_args = {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]} + train_args = {"input_config": [{"DataSource": {"OtherDataSource": {}}}]} accept_eula = True _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged - assert train_args == {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]} + assert train_args == {"input_config": [{"DataSource": {"OtherDataSource": {}}}]} def test_set_accept_eula_for_input_data_config_missing_data_source(): """Test when DataSource is missing.""" - train_args = {"InputDataConfig": [{"OtherKey": {}}]} + train_args = {"input_config": [{"OtherKey": {}}]} accept_eula = True _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged - assert train_args == {"InputDataConfig": [{"OtherKey": {}}]} + assert train_args == {"input_config": [{"OtherKey": {}}]} def test_set_accept_eula_for_input_data_config_mixed_data_sources(): """Test with a mix of S3DataSource and other data sources.""" with patch("sagemaker.estimator.logger") as logger: train_args = { - "InputDataConfig": [ + "input_config": [ {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}, {"DataSource": {"OtherDataSource": {}}}, ] @@ -467,10 +467,10 @@ def test_set_accept_eula_for_input_data_config_mixed_data_sources(): _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only - assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][ - "ModelAccessConfig" - ] == {"AcceptEula": True} - assert "ModelAccessConfig" not in train_args["InputDataConfig"][1]["DataSource"].get( + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } + assert "ModelAccessConfig" not in train_args["input_config"][1]["DataSource"].get( "OtherDataSource", {} ) @@ -2705,6 +2705,31 @@ def test_fit_verify_job_name(strftime, sagemaker_session): assert fw.latest_training_job.name == JOB_NAME +@patch("time.strftime", return_value=TIMESTAMP) +def test_fit_verify_accept_eula(strftime, sagemaker_session): + fw = DummyFramework( + entry_point=SCRIPT_PATH, + role="DummyRole", + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + tags=TAGS, + encrypt_inter_container_traffic=True, + ) + fw.fit(inputs=TrainingInput("s3://mybucket/train"), accept_eula=True) + + _, _, train_kwargs = sagemaker_session.train.mock_calls[0] + + assert ( + train_kwargs["input_config"][0] + .get("DataSource", {}) + .get("S3DataSource", {}) + .get("ModelAccessConfig", {}) + .get("AcceptEula") + is True + ) + + @pytest.mark.parametrize( "debugger_hook_config_direct_input, sagemaker_config, expected_debugger_hook_config_output", [ From f26a1700e7239b84ea021c0944cf31351c8b030f Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 15 May 2025 16:19:55 +0000 Subject: [PATCH 08/11] chore: only attach eula for model channel --- src/sagemaker/estimator.py | 27 +++---- tests/unit/test_estimator.py | 141 +++++++++++++++++++++-------------- 2 files changed, 94 insertions(+), 74 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index f6f3de3644..e62b1b8c7a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -2671,7 +2671,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula=None) train_args["session_chaining_config"] = estimator.get_session_chaining_config() if accept_eula is not None: - cls._set_accept_eula_for_input_data_config(train_args, accept_eula) + cls._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) return train_args @@ -2696,11 +2696,11 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args): train_args["checkpoint_local_path"] = estimator.checkpoint_local_path @classmethod - def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula): - """Set the AcceptEula flag for all input data configurations. + def _set_accept_eula_for_model_channel_input_data_config(cls, train_args, accept_eula): + """Set the AcceptEula flag for model channel in input data configurations. - This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources - in the input_config array. It handles cases where keys might not exist in the + This method sets the AcceptEula flag in the ModelAccessConfig for the model channel + S3DataSource in the input_config array. It handles cases where keys might not exist in the nested dictionary structure. Args: @@ -2713,26 +2713,17 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula): if accept_eula is None: return - eula_count = 0 - s3_uris = [] - for idx in range(len(train_args["input_config"])): - if "DataSource" in train_args["input_config"][idx]: + if ( + "DataSource" in train_args["input_config"][idx] + and train_args["input_config"][idx]["ChannelName"].lower().strip() == "model" + ): data_source = train_args["input_config"][idx]["DataSource"] if "S3DataSource" in data_source: s3_data_source = data_source["S3DataSource"] if "ModelAccessConfig" not in s3_data_source: s3_data_source["ModelAccessConfig"] = {} s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula - eula_count += 1 - - # Collect S3 URI if available - if "S3Uri" in s3_data_source: - s3_uris.append(s3_data_source["S3Uri"]) - - # Log info if more than one EULA needs to be accepted - if eula_count > 1: - logger.info("Accepting EULA for %d S3 data sources: %s", eula_count, ", ".join(s3_uris)) @classmethod def _is_local_channel(cls, input_uri): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 52678ef5a9..5621723759 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -336,93 +336,119 @@ def training_job_description(sagemaker_session): return returned_job_description -def test_set_accept_eula_for_input_data_config_no_input_data_config(): +def test_set_accept_eula_for_model_channel_input_data_config_no_input_data_config(): """Test when input_config is not in train_args.""" train_args = {} accept_eula = True - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {} -def test_set_accept_eula_for_input_data_config_none_accept_eula(): +def test_set_accept_eula_for_model_channel_input_data_config_none_accept_eula(): """Test when accept_eula is None.""" train_args = {"input_config": [{"DataSource": {"S3DataSource": {}}}]} accept_eula = None - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {"input_config": [{"DataSource": {"S3DataSource": {}}}]} -def test_set_accept_eula_for_input_data_config_single_data_source(): +def test_set_accept_eula_for_model_channel_input_data_config_single_data_source(): """Test with a single S3DataSource.""" with patch("sagemaker.estimator.logger") as logger: train_args = { - "input_config": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}] + "input_config": [ + { + "ChannelName": "model", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}, + } + ] } accept_eula = True - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) # Verify ModelAccessConfig and AcceptEula are set correctly assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { "AcceptEula": True } - # Verify no logging occurred since there's only one data source + # Verify no logging occurred logger.info.assert_not_called() -def test_set_accept_eula_for_input_data_config_multiple_data_sources(): - """Test with multiple S3DataSources.""" +def test_set_accept_eula_for_model_channel_input_data_config_single_data_source(): + """Test with a single S3DataSource.""" with patch("sagemaker.estimator.logger") as logger: train_args = { "input_config": [ - {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}}, - {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}}, + { + "ChannelName": "NotModel", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/not-model"}}, + } ] } accept_eula = True - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) + + # Verify ModelAccessConfig and AcceptEula are set correctly + assert ( + train_args["input_config"][0]["DataSource"]["S3DataSource"].get("ModelAccessConfig") + == None + ) - # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources - assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { - "AcceptEula": True - } - assert train_args["input_config"][1]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { - "AcceptEula": True - } - # Verify logging occurred with correct information - logger.info.assert_called_once() - args = logger.info.call_args[0] - assert args[0] == "Accepting EULA for %d S3 data sources: %s" - assert args[1] == 2 - assert args[2] == "s3://bucket/model1, s3://bucket/model2" +def test_set_accept_eula_for_model_channel_input_data_config_multiple_model_channels(): + """Test with multiple model channels.""" + train_args = { + "input_config": [ + { + "ChannelName": "model", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}, + }, + { + "ChannelName": "model", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}, + }, + ] + } + accept_eula = True + + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) + + # Verify ModelAccessConfig and AcceptEula are set correctly for both model channels + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } + assert train_args["input_config"][1]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } -def test_set_accept_eula_for_input_data_config_existing_model_access_config(): +def test_set_accept_eula_for_model_channel_input_data_config_existing_model_access_config(): """Test when ModelAccessConfig already exists.""" train_args = { "input_config": [ { + "ChannelName": "model", "DataSource": { "S3DataSource": { "S3Uri": "s3://bucket/model", "ModelAccessConfig": {"OtherSetting": "value"}, } - } + }, } ] } accept_eula = True - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) # Verify AcceptEula is added to existing ModelAccessConfig assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { @@ -431,51 +457,52 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config(): } -def test_set_accept_eula_for_input_data_config_missing_s3_data_source(): +def test_set_accept_eula_for_model_channel_input_data_config_missing_s3_data_source(): """Test when S3DataSource is missing.""" - train_args = {"input_config": [{"DataSource": {"OtherDataSource": {}}}]} + train_args = {"input_config": [{"ChannelName": "model", "DataSource": {"OtherDataSource": {}}}]} accept_eula = True - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged - assert train_args == {"input_config": [{"DataSource": {"OtherDataSource": {}}}]} + assert train_args == { + "input_config": [{"ChannelName": "model", "DataSource": {"OtherDataSource": {}}}] + } -def test_set_accept_eula_for_input_data_config_missing_data_source(): +def test_set_accept_eula_for_model_channel_input_data_config_missing_data_source(): """Test when DataSource is missing.""" train_args = {"input_config": [{"OtherKey": {}}]} accept_eula = True - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) # Verify train_args remains unchanged assert train_args == {"input_config": [{"OtherKey": {}}]} -def test_set_accept_eula_for_input_data_config_mixed_data_sources(): +def test_set_accept_eula_for_model_channel_input_data_config_mixed_data_sources(): """Test with a mix of S3DataSource and other data sources.""" - with patch("sagemaker.estimator.logger") as logger: - train_args = { - "input_config": [ - {"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}, - {"DataSource": {"OtherDataSource": {}}}, - ] - } - accept_eula = True - - _TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula) + train_args = { + "input_config": [ + { + "ChannelName": "model", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}, + }, + {"ChannelName": "model", "DataSource": {"OtherDataSource": {}}}, + ] + } + accept_eula = True - # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only - assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { - "AcceptEula": True - } - assert "ModelAccessConfig" not in train_args["input_config"][1]["DataSource"].get( - "OtherDataSource", {} - ) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) - # Verify no logging occurred since there's only one S3 data source - logger.info.assert_not_called() + # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } + assert "ModelAccessConfig" not in train_args["input_config"][1]["DataSource"].get( + "OtherDataSource", {} + ) def test_validate_smdistributed_unsupported_image_raises(sagemaker_session): @@ -2709,6 +2736,8 @@ def test_fit_verify_job_name(strftime, sagemaker_session): def test_fit_verify_accept_eula(strftime, sagemaker_session): fw = DummyFramework( entry_point=SCRIPT_PATH, + model_uri="s3://mybucket/model", + image_uri=IMAGE_URI, role="DummyRole", sagemaker_session=sagemaker_session, instance_count=INSTANCE_COUNT, @@ -2721,7 +2750,7 @@ def test_fit_verify_accept_eula(strftime, sagemaker_session): _, _, train_kwargs = sagemaker_session.train.mock_calls[0] assert ( - train_kwargs["input_config"][0] + train_kwargs["input_config"][1] .get("DataSource", {}) .get("S3DataSource", {}) .get("ModelAccessConfig", {}) From b1ffab2aa505100fcee949c752e2027dc27e315b Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 15 May 2025 18:23:43 +0000 Subject: [PATCH 09/11] chore: undo changes to serverless_inference_config_dict --- src/sagemaker/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index b281d9f489..3bfac0c8da 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1805,7 +1805,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, explainer_config_dict=explainer_config_dict, async_inference_config_dict=async_inference_config_dict, - serverless_inference_config=serverless_inference_config_dict, + serverless_inference_config_dict=serverless_inference_config_dict, routing_config=routing_config, inference_ami_version=inference_ami_version, ) From a7ff61d0e8912dd69c1751ef114132112a3cd88e Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 15 May 2025 18:30:33 +0000 Subject: [PATCH 10/11] chore: cleanup unit tests --- tests/unit/test_estimator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 5621723759..876d79b67d 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -378,11 +378,8 @@ def test_set_accept_eula_for_model_channel_input_data_config_single_data_source( "AcceptEula": True } - # Verify no logging occurred - logger.info.assert_not_called() - -def test_set_accept_eula_for_model_channel_input_data_config_single_data_source(): +def test_set_accept_eula_for_nonmodel_channel_input_data_config_single_data_source(): """Test with a single S3DataSource.""" with patch("sagemaker.estimator.logger") as logger: train_args = { From 4c46f0196ec66c4ab77d9e8c53f03ebd95c7f7ba Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Thu, 15 May 2025 19:08:41 +0000 Subject: [PATCH 11/11] fix: unit tests --- tests/unit/test_estimator.py | 59 +++++++++++++++++------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 876d79b67d..f01cb08284 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -360,45 +360,42 @@ def test_set_accept_eula_for_model_channel_input_data_config_none_accept_eula(): def test_set_accept_eula_for_model_channel_input_data_config_single_data_source(): """Test with a single S3DataSource.""" - with patch("sagemaker.estimator.logger") as logger: - train_args = { - "input_config": [ - { - "ChannelName": "model", - "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}, - } - ] - } - accept_eula = True + train_args = { + "input_config": [ + { + "ChannelName": "model", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}, + } + ] + } + accept_eula = True - _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) - # Verify ModelAccessConfig and AcceptEula are set correctly - assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { - "AcceptEula": True - } + # Verify ModelAccessConfig and AcceptEula are set correctly + assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == { + "AcceptEula": True + } def test_set_accept_eula_for_nonmodel_channel_input_data_config_single_data_source(): """Test with a single S3DataSource.""" - with patch("sagemaker.estimator.logger") as logger: - train_args = { - "input_config": [ - { - "ChannelName": "NotModel", - "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/not-model"}}, - } - ] - } - accept_eula = True + train_args = { + "input_config": [ + { + "ChannelName": "NotModel", + "DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/not-model"}}, + } + ] + } + accept_eula = True - _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) + _TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula) - # Verify ModelAccessConfig and AcceptEula are set correctly - assert ( - train_args["input_config"][0]["DataSource"]["S3DataSource"].get("ModelAccessConfig") - == None - ) + # Verify ModelAccessConfig and AcceptEula are set correctly + assert ( + train_args["input_config"][0]["DataSource"]["S3DataSource"].get("ModelAccessConfig") is None + ) def test_set_accept_eula_for_model_channel_input_data_config_multiple_model_channels():