Skip to content

Commit e9f2f65

Browse files
committed
fix: unit tests
1 parent 68202a7 commit e9f2f65

File tree

6 files changed

+153
-6
lines changed

6 files changed

+153
-6
lines changed

src/sagemaker/estimator.py

+3
Original file line numberDiff line numberDiff line change
@@ -2710,6 +2710,9 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
27102710
if "InputDataConfig" not in train_args:
27112711
return
27122712

2713+
if accept_eula is None:
2714+
return
2715+
27132716
eula_count = 0
27142717
s3_uris = []
27152718

src/sagemaker/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1805,7 +1805,7 @@ def deploy(
18051805
container_startup_health_check_timeout=container_startup_health_check_timeout,
18061806
explainer_config_dict=explainer_config_dict,
18071807
async_inference_config_dict=async_inference_config_dict,
1808-
serverless_inference_config_dict=serverless_inference_config_dict,
1808+
serverless_inference_config=serverless_inference_config_dict,
18091809
routing_config=routing_config,
18101810
inference_ami_version=inference_ami_version,
18111811
)

tests/unit/sagemaker/experiments/test_run_context.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker
5454
assert _RunContext.get_current_run() == run_obj
5555

5656
expected_exp_config = run_obj.experiment_config
57-
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
57+
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None)
5858

5959
# _RunContext is cleaned up after exiting the with statement
6060
assert not _RunContext.get_current_run()
@@ -94,7 +94,7 @@ def test_auto_pass_in_exp_config_under_load_run(
9494
assert loaded_run.experiment_config == run_obj.experiment_config
9595

9696
expected_exp_config = run_obj.experiment_config
97-
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
97+
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None)
9898

9999
# _RunContext is cleaned up after exiting the with statement
100100
assert not _RunContext.get_current_run()
@@ -174,7 +174,7 @@ def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_
174174

175175
assert _RunContext.get_current_run() == run_obj
176176

177-
mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg)
177+
mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg, None)
178178

179179
# _RunContext is cleaned up after exiting the with statement
180180
assert not _RunContext.get_current_run()

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ def test_gated_model_s3_uri_with_eula_in_fit(
607607
inputs=channels,
608608
wait=True,
609609
job_name="meta-textgeneration-llama-2-7b-f-8675309",
610+
accept_eula=True,
610611
)
611612

612613
assert hasattr(estimator, "model_access_config")
@@ -688,6 +689,7 @@ def test_gated_model_non_model_package_s3_uri(
688689
instance_count=1,
689690
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt"
690691
"orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04",
692+
model_uri="s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key",
691693
source_dir="s3://jumpstart-cache-prod-us-west-2/source-d"
692694
"irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz",
693695
entry_point="transfer_learning.py",
@@ -1346,7 +1348,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
13461348
and reach out to JumpStart team."""
13471349

13481350
init_args_to_skip: Set[str] = set(["kwargs"])
1349-
fit_args_to_skip: Set[str] = set(["accept_eula"])
1351+
fit_args_to_skip: Set[str] = set([])
13501352
deploy_args_to_skip: Set[str] = set(["kwargs"])
13511353

13521354
parent_class_init = Estimator.__init__

tests/unit/sagemaker/jumpstart/test_types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def test_use_training_model_artifact():
333333
specs1 = JumpStartModelSpecs(BASE_SPEC)
334334
assert specs1.use_training_model_artifact()
335335
specs1.gated_bucket = True
336-
assert not specs1.use_training_model_artifact()
336+
assert specs1.use_training_model_artifact()
337337
specs1.gated_bucket = False
338338
specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"}
339339
assert not specs1.use_training_model_artifact()

tests/unit/test_estimator.py

+142
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,148 @@ def training_job_description(sagemaker_session):
336336
return returned_job_description
337337

338338

339+
def test_set_accept_eula_for_input_data_config_no_input_data_config():
340+
"""Test when InputDataConfig is not in train_args."""
341+
train_args = {}
342+
accept_eula = True
343+
344+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
345+
346+
# Verify train_args remains unchanged
347+
assert train_args == {}
348+
349+
350+
def test_set_accept_eula_for_input_data_config_none_accept_eula():
351+
"""Test when accept_eula is None."""
352+
train_args = {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]}
353+
accept_eula = None
354+
355+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
356+
357+
# Verify train_args remains unchanged
358+
assert train_args == {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]}
359+
360+
361+
def test_set_accept_eula_for_input_data_config_single_data_source():
362+
"""Test with a single S3DataSource."""
363+
with patch("sagemaker.estimator.logger") as logger:
364+
train_args = {
365+
"InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}]
366+
}
367+
accept_eula = True
368+
369+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
370+
371+
# Verify ModelAccessConfig and AcceptEula are set correctly
372+
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][
373+
"ModelAccessConfig"
374+
] == {"AcceptEula": True}
375+
376+
# Verify no logging occurred since there's only one data source
377+
logger.info.assert_not_called()
378+
379+
380+
def test_set_accept_eula_for_input_data_config_multiple_data_sources():
381+
"""Test with multiple S3DataSources."""
382+
with patch("sagemaker.estimator.logger") as logger:
383+
train_args = {
384+
"InputDataConfig": [
385+
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}},
386+
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}},
387+
]
388+
}
389+
accept_eula = True
390+
391+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
392+
393+
# Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394+
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][
395+
"ModelAccessConfig"
396+
] == {"AcceptEula": True}
397+
assert train_args["InputDataConfig"][1]["DataSource"]["S3DataSource"][
398+
"ModelAccessConfig"
399+
] == {"AcceptEula": True}
400+
401+
# Verify logging occurred with correct information
402+
logger.info.assert_called_once()
403+
args = logger.info.call_args[0]
404+
assert args[0] == "Accepting EULA for %d S3 data sources: %s"
405+
assert args[1] == 2
406+
assert args[2] == "s3://bucket/model1, s3://bucket/model2"
407+
408+
409+
def test_set_accept_eula_for_input_data_config_existing_model_access_config():
410+
"""Test when ModelAccessConfig already exists."""
411+
train_args = {
412+
"InputDataConfig": [
413+
{
414+
"DataSource": {
415+
"S3DataSource": {
416+
"S3Uri": "s3://bucket/model",
417+
"ModelAccessConfig": {"OtherSetting": "value"},
418+
}
419+
}
420+
}
421+
]
422+
}
423+
accept_eula = True
424+
425+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
426+
427+
# Verify AcceptEula is added to existing ModelAccessConfig
428+
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
429+
"OtherSetting": "value",
430+
"AcceptEula": True,
431+
}
432+
433+
434+
def test_set_accept_eula_for_input_data_config_missing_s3_data_source():
435+
"""Test when S3DataSource is missing."""
436+
train_args = {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]}
437+
accept_eula = True
438+
439+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
440+
441+
# Verify train_args remains unchanged
442+
assert train_args == {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]}
443+
444+
445+
def test_set_accept_eula_for_input_data_config_missing_data_source():
446+
"""Test when DataSource is missing."""
447+
train_args = {"InputDataConfig": [{"OtherKey": {}}]}
448+
accept_eula = True
449+
450+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
451+
452+
# Verify train_args remains unchanged
453+
assert train_args == {"InputDataConfig": [{"OtherKey": {}}]}
454+
455+
456+
def test_set_accept_eula_for_input_data_config_mixed_data_sources():
457+
"""Test with a mix of S3DataSource and other data sources."""
458+
with patch("sagemaker.estimator.logger") as logger:
459+
train_args = {
460+
"InputDataConfig": [
461+
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}},
462+
{"DataSource": {"OtherDataSource": {}}},
463+
]
464+
}
465+
accept_eula = True
466+
467+
EstimatorBase._set_accept_eula_for_input_data_config(train_args, accept_eula)
468+
469+
# Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470+
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][
471+
"ModelAccessConfig"
472+
] == {"AcceptEula": True}
473+
assert "ModelAccessConfig" not in train_args["InputDataConfig"][1]["DataSource"].get(
474+
"OtherDataSource", {}
475+
)
476+
477+
# Verify no logging occurred since there's only one S3 data source
478+
logger.info.assert_not_called()
479+
480+
339481
def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
340482
# Test unsupported image raises error.
341483
for unsupported_image in DummyFramework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:

0 commit comments

Comments
 (0)