Skip to content

Commit 4d45973

Browse files
committed
chore: add unit test + minor fix
1 parent 2830f39 commit 4d45973

File tree

2 files changed

+55
-30
lines changed

2 files changed

+55
-30
lines changed

src/sagemaker/estimator.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -2700,14 +2700,14 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
27002700
"""Set the AcceptEula flag for all input data configurations.
27012701
27022702
This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources
2703-
in the InputDataConfig array. It handles cases where keys might not exist in the
2703+
in the input_config array. It handles cases where keys might not exist in the
27042704
nested dictionary structure.
27052705
27062706
Args:
27072707
train_args (dict): The training job arguments dictionary
27082708
accept_eula (bool): The value to set for AcceptEula flag
27092709
"""
2710-
if "InputDataConfig" not in train_args:
2710+
if "input_config" not in train_args:
27112711
return
27122712

27132713
if accept_eula is None:
@@ -2716,9 +2716,9 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
27162716
eula_count = 0
27172717
s3_uris = []
27182718

2719-
for idx in range(len(train_args["InputDataConfig"])):
2720-
if "DataSource" in train_args["InputDataConfig"][idx]:
2721-
data_source = train_args["InputDataConfig"][idx]["DataSource"]
2719+
for idx in range(len(train_args["input_config"])):
2720+
if "DataSource" in train_args["input_config"][idx]:
2721+
data_source = train_args["input_config"][idx]["DataSource"]
27222722
if "S3DataSource" in data_source:
27232723
s3_data_source = data_source["S3DataSource"]
27242724
if "ModelAccessConfig" not in s3_data_source:

tests/unit/test_estimator.py

+50-25
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def training_job_description(sagemaker_session):
337337

338338

339339
def test_set_accept_eula_for_input_data_config_no_input_data_config():
340-
"""Test when InputDataConfig is not in train_args."""
340+
"""Test when input_config is not in train_args."""
341341
train_args = {}
342342
accept_eula = True
343343

@@ -349,29 +349,29 @@ def test_set_accept_eula_for_input_data_config_no_input_data_config():
349349

350350
def test_set_accept_eula_for_input_data_config_none_accept_eula():
351351
"""Test when accept_eula is None."""
352-
train_args = {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]}
352+
train_args = {"input_config": [{"DataSource": {"S3DataSource": {}}}]}
353353
accept_eula = None
354354

355355
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
356356

357357
# Verify train_args remains unchanged
358-
assert train_args == {"InputDataConfig": [{"DataSource": {"S3DataSource": {}}}]}
358+
assert train_args == {"input_config": [{"DataSource": {"S3DataSource": {}}}]}
359359

360360

361361
def test_set_accept_eula_for_input_data_config_single_data_source():
362362
"""Test with a single S3DataSource."""
363363
with patch("sagemaker.estimator.logger") as logger:
364364
train_args = {
365-
"InputDataConfig": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}]
365+
"input_config": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}]
366366
}
367367
accept_eula = True
368368

369369
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
370370

371371
# Verify ModelAccessConfig and AcceptEula are set correctly
372-
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][
373-
"ModelAccessConfig"
374-
] == {"AcceptEula": True}
372+
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
373+
"AcceptEula": True
374+
}
375375

376376
# Verify no logging occurred since there's only one data source
377377
logger.info.assert_not_called()
@@ -381,7 +381,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources():
381381
"""Test with multiple S3DataSources."""
382382
with patch("sagemaker.estimator.logger") as logger:
383383
train_args = {
384-
"InputDataConfig": [
384+
"input_config": [
385385
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}},
386386
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}},
387387
]
@@ -391,12 +391,12 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources():
391391
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
392392

393393
# Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394-
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][
395-
"ModelAccessConfig"
396-
] == {"AcceptEula": True}
397-
assert train_args["InputDataConfig"][1]["DataSource"]["S3DataSource"][
398-
"ModelAccessConfig"
399-
] == {"AcceptEula": True}
394+
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
395+
"AcceptEula": True
396+
}
397+
assert train_args["input_config"][1]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
398+
"AcceptEula": True
399+
}
400400

401401
# Verify logging occurred with correct information
402402
logger.info.assert_called_once()
@@ -409,7 +409,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources():
409409
def test_set_accept_eula_for_input_data_config_existing_model_access_config():
410410
"""Test when ModelAccessConfig already exists."""
411411
train_args = {
412-
"InputDataConfig": [
412+
"input_config": [
413413
{
414414
"DataSource": {
415415
"S3DataSource": {
@@ -425,39 +425,39 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config():
425425
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
426426

427427
# Verify AcceptEula is added to existing ModelAccessConfig
428-
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
428+
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
429429
"OtherSetting": "value",
430430
"AcceptEula": True,
431431
}
432432

433433

434434
def test_set_accept_eula_for_input_data_config_missing_s3_data_source():
435435
"""Test when S3DataSource is missing."""
436-
train_args = {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]}
436+
train_args = {"input_config": [{"DataSource": {"OtherDataSource": {}}}]}
437437
accept_eula = True
438438

439439
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
440440

441441
# Verify train_args remains unchanged
442-
assert train_args == {"InputDataConfig": [{"DataSource": {"OtherDataSource": {}}}]}
442+
assert train_args == {"input_config": [{"DataSource": {"OtherDataSource": {}}}]}
443443

444444

445445
def test_set_accept_eula_for_input_data_config_missing_data_source():
446446
"""Test when DataSource is missing."""
447-
train_args = {"InputDataConfig": [{"OtherKey": {}}]}
447+
train_args = {"input_config": [{"OtherKey": {}}]}
448448
accept_eula = True
449449

450450
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
451451

452452
# Verify train_args remains unchanged
453-
assert train_args == {"InputDataConfig": [{"OtherKey": {}}]}
453+
assert train_args == {"input_config": [{"OtherKey": {}}]}
454454

455455

456456
def test_set_accept_eula_for_input_data_config_mixed_data_sources():
457457
"""Test with a mix of S3DataSource and other data sources."""
458458
with patch("sagemaker.estimator.logger") as logger:
459459
train_args = {
460-
"InputDataConfig": [
460+
"input_config": [
461461
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}},
462462
{"DataSource": {"OtherDataSource": {}}},
463463
]
@@ -467,10 +467,10 @@ def test_set_accept_eula_for_input_data_config_mixed_data_sources():
467467
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
468468

469469
# Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470-
assert train_args["InputDataConfig"][0]["DataSource"]["S3DataSource"][
471-
"ModelAccessConfig"
472-
] == {"AcceptEula": True}
473-
assert "ModelAccessConfig" not in train_args["InputDataConfig"][1]["DataSource"].get(
470+
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
471+
"AcceptEula": True
472+
}
473+
assert "ModelAccessConfig" not in train_args["input_config"][1]["DataSource"].get(
474474
"OtherDataSource", {}
475475
)
476476

@@ -2705,6 +2705,31 @@ def test_fit_verify_job_name(strftime, sagemaker_session):
27052705
assert fw.latest_training_job.name == JOB_NAME
27062706

27072707

2708+
@patch("time.strftime", return_value=TIMESTAMP)
2709+
def test_fit_verify_accept_eula(strftime, sagemaker_session):
2710+
fw = DummyFramework(
2711+
entry_point=SCRIPT_PATH,
2712+
role="DummyRole",
2713+
sagemaker_session=sagemaker_session,
2714+
instance_count=INSTANCE_COUNT,
2715+
instance_type=INSTANCE_TYPE,
2716+
tags=TAGS,
2717+
encrypt_inter_container_traffic=True,
2718+
)
2719+
fw.fit(inputs=TrainingInput("s3://mybucket/train"), accept_eula=True)
2720+
2721+
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
2722+
2723+
assert (
2724+
train_kwargs["input_config"][0]
2725+
.get("DataSource", {})
2726+
.get("S3DataSource", {})
2727+
.get("ModelAccessConfig", {})
2728+
.get("AcceptEula")
2729+
is True
2730+
)
2731+
2732+
27082733
@pytest.mark.parametrize(
27092734
"debugger_hook_config_direct_input, sagemaker_config, expected_debugger_hook_config_output",
27102735
[

0 commit comments

Comments
 (0)