Skip to content

Commit f26a170

Browse files
committed
chore: only attach eula for model channel
1 parent 4d45973 commit f26a170

File tree

2 files changed

+94
-74
lines changed

2 files changed

+94
-74
lines changed

src/sagemaker/estimator.py

+9-18
Original file line numberDiff line numberDiff line change
@@ -2671,7 +2671,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula=None)
26712671
train_args["session_chaining_config"] = estimator.get_session_chaining_config()
26722672

26732673
if accept_eula is not None:
2674-
cls._set_accept_eula_for_input_data_config(train_args, accept_eula)
2674+
cls._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
26752675

26762676
return train_args
26772677

@@ -2696,11 +2696,11 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
26962696
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path
26972697

26982698
@classmethod
2699-
def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
2700-
"""Set the AcceptEula flag for all input data configurations.
2699+
def _set_accept_eula_for_model_channel_input_data_config(cls, train_args, accept_eula):
2700+
"""Set the AcceptEula flag for model channel in input data configurations.
27012701
2702-
This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources
2703-
in the input_config array. It handles cases where keys might not exist in the
2702+
This method sets the AcceptEula flag in the ModelAccessConfig for the model channel
2703+
S3DataSource in the input_config array. It handles cases where keys might not exist in the
27042704
nested dictionary structure.
27052705
27062706
Args:
@@ -2713,26 +2713,17 @@ def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
27132713
if accept_eula is None:
27142714
return
27152715

2716-
eula_count = 0
2717-
s3_uris = []
2718-
27192716
for idx in range(len(train_args["input_config"])):
2720-
if "DataSource" in train_args["input_config"][idx]:
2717+
if (
2718+
"DataSource" in train_args["input_config"][idx]
2719+
and train_args["input_config"][idx]["ChannelName"].lower().strip() == "model"
2720+
):
27212721
data_source = train_args["input_config"][idx]["DataSource"]
27222722
if "S3DataSource" in data_source:
27232723
s3_data_source = data_source["S3DataSource"]
27242724
if "ModelAccessConfig" not in s3_data_source:
27252725
s3_data_source["ModelAccessConfig"] = {}
27262726
s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula
2727-
eula_count += 1
2728-
2729-
# Collect S3 URI if available
2730-
if "S3Uri" in s3_data_source:
2731-
s3_uris.append(s3_data_source["S3Uri"])
2732-
2733-
# Log info if more than one EULA needs to be accepted
2734-
if eula_count > 1:
2735-
logger.info("Accepting EULA for %d S3 data sources: %s", eula_count, ", ".join(s3_uris))
27362727

27372728
@classmethod
27382729
def _is_local_channel(cls, input_uri):

tests/unit/test_estimator.py

+85-56
Original file line numberDiff line numberDiff line change
@@ -336,93 +336,119 @@ def training_job_description(sagemaker_session):
336336
return returned_job_description
337337

338338

339-
def test_set_accept_eula_for_input_data_config_no_input_data_config():
339+
def test_set_accept_eula_for_model_channel_input_data_config_no_input_data_config():
340340
"""Test when input_config is not in train_args."""
341341
train_args = {}
342342
accept_eula = True
343343

344-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
344+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
345345

346346
# Verify train_args remains unchanged
347347
assert train_args == {}
348348

349349

350-
def test_set_accept_eula_for_input_data_config_none_accept_eula():
350+
def test_set_accept_eula_for_model_channel_input_data_config_none_accept_eula():
351351
"""Test when accept_eula is None."""
352352
train_args = {"input_config": [{"DataSource": {"S3DataSource": {}}}]}
353353
accept_eula = None
354354

355-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
355+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
356356

357357
# Verify train_args remains unchanged
358358
assert train_args == {"input_config": [{"DataSource": {"S3DataSource": {}}}]}
359359

360360

361-
def test_set_accept_eula_for_input_data_config_single_data_source():
361+
def test_set_accept_eula_for_model_channel_input_data_config_single_data_source():
362362
"""Test with a single S3DataSource."""
363363
with patch("sagemaker.estimator.logger") as logger:
364364
train_args = {
365-
"input_config": [{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}}]
365+
"input_config": [
366+
{
367+
"ChannelName": "model",
368+
"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}},
369+
}
370+
]
366371
}
367372
accept_eula = True
368373

369-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
374+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
370375

371376
# Verify ModelAccessConfig and AcceptEula are set correctly
372377
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
373378
"AcceptEula": True
374379
}
375380

376-
# Verify no logging occurred since there's only one data source
381+
# Verify no logging occurred
377382
logger.info.assert_not_called()
378383

379384

380-
def test_set_accept_eula_for_input_data_config_multiple_data_sources():
381-
"""Test with multiple S3DataSources."""
385+
def test_set_accept_eula_for_model_channel_input_data_config_single_data_source():
386+
"""Test with a single S3DataSource."""
382387
with patch("sagemaker.estimator.logger") as logger:
383388
train_args = {
384389
"input_config": [
385-
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}}},
386-
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}}},
390+
{
391+
"ChannelName": "NotModel",
392+
"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/not-model"}},
393+
}
387394
]
388395
}
389396
accept_eula = True
390397

391-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
398+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
399+
400+
# Verify ModelAccessConfig and AcceptEula are set correctly
401+
assert (
402+
train_args["input_config"][0]["DataSource"]["S3DataSource"].get("ModelAccessConfig")
403+
== None
404+
)
392405

393-
# Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394-
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
395-
"AcceptEula": True
396-
}
397-
assert train_args["input_config"][1]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
398-
"AcceptEula": True
399-
}
400406

401-
# Verify logging occurred with correct information
402-
logger.info.assert_called_once()
403-
args = logger.info.call_args[0]
404-
assert args[0] == "Accepting EULA for %d S3 data sources: %s"
405-
assert args[1] == 2
406-
assert args[2] == "s3://bucket/model1, s3://bucket/model2"
407+
def test_set_accept_eula_for_model_channel_input_data_config_multiple_model_channels():
408+
"""Test with multiple model channels."""
409+
train_args = {
410+
"input_config": [
411+
{
412+
"ChannelName": "model",
413+
"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model1"}},
414+
},
415+
{
416+
"ChannelName": "model",
417+
"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model2"}},
418+
},
419+
]
420+
}
421+
accept_eula = True
422+
423+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
424+
425+
# Verify ModelAccessConfig and AcceptEula are set correctly for both model channels
426+
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
427+
"AcceptEula": True
428+
}
429+
assert train_args["input_config"][1]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
430+
"AcceptEula": True
431+
}
407432

408433

409-
def test_set_accept_eula_for_input_data_config_existing_model_access_config():
434+
def test_set_accept_eula_for_model_channel_input_data_config_existing_model_access_config():
410435
"""Test when ModelAccessConfig already exists."""
411436
train_args = {
412437
"input_config": [
413438
{
439+
"ChannelName": "model",
414440
"DataSource": {
415441
"S3DataSource": {
416442
"S3Uri": "s3://bucket/model",
417443
"ModelAccessConfig": {"OtherSetting": "value"},
418444
}
419-
}
445+
},
420446
}
421447
]
422448
}
423449
accept_eula = True
424450

425-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
451+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
426452

427453
# Verify AcceptEula is added to existing ModelAccessConfig
428454
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
@@ -431,51 +457,52 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config():
431457
}
432458

433459

434-
def test_set_accept_eula_for_input_data_config_missing_s3_data_source():
460+
def test_set_accept_eula_for_model_channel_input_data_config_missing_s3_data_source():
435461
"""Test when S3DataSource is missing."""
436-
train_args = {"input_config": [{"DataSource": {"OtherDataSource": {}}}]}
462+
train_args = {"input_config": [{"ChannelName": "model", "DataSource": {"OtherDataSource": {}}}]}
437463
accept_eula = True
438464

439-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
465+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
440466

441467
# Verify train_args remains unchanged
442-
assert train_args == {"input_config": [{"DataSource": {"OtherDataSource": {}}}]}
468+
assert train_args == {
469+
"input_config": [{"ChannelName": "model", "DataSource": {"OtherDataSource": {}}}]
470+
}
443471

444472

445-
def test_set_accept_eula_for_input_data_config_missing_data_source():
473+
def test_set_accept_eula_for_model_channel_input_data_config_missing_data_source():
446474
"""Test when DataSource is missing."""
447475
train_args = {"input_config": [{"OtherKey": {}}]}
448476
accept_eula = True
449477

450-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
478+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
451479

452480
# Verify train_args remains unchanged
453481
assert train_args == {"input_config": [{"OtherKey": {}}]}
454482

455483

456-
def test_set_accept_eula_for_input_data_config_mixed_data_sources():
484+
def test_set_accept_eula_for_model_channel_input_data_config_mixed_data_sources():
457485
"""Test with a mix of S3DataSource and other data sources."""
458-
with patch("sagemaker.estimator.logger") as logger:
459-
train_args = {
460-
"input_config": [
461-
{"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}}},
462-
{"DataSource": {"OtherDataSource": {}}},
463-
]
464-
}
465-
accept_eula = True
466-
467-
_TrainingJob._set_accept_eula_for_input_data_config(train_args, accept_eula)
486+
train_args = {
487+
"input_config": [
488+
{
489+
"ChannelName": "model",
490+
"DataSource": {"S3DataSource": {"S3Uri": "s3://bucket/model"}},
491+
},
492+
{"ChannelName": "model", "DataSource": {"OtherDataSource": {}}},
493+
]
494+
}
495+
accept_eula = True
468496

469-
# Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470-
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
471-
"AcceptEula": True
472-
}
473-
assert "ModelAccessConfig" not in train_args["input_config"][1]["DataSource"].get(
474-
"OtherDataSource", {}
475-
)
497+
_TrainingJob._set_accept_eula_for_model_channel_input_data_config(train_args, accept_eula)
476498

477-
# Verify no logging occurred since there's only one S3 data source
478-
logger.info.assert_not_called()
499+
# Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
500+
assert train_args["input_config"][0]["DataSource"]["S3DataSource"]["ModelAccessConfig"] == {
501+
"AcceptEula": True
502+
}
503+
assert "ModelAccessConfig" not in train_args["input_config"][1]["DataSource"].get(
504+
"OtherDataSource", {}
505+
)
479506

480507

481508
def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
@@ -2709,6 +2736,8 @@ def test_fit_verify_job_name(strftime, sagemaker_session):
27092736
def test_fit_verify_accept_eula(strftime, sagemaker_session):
27102737
fw = DummyFramework(
27112738
entry_point=SCRIPT_PATH,
2739+
model_uri="s3://mybucket/model",
2740+
image_uri=IMAGE_URI,
27122741
role="DummyRole",
27132742
sagemaker_session=sagemaker_session,
27142743
instance_count=INSTANCE_COUNT,
@@ -2721,7 +2750,7 @@ def test_fit_verify_accept_eula(strftime, sagemaker_session):
27212750
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
27222751

27232752
assert (
2724-
train_kwargs["input_config"][0]
2753+
train_kwargs["input_config"][1]
27252754
.get("DataSource", {})
27262755
.get("S3DataSource", {})
27272756
.get("ModelAccessConfig", {})

0 commit comments

Comments
 (0)