@@ -336,93 +336,119 @@ def training_job_description(sagemaker_session):
336
336
return returned_job_description
337
337
338
338
339
- def test_set_accept_eula_for_input_data_config_no_input_data_config ():
339
+ def test_set_accept_eula_for_model_channel_input_data_config_no_input_data_config ():
340
340
"""Test when input_config is not in train_args."""
341
341
train_args = {}
342
342
accept_eula = True
343
343
344
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
344
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
345
345
346
346
# Verify train_args remains unchanged
347
347
assert train_args == {}
348
348
349
349
350
- def test_set_accept_eula_for_input_data_config_none_accept_eula ():
350
+ def test_set_accept_eula_for_model_channel_input_data_config_none_accept_eula ():
351
351
"""Test when accept_eula is None."""
352
352
train_args = {"input_config" : [{"DataSource" : {"S3DataSource" : {}}}]}
353
353
accept_eula = None
354
354
355
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
355
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
356
356
357
357
# Verify train_args remains unchanged
358
358
assert train_args == {"input_config" : [{"DataSource" : {"S3DataSource" : {}}}]}
359
359
360
360
361
- def test_set_accept_eula_for_input_data_config_single_data_source ():
361
+ def test_set_accept_eula_for_model_channel_input_data_config_single_data_source ():
362
362
"""Test with a single S3DataSource."""
363
363
with patch ("sagemaker.estimator.logger" ) as logger :
364
364
train_args = {
365
- "input_config" : [{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}}]
365
+ "input_config" : [
366
+ {
367
+ "ChannelName" : "model" ,
368
+ "DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }},
369
+ }
370
+ ]
366
371
}
367
372
accept_eula = True
368
373
369
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
374
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
370
375
371
376
# Verify ModelAccessConfig and AcceptEula are set correctly
372
377
assert train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
373
378
"AcceptEula" : True
374
379
}
375
380
376
- # Verify no logging occurred since there's only one data source
381
+ # Verify no logging occurred
377
382
logger .info .assert_not_called ()
378
383
379
384
380
- def test_set_accept_eula_for_input_data_config_multiple_data_sources ():
381
- """Test with multiple S3DataSources ."""
385
+ def test_set_accept_eula_for_model_channel_input_data_config_single_data_source ():
386
+ """Test with a single S3DataSource ."""
382
387
with patch ("sagemaker.estimator.logger" ) as logger :
383
388
train_args = {
384
389
"input_config" : [
385
- {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model1" }}},
386
- {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model2" }}},
390
+ {
391
+ "ChannelName" : "NotModel" ,
392
+ "DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/not-model" }},
393
+ }
387
394
]
388
395
}
389
396
accept_eula = True
390
397
391
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
398
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
399
+
400
+ # Verify ModelAccessConfig and AcceptEula are set correctly
401
+ assert (
402
+ train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ].get ("ModelAccessConfig" )
403
+ == None
404
+ )
392
405
393
- # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394
- assert train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
395
- "AcceptEula" : True
396
- }
397
- assert train_args ["input_config" ][1 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
398
- "AcceptEula" : True
399
- }
400
406
401
- # Verify logging occurred with correct information
402
- logger .info .assert_called_once ()
403
- args = logger .info .call_args [0 ]
404
- assert args [0 ] == "Accepting EULA for %d S3 data sources: %s"
405
- assert args [1 ] == 2
406
- assert args [2 ] == "s3://bucket/model1, s3://bucket/model2"
407
+ def test_set_accept_eula_for_model_channel_input_data_config_multiple_model_channels ():
408
+ """Test with multiple model channels."""
409
+ train_args = {
410
+ "input_config" : [
411
+ {
412
+ "ChannelName" : "model" ,
413
+ "DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model1" }},
414
+ },
415
+ {
416
+ "ChannelName" : "model" ,
417
+ "DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model2" }},
418
+ },
419
+ ]
420
+ }
421
+ accept_eula = True
422
+
423
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
424
+
425
+ # Verify ModelAccessConfig and AcceptEula are set correctly for both model channels
426
+ assert train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
427
+ "AcceptEula" : True
428
+ }
429
+ assert train_args ["input_config" ][1 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
430
+ "AcceptEula" : True
431
+ }
407
432
408
433
409
- def test_set_accept_eula_for_input_data_config_existing_model_access_config ():
434
+ def test_set_accept_eula_for_model_channel_input_data_config_existing_model_access_config ():
410
435
"""Test when ModelAccessConfig already exists."""
411
436
train_args = {
412
437
"input_config" : [
413
438
{
439
+ "ChannelName" : "model" ,
414
440
"DataSource" : {
415
441
"S3DataSource" : {
416
442
"S3Uri" : "s3://bucket/model" ,
417
443
"ModelAccessConfig" : {"OtherSetting" : "value" },
418
444
}
419
- }
445
+ },
420
446
}
421
447
]
422
448
}
423
449
accept_eula = True
424
450
425
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
451
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
426
452
427
453
# Verify AcceptEula is added to existing ModelAccessConfig
428
454
assert train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
@@ -431,51 +457,52 @@ def test_set_accept_eula_for_input_data_config_existing_model_access_config():
431
457
}
432
458
433
459
434
- def test_set_accept_eula_for_input_data_config_missing_s3_data_source ():
460
+ def test_set_accept_eula_for_model_channel_input_data_config_missing_s3_data_source ():
435
461
"""Test when S3DataSource is missing."""
436
- train_args = {"input_config" : [{"DataSource" : {"OtherDataSource" : {}}}]}
462
+ train_args = {"input_config" : [{"ChannelName" : "model" , " DataSource" : {"OtherDataSource" : {}}}]}
437
463
accept_eula = True
438
464
439
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
465
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
440
466
441
467
# Verify train_args remains unchanged
442
- assert train_args == {"input_config" : [{"DataSource" : {"OtherDataSource" : {}}}]}
468
+ assert train_args == {
469
+ "input_config" : [{"ChannelName" : "model" , "DataSource" : {"OtherDataSource" : {}}}]
470
+ }
443
471
444
472
445
- def test_set_accept_eula_for_input_data_config_missing_data_source ():
473
+ def test_set_accept_eula_for_model_channel_input_data_config_missing_data_source ():
446
474
"""Test when DataSource is missing."""
447
475
train_args = {"input_config" : [{"OtherKey" : {}}]}
448
476
accept_eula = True
449
477
450
- _TrainingJob ._set_accept_eula_for_input_data_config (train_args , accept_eula )
478
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
451
479
452
480
# Verify train_args remains unchanged
453
481
assert train_args == {"input_config" : [{"OtherKey" : {}}]}
454
482
455
483
456
- def test_set_accept_eula_for_input_data_config_mixed_data_sources ():
484
+ def test_set_accept_eula_for_model_channel_input_data_config_mixed_data_sources ():
457
485
"""Test with a mix of S3DataSource and other data sources."""
458
- with patch ( "sagemaker.estimator.logger" ) as logger :
459
- train_args = {
460
- "input_config" : [
461
- { "DataSource " : { "S3DataSource" : { "S3Uri" : "s3://bucket/ model"}}} ,
462
- { "DataSource" : {"OtherDataSource " : {} }},
463
- ]
464
- }
465
- accept_eula = True
466
-
467
- _TrainingJob . _set_accept_eula_for_input_data_config ( train_args , accept_eula )
486
+ train_args = {
487
+ "input_config" : [
488
+ {
489
+ "ChannelName " : " model" ,
490
+ "DataSource" : {"S3DataSource " : {"S3Uri" : "s3://bucket/model" }},
491
+ },
492
+ { "ChannelName" : "model" , "DataSource" : { "OtherDataSource" : {}}},
493
+ ]
494
+ }
495
+ accept_eula = True
468
496
469
- # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470
- assert train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
471
- "AcceptEula" : True
472
- }
473
- assert "ModelAccessConfig" not in train_args ["input_config" ][1 ]["DataSource" ].get (
474
- "OtherDataSource" , {}
475
- )
497
+ _TrainingJob ._set_accept_eula_for_model_channel_input_data_config (train_args , accept_eula )
476
498
477
- # Verify no logging occurred since there's only one S3 data source
478
- logger .info .assert_not_called ()
499
+ # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
500
+ assert train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
501
+ "AcceptEula" : True
502
+ }
503
+ assert "ModelAccessConfig" not in train_args ["input_config" ][1 ]["DataSource" ].get (
504
+ "OtherDataSource" , {}
505
+ )
479
506
480
507
481
508
def test_validate_smdistributed_unsupported_image_raises (sagemaker_session ):
@@ -2709,6 +2736,8 @@ def test_fit_verify_job_name(strftime, sagemaker_session):
2709
2736
def test_fit_verify_accept_eula (strftime , sagemaker_session ):
2710
2737
fw = DummyFramework (
2711
2738
entry_point = SCRIPT_PATH ,
2739
+ model_uri = "s3://mybucket/model" ,
2740
+ image_uri = IMAGE_URI ,
2712
2741
role = "DummyRole" ,
2713
2742
sagemaker_session = sagemaker_session ,
2714
2743
instance_count = INSTANCE_COUNT ,
@@ -2721,7 +2750,7 @@ def test_fit_verify_accept_eula(strftime, sagemaker_session):
2721
2750
_ , _ , train_kwargs = sagemaker_session .train .mock_calls [0 ]
2722
2751
2723
2752
assert (
2724
- train_kwargs ["input_config" ][0 ]
2753
+ train_kwargs ["input_config" ][1 ]
2725
2754
.get ("DataSource" , {})
2726
2755
.get ("S3DataSource" , {})
2727
2756
.get ("ModelAccessConfig" , {})
0 commit comments