@@ -337,7 +337,7 @@ def training_job_description(sagemaker_session):
337
337
338
338
339
339
def test_set_accept_eula_for_input_data_config_no_input_data_config ():
340
- """Test when InputDataConfig is not in train_args."""
340
+ """Test when input_config is not in train_args."""
341
341
train_args = {}
342
342
accept_eula = True
343
343
@@ -349,29 +349,29 @@ def test_set_accept_eula_for_input_data_config_no_input_data_config():
349
349
350
350
def test_set_accept_eula_for_input_data_config_none_accept_eula ():
351
351
"""Test when accept_eula is None."""
352
- train_args = {"InputDataConfig " : [{"DataSource" : {"S3DataSource" : {}}}]}
352
+ train_args = {"input_config " : [{"DataSource" : {"S3DataSource" : {}}}]}
353
353
accept_eula = None
354
354
355
355
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
356
356
357
357
# Verify train_args remains unchanged
358
- assert train_args == {"InputDataConfig " : [{"DataSource" : {"S3DataSource" : {}}}]}
358
+ assert train_args == {"input_config " : [{"DataSource" : {"S3DataSource" : {}}}]}
359
359
360
360
361
361
def test_set_accept_eula_for_input_data_config_single_data_source ():
362
362
"""Test with a single S3DataSource."""
363
363
with patch ("sagemaker.estimator.logger" ) as logger :
364
364
train_args = {
365
- "InputDataConfig " : [{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}}]
365
+ "input_config " : [{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}}]
366
366
}
367
367
accept_eula = True
368
368
369
369
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
370
370
371
371
# Verify ModelAccessConfig and AcceptEula are set correctly
372
- assert train_args ["InputDataConfig " ][0 ]["DataSource" ]["S3DataSource" ][
373
- "ModelAccessConfig"
374
- ] == { "AcceptEula" : True }
372
+ assert train_args ["input_config " ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
373
+ "AcceptEula" : True
374
+ }
375
375
376
376
# Verify no logging occurred since there's only one data source
377
377
logger .info .assert_not_called ()
@@ -381,7 +381,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources():
381
381
"""Test with multiple S3DataSources."""
382
382
with patch ("sagemaker.estimator.logger" ) as logger :
383
383
train_args = {
384
- "InputDataConfig " : [
384
+ "input_config " : [
385
385
{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model1" }}},
386
386
{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model2" }}},
387
387
]
@@ -391,12 +391,12 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources():
391
391
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
392
392
393
393
# Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394
- assert train_args ["InputDataConfig " ][0 ]["DataSource" ]["S3DataSource" ][
395
- "ModelAccessConfig"
396
- ] == { "AcceptEula" : True }
397
- assert train_args ["InputDataConfig " ][1 ]["DataSource" ]["S3DataSource" ][
398
- "ModelAccessConfig"
399
- ] == { "AcceptEula" : True }
394
+ assert train_args ["input_config " ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
395
+ "AcceptEula" : True
396
+ }
397
+ assert train_args ["input_config " ][1 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
398
+ "AcceptEula" : True
399
+ }
400
400
401
401
# Verify logging occurred with correct information
402
402
logger .info .assert_called_once ()
@@ -409,7 +409,7 @@ def test_set_accept_eula_for_input_data_config_multiple_data_sources():
409
409
def test_set_accept_eula_for_input_data_config_existing_model_access_config ():
410
410
"""Test when ModelAccessConfig already exists."""
411
411
train_args = {
412
- "InputDataConfig " : [
412
+ "input_config " : [
413
413
{
414
414
"DataSource" : {
415
415
"S3DataSource" : {
@@ -425,39 +425,39 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config():
425
425
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
426
426
427
427
# Verify AcceptEula is added to existing ModelAccessConfig
428
- assert train_args ["InputDataConfig " ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
428
+ assert train_args ["input_config " ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
429
429
"OtherSetting" : "value" ,
430
430
"AcceptEula" : True ,
431
431
}
432
432
433
433
434
434
def test_set_accept_eula_for_input_data_config_missing_s3_data_source ():
435
435
"""Test when S3DataSource is missing."""
436
- train_args = {"InputDataConfig " : [{"DataSource" : {"OtherDataSource" : {}}}]}
436
+ train_args = {"input_config " : [{"DataSource" : {"OtherDataSource" : {}}}]}
437
437
accept_eula = True
438
438
439
439
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
440
440
441
441
# Verify train_args remains unchanged
442
- assert train_args == {"InputDataConfig " : [{"DataSource" : {"OtherDataSource" : {}}}]}
442
+ assert train_args == {"input_config " : [{"DataSource" : {"OtherDataSource" : {}}}]}
443
443
444
444
445
445
def test_set_accept_eula_for_input_data_config_missing_data_source ():
446
446
"""Test when DataSource is missing."""
447
- train_args = {"InputDataConfig " : [{"OtherKey" : {}}]}
447
+ train_args = {"input_config " : [{"OtherKey" : {}}]}
448
448
accept_eula = True
449
449
450
450
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
451
451
452
452
# Verify train_args remains unchanged
453
- assert train_args == {"InputDataConfig " : [{"OtherKey" : {}}]}
453
+ assert train_args == {"input_config " : [{"OtherKey" : {}}]}
454
454
455
455
456
456
def test_set_accept_eula_for_input_data_config_mixed_data_sources ():
457
457
"""Test with a mix of S3DataSource and other data sources."""
458
458
with patch ("sagemaker.estimator.logger" ) as logger :
459
459
train_args = {
460
- "InputDataConfig " : [
460
+ "input_config " : [
461
461
{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}},
462
462
{"DataSource" : {"OtherDataSource" : {}}},
463
463
]
@@ -467,10 +467,10 @@ def test_set_accept_eula_for_input_data_config_mixed_data_sources():
467
467
_TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
468
468
469
469
# Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470
- assert train_args ["InputDataConfig " ][0 ]["DataSource" ]["S3DataSource" ][
471
- "ModelAccessConfig"
472
- ] == { "AcceptEula" : True }
473
- assert "ModelAccessConfig" not in train_args ["InputDataConfig " ][1 ]["DataSource" ].get (
470
+ assert train_args ["input_config " ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
471
+ "AcceptEula" : True
472
+ }
473
+ assert "ModelAccessConfig" not in train_args ["input_config " ][1 ]["DataSource" ].get (
474
474
"OtherDataSource" , {}
475
475
)
476
476
@@ -2705,6 +2705,31 @@ def test_fit_verify_job_name(strftime, sagemaker_session):
2705
2705
assert fw .latest_training_job .name == JOB_NAME
2706
2706
2707
2707
2708
+ @patch ("time.strftime" , return_value = TIMESTAMP )
2709
+ def test_fit_verify_accept_eula (strftime , sagemaker_session ):
2710
+ fw = DummyFramework (
2711
+ entry_point = SCRIPT_PATH ,
2712
+ role = "DummyRole" ,
2713
+ sagemaker_session = sagemaker_session ,
2714
+ instance_count = INSTANCE_COUNT ,
2715
+ instance_type = INSTANCE_TYPE ,
2716
+ tags = TAGS ,
2717
+ encrypt_inter_container_traffic = True ,
2718
+ )
2719
+ fw .fit (inputs = TrainingInput ("s3://mybucket/train" ), accept_eula = True )
2720
+
2721
+ _ , _ , train_kwargs = sagemaker_session .train .mock_calls [0 ]
2722
+
2723
+ assert (
2724
+ train_kwargs ["input_config" ][0 ]
2725
+ .get ("DataSource" , {})
2726
+ .get ("S3DataSource" , {})
2727
+ .get ("ModelAccessConfig" , {})
2728
+ .get ("AcceptEula" )
2729
+ is True
2730
+ )
2731
+
2732
+
2708
2733
@pytest .mark .parametrize (
2709
2734
"debugger_hook_config_direct_input, sagemaker_config, expected_debugger_hook_config_output" ,
2710
2735
[
0 commit comments