Skip to content

Commit 402fbc9

Browse files
committed
fix: support legacy training models, fix cache override for unsupported files
1 parent 5de1e55 commit 402fbc9

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

src/sagemaker/jumpstart/cache.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,18 @@ def _get_json_file(
372372
object and None when reading from the local file system.
373373
"""
374374
if self._is_local_metadata_mode():
375-
file_content, etag = self._get_json_file_from_local_override(key, filetype), None
376-
else:
377-
file_content, etag = self._get_json_file_and_etag_from_s3(key)
378-
return file_content, etag
375+
if filetype in {
376+
JumpStartS3FileType.OPEN_WEIGHT_MANIFEST,
377+
JumpStartS3FileType.OPEN_WEIGHT_SPECS,
378+
}:
379+
return self._get_json_file_from_local_override(key, filetype), None
380+
else:
381+
JUMPSTART_LOGGER.warning(
382+
"Local metadata mode is enabled, but the file type %s is not supported "
383+
"for local override. Falling back to s3.",
384+
filetype,
385+
)
386+
return self._get_json_file_and_etag_from_s3(key)
379387

380388
def _get_json_md5_hash(self, key: str):
381389
"""Retrieves md5 object hash for s3 objects, using `s3.head_object`.

src/sagemaker/jumpstart/types.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1940,9 +1940,20 @@ def use_inference_script_uri(self) -> bool:
19401940

19411941
def use_training_model_artifact(self) -> bool:
19421942
"""Returns True if the model should use a model uri when kicking off training job."""
1943+
# old models with this environment variable present don't use model channel
1944+
if any(
1945+
self.training_instance_type_variants.get_instance_specific_gated_model_key_env_var_value(
1946+
instance_type
1947+
)
1948+
for instance_type in self.supported_training_instance_types
1949+
):
1950+
return False
1951+
1952+
# even older models with training model package artifact uris present also don't use model channel
1953+
if len(self.training_model_package_artifact_uris or {}) > 0:
1954+
return False
19431955

1944-
# otherwise, return true is a training model package is not set
1945-
return len(self.training_model_package_artifact_uris or {}) == 0
1956+
return getattr(self, "training_artifact_key", None) is not None
19461957

19471958
def is_gated_model(self) -> bool:
19481959
"""Returns True if the model has a EULA key or the model bucket is gated."""

0 commit comments

Comments
 (0)