@@ -336,6 +336,148 @@ def training_job_description(sagemaker_session):
336
336
return returned_job_description
337
337
338
338
339
+ def test_set_accept_eula_for_input_data_config_no_input_data_config ():
340
+ """Test when InputDataConfig is not in train_args."""
341
+ train_args = {}
342
+ accept_eula = True
343
+
344
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
345
+
346
+ # Verify train_args remains unchanged
347
+ assert train_args == {}
348
+
349
+
350
+ def test_set_accept_eula_for_input_data_config_none_accept_eula ():
351
+ """Test when accept_eula is None."""
352
+ train_args = {"InputDataConfig" : [{"DataSource" : {"S3DataSource" : {}}}]}
353
+ accept_eula = None
354
+
355
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
356
+
357
+ # Verify train_args remains unchanged
358
+ assert train_args == {"InputDataConfig" : [{"DataSource" : {"S3DataSource" : {}}}]}
359
+
360
+
361
+ def test_set_accept_eula_for_input_data_config_single_data_source ():
362
+ """Test with a single S3DataSource."""
363
+ with patch ("sagemaker.estimator.logger" ) as logger :
364
+ train_args = {
365
+ "InputDataConfig" : [{"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}}]
366
+ }
367
+ accept_eula = True
368
+
369
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
370
+
371
+ # Verify ModelAccessConfig and AcceptEula are set correctly
372
+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
373
+ "ModelAccessConfig"
374
+ ] == {"AcceptEula" : True }
375
+
376
+ # Verify no logging occurred since there's only one data source
377
+ logger .info .assert_not_called ()
378
+
379
+
380
+ def test_set_accept_eula_for_input_data_config_multiple_data_sources ():
381
+ """Test with multiple S3DataSources."""
382
+ with patch ("sagemaker.estimator.logger" ) as logger :
383
+ train_args = {
384
+ "InputDataConfig" : [
385
+ {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model1" }}},
386
+ {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model2" }}},
387
+ ]
388
+ }
389
+ accept_eula = True
390
+
391
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
392
+
393
+ # Verify ModelAccessConfig and AcceptEula are set correctly for both data sources
394
+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
395
+ "ModelAccessConfig"
396
+ ] == {"AcceptEula" : True }
397
+ assert train_args ["InputDataConfig" ][1 ]["DataSource" ]["S3DataSource" ][
398
+ "ModelAccessConfig"
399
+ ] == {"AcceptEula" : True }
400
+
401
+ # Verify logging occurred with correct information
402
+ logger .info .assert_called_once ()
403
+ args = logger .info .call_args [0 ]
404
+ assert args [0 ] == "Accepting EULA for %d S3 data sources: %s"
405
+ assert args [1 ] == 2
406
+ assert args [2 ] == "s3://bucket/model1, s3://bucket/model2"
407
+
408
+
409
+ def test_set_accept_eula_for_input_data_config_existing_model_access_config ():
410
+ """Test when ModelAccessConfig already exists."""
411
+ train_args = {
412
+ "InputDataConfig" : [
413
+ {
414
+ "DataSource" : {
415
+ "S3DataSource" : {
416
+ "S3Uri" : "s3://bucket/model" ,
417
+ "ModelAccessConfig" : {"OtherSetting" : "value" },
418
+ }
419
+ }
420
+ }
421
+ ]
422
+ }
423
+ accept_eula = True
424
+
425
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
426
+
427
+ # Verify AcceptEula is added to existing ModelAccessConfig
428
+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["ModelAccessConfig" ] == {
429
+ "OtherSetting" : "value" ,
430
+ "AcceptEula" : True ,
431
+ }
432
+
433
+
434
+ def test_set_accept_eula_for_input_data_config_missing_s3_data_source ():
435
+ """Test when S3DataSource is missing."""
436
+ train_args = {"InputDataConfig" : [{"DataSource" : {"OtherDataSource" : {}}}]}
437
+ accept_eula = True
438
+
439
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
440
+
441
+ # Verify train_args remains unchanged
442
+ assert train_args == {"InputDataConfig" : [{"DataSource" : {"OtherDataSource" : {}}}]}
443
+
444
+
445
+ def test_set_accept_eula_for_input_data_config_missing_data_source ():
446
+ """Test when DataSource is missing."""
447
+ train_args = {"InputDataConfig" : [{"OtherKey" : {}}]}
448
+ accept_eula = True
449
+
450
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
451
+
452
+ # Verify train_args remains unchanged
453
+ assert train_args == {"InputDataConfig" : [{"OtherKey" : {}}]}
454
+
455
+
456
+ def test_set_accept_eula_for_input_data_config_mixed_data_sources ():
457
+ """Test with a mix of S3DataSource and other data sources."""
458
+ with patch ("sagemaker.estimator.logger" ) as logger :
459
+ train_args = {
460
+ "InputDataConfig" : [
461
+ {"DataSource" : {"S3DataSource" : {"S3Uri" : "s3://bucket/model" }}},
462
+ {"DataSource" : {"OtherDataSource" : {}}},
463
+ ]
464
+ }
465
+ accept_eula = True
466
+
467
+ EstimatorBase ._set_accept_eula_for_input_data_config (train_args , accept_eula )
468
+
469
+ # Verify ModelAccessConfig and AcceptEula are set correctly for S3DataSource only
470
+ assert train_args ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
471
+ "ModelAccessConfig"
472
+ ] == {"AcceptEula" : True }
473
+ assert "ModelAccessConfig" not in train_args ["InputDataConfig" ][1 ]["DataSource" ].get (
474
+ "OtherDataSource" , {}
475
+ )
476
+
477
+ # Verify no logging occurred since there's only one S3 data source
478
+ logger .info .assert_not_called ()
479
+
480
+
339
481
def test_validate_smdistributed_unsupported_image_raises (sagemaker_session ):
340
482
# Test unsupported image raises error.
341
483
for unsupported_image in DummyFramework .UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM :
0 commit comments