Skip to content

Commit 7c29b96

Browse files
pintaoz-awspintaoz
authored andcommitted
Ensure Model.is_repack() returns a boolean (aws#5060)
* Ensure Model.is_repack() returns a boolean * update test --------- Co-authored-by: pintaoz <pintaoz@amazon.com>
1 parent 604fae7 commit 7c29b96

File tree

3 files changed

+32
-0
lines changed

3 files changed

+32
-0
lines changed

src/sagemaker/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,8 @@ def is_repack(self) -> bool:
745745
Returns:
746746
bool: if the source need to be repacked or not
747747
"""
748+
if self.source_dir is None or self.entry_point is None:
749+
return False
748750
return self.source_dir and self.entry_point and not self.git_config
749751

750752
def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
@@ -2143,6 +2145,8 @@ def is_repack(self) -> bool:
21432145
Returns:
21442146
bool: if the source need to be repacked or not
21452147
"""
2148+
if self.source_dir is None or self.entry_point is None:
2149+
return False
21462150
return self.source_dir and self.entry_point and not (self.key_prefix or self.git_config)
21472151

21482152

tests/unit/sagemaker/model/test_framework_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session):
511511
assert not model.is_repack()
512512

513513

514+
@patch("sagemaker.utils.repack_model")
515+
def test_is_repack_with_none_type(repack_model, sagemaker_session):
516+
"""Test is_repack() returns a boolean value when source_dir and entry_point are None"""
517+
518+
model = FrameworkModel(
519+
role=ROLE,
520+
sagemaker_session=sagemaker_session,
521+
image_uri=IMAGE_URI,
522+
model_data=MODEL_DATA,
523+
)
524+
525+
assert model.is_repack() is False
526+
527+
514528
@patch("sagemaker.git_utils.git_clone_repo")
515529
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
516530
def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session):

tests/unit/sagemaker/model/test_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,6 +1046,20 @@ def test_is_repack_with_code_location(repack_model, sagemaker_session):
10461046
assert model.is_repack()
10471047

10481048

1049+
@patch("sagemaker.utils.repack_model")
1050+
def test_is_repack_with_none_type(repack_model, sagemaker_session):
1051+
"""Test is_repack() returns a boolean value when source_dir and entry_point are None"""
1052+
1053+
model = Model(
1054+
role=ROLE,
1055+
sagemaker_session=sagemaker_session,
1056+
image_uri=IMAGE_URI,
1057+
model_data=MODEL_DATA,
1058+
)
1059+
1060+
assert model.is_repack() is False
1061+
1062+
10491063
@patch("sagemaker.git_utils.git_clone_repo")
10501064
@patch("sagemaker.model.fw_utils.tar_and_upload_dir")
10511065
def test_is_repack_with_git_config(tar_and_upload_dir, git_clone_repo, sagemaker_session):

0 commit comments

Comments
 (0)