Skip to content

aws_sagemaker_image_version add support for aliases #42609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .changelog/42606.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
resource/aws_sagemaker_image_version: Use version as part of resource id, avoiding mix up between similar resources
```
3 changes: 3 additions & 0 deletions .changelog/42609.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_image_version: Add `aliases` argument
```
218 changes: 207 additions & 11 deletions internal/service/sagemaker/image_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ package sagemaker

import (
"context"
"fmt"
"log"
"strconv"
"strings"

"github.com/YakDriver/regexache"
"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -40,6 +43,13 @@ func resourceImageVersion() *schema.Resource {
Type: schema.TypeString,
Computed: true,
},
"aliases": {
Type: schema.TypeSet,
Optional: true,
Elem: &schema.Schema{
Type: schema.TypeString,
},
},
"base_image": {
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -139,15 +149,35 @@ func resourceImageVersionCreate(ctx context.Context, d *schema.ResourceData, met
input.ProgrammingLang = aws.String(v.(string))
}

if v, ok := d.GetOk("aliases"); ok {
aliases := v.(*schema.Set).List()
input.Aliases = make([]string, len(aliases))
for i, alias := range aliases {
input.Aliases[i] = alias.(string)
}
}

_, err := conn.CreateImageVersion(ctx, input)
if err != nil {
return sdkdiag.AppendErrorf(diags, "creating SageMaker AI Image Version %s: %s", name, err)
}

d.SetId(name)
// Get the version from the API response
output, err := conn.DescribeImageVersion(ctx, &sagemaker.DescribeImageVersionInput{
ImageName: aws.String(name),
})
if err != nil {
return sdkdiag.AppendErrorf(diags, "describing SageMaker AI Image Version %s after creation: %s", name, err)
}

if _, err := waitImageVersionCreated(ctx, conn, d.Id()); err != nil {
return sdkdiag.AppendErrorf(diags, "waiting for SageMaker AI Image Version (%s) to be created: %s", d.Id(), err)
// Set the ID to be a combination of name and version
versionNumber := aws.ToInt32(output.Version)
id := fmt.Sprintf("%s:%d", name, versionNumber)
d.SetId(id)

// Wait for the image version to be created
if _, err := waitImageVersionCreated(ctx, conn, id); err != nil {
return sdkdiag.AppendErrorf(diags, "waiting for SageMaker AI Image Version (%s) to be created: %s", id, err)
}

return append(diags, resourceImageVersionRead(ctx, d, meta)...)
Expand All @@ -157,24 +187,46 @@ func resourceImageVersionRead(ctx context.Context, d *schema.ResourceData, meta
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

image, err := findImageVersionByName(ctx, conn, d.Id())
id := d.Id()
var image *sagemaker.DescribeImageVersionOutput
var err error

// Check if the ID contains a version (has a colon)
if strings.Contains(id, ":") {
// New format - use the new function
image, err = findImageVersionByNameAndVersion(ctx, conn, id)
} else {
// Legacy format - just the name
image, err = findImageVersionByName(ctx, conn, id)

// If successful, update the ID to the new format
if err == nil && image != nil {
newID := fmt.Sprintf("%s:%d", id, aws.ToInt32(image.Version))
d.SetId(newID)
id = newID
}
}

if !d.IsNewResource() && tfresource.NotFound(err) {
d.SetId("")
log.Printf("[WARN] Unable to find SageMaker AI Image Version (%s); removing from state", d.Id())
log.Printf("[WARN] Unable to find SageMaker AI Image Version (%s); removing from state", id)
return diags
}

if err != nil {
return sdkdiag.AppendErrorf(diags, "reading SageMaker AI Image Version (%s): %s", d.Id(), err)
return sdkdiag.AppendErrorf(diags, "reading SageMaker AI Image Version (%s): %s", id, err)
}

// Parse the ID to get the name
parts := strings.Split(id, ":")
name := parts[0]

d.Set(names.AttrARN, image.ImageVersionArn)
d.Set("base_image", image.BaseImage)
d.Set("image_arn", image.ImageArn)
d.Set("container_image", image.ContainerImage)
d.Set(names.AttrVersion, image.Version)
d.Set("image_name", d.Id())
d.Set("image_name", name)
d.Set("horovod", image.Horovod)
d.Set("job_type", image.JobType)
d.Set("processor", image.Processor)
Expand All @@ -183,16 +235,54 @@ func resourceImageVersionRead(ctx context.Context, d *schema.ResourceData, meta
d.Set("ml_framework", image.MLFramework)
d.Set("programming_lang", image.ProgrammingLang)

// The AWS SDK doesn't have an Aliases field in DescribeImageVersionOutput
// We need to fetch aliases separately using ListAliases API
idParts := strings.Split(id, ":")
imageName := idParts[0]
versionStr := idParts[1]
versionNum, err := strconv.Atoi(versionStr)
if err != nil {
return sdkdiag.AppendErrorf(diags, "invalid version number in resource ID: %s", d.Id())
}

aliasesInput := &sagemaker.ListAliasesInput{
ImageName: aws.String(imageName),
Version: aws.Int32(int32(versionNum)),
}

aliasesOutput, err := conn.ListAliases(ctx, aliasesInput)
if err != nil {
return sdkdiag.AppendErrorf(diags, "listing aliases for SageMaker AI Image Version (%s): %s", d.Id(), err)
}

if err := d.Set("aliases", aliasesOutput.SageMakerImageVersionAliases); err != nil {
return sdkdiag.AppendErrorf(diags, "setting aliases: %s", err)
}

return diags
}

func resourceImageVersionUpdate(ctx context.Context, d *schema.ResourceData, meta any) diag.Diagnostics {
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

// Parse the ID to get name and version
parts := strings.Split(d.Id(), ":")
if len(parts) != 2 {
return sdkdiag.AppendErrorf(diags, "invalid resource ID format: %s", d.Id())
}

name := parts[0]
versionStr := parts[1]

version, err := strconv.Atoi(versionStr)
if err != nil {
return sdkdiag.AppendErrorf(diags, "invalid version number in resource ID: %s", d.Id())
}

input := &sagemaker.UpdateImageVersionInput{
ImageName: aws.String(d.Id()),
Version: aws.Int32(int32(d.Get(names.AttrVersion).(int))),
ImageName: aws.String(name),
Version: aws.Int32(int32(version)),
}

if d.HasChange("horovod") {
Expand Down Expand Up @@ -223,6 +313,52 @@ func resourceImageVersionUpdate(ctx context.Context, d *schema.ResourceData, met
input.ProgrammingLang = aws.String(d.Get("programming_lang").(string))
}

if d.HasChange("aliases") {
// For UpdateImageVersion, we need to use AliasesToAdd and AliasesToDelete
// instead of Aliases directly
oldAliasesSet, newAliasesSet := d.GetChange("aliases")
oldAliases := oldAliasesSet.(*schema.Set).List()
newAliases := newAliasesSet.(*schema.Set).List()

// Find aliases to add (in new but not in old)
var aliasesToAdd []string
for _, newAlias := range newAliases {
found := false
for _, oldAlias := range oldAliases {
if newAlias.(string) == oldAlias.(string) {
found = true
break
}
}
if !found {
aliasesToAdd = append(aliasesToAdd, newAlias.(string))
}
}

// Find aliases to delete (in old but not in new)
var aliasesToDelete []string
for _, oldAlias := range oldAliases {
found := false
for _, newAlias := range newAliases {
if oldAlias.(string) == newAlias.(string) {
found = true
break
}
}
if !found {
aliasesToDelete = append(aliasesToDelete, oldAlias.(string))
}
}

if len(aliasesToAdd) > 0 {
input.AliasesToAdd = aliasesToAdd
}

if len(aliasesToDelete) > 0 {
input.AliasesToDelete = aliasesToDelete
}
}

if _, err := conn.UpdateImageVersion(ctx, input); err != nil {
return sdkdiag.AppendErrorf(diags, "updating SageMaker AI Image Version (%s): %s", d.Id(), err)
}
Expand All @@ -234,9 +370,23 @@ func resourceImageVersionDelete(ctx context.Context, d *schema.ResourceData, met
var diags diag.Diagnostics
conn := meta.(*conns.AWSClient).SageMakerClient(ctx)

// Parse the ID to get name and version
parts := strings.Split(d.Id(), ":")
if len(parts) != 2 {
return sdkdiag.AppendErrorf(diags, "invalid resource ID format: %s", d.Id())
}

name := parts[0]
versionStr := parts[1]

version, err := strconv.Atoi(versionStr)
if err != nil {
return sdkdiag.AppendErrorf(diags, "invalid version number in resource ID: %s", d.Id())
}

input := &sagemaker.DeleteImageVersionInput{
ImageName: aws.String(d.Id()),
Version: aws.Int32(int32(d.Get(names.AttrVersion).(int))),
ImageName: aws.String(name),
Version: aws.Int32(int32(version)),
}

if _, err := conn.DeleteImageVersion(ctx, input); err != nil {
Expand All @@ -253,6 +403,47 @@ func resourceImageVersionDelete(ctx context.Context, d *schema.ResourceData, met
return diags
}

func findImageVersionByNameAndVersion(ctx context.Context, conn *sagemaker.Client, id string) (*sagemaker.DescribeImageVersionOutput, error) {
// Parse the ID to get name and version
parts := strings.Split(id, ":")
if len(parts) != 2 {
return nil, fmt.Errorf("invalid resource ID format: %s", id)
}

name := parts[0]
versionStr := parts[1]

version, err := strconv.Atoi(versionStr)
if err != nil {
return nil, fmt.Errorf("invalid version number in resource ID: %s", id)
}

input := &sagemaker.DescribeImageVersionInput{
ImageName: aws.String(name),
Version: aws.Int32(int32(version)),
}

output, err := conn.DescribeImageVersion(ctx, input)

if errs.IsAErrorMessageContains[*awstypes.ResourceNotFound](err, "does not exist") {
return nil, &retry.NotFoundError{
LastError: err,
LastRequest: input,
}
}

if err != nil {
return nil, err
}

if output == nil {
return nil, tfresource.NewEmptyResultError(input)
}

return output, nil
}

// Keep this for backward compatibility
func findImageVersionByName(ctx context.Context, conn *sagemaker.Client, name string) (*sagemaker.DescribeImageVersionOutput, error) {
input := &sagemaker.DescribeImageVersionInput{
ImageName: aws.String(name),
Expand All @@ -277,3 +468,8 @@ func findImageVersionByName(ctx context.Context, conn *sagemaker.Client, name st

return output, nil
}

// FindImageVersionByNameAndVersion finds a SageMaker Image Version by name and version
func FindImageVersionByNameAndVersion(ctx context.Context, conn *sagemaker.Client, id string) (*sagemaker.DescribeImageVersionOutput, error) {
return findImageVersionByNameAndVersion(ctx, conn, id)
}
11 changes: 9 additions & 2 deletions internal/service/sagemaker/image_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ func TestAccSageMakerImageVersion_full(t *testing.T) {
resource.TestCheckResourceAttr(resourceName, "job_type", "TRAINING"),
resource.TestCheckResourceAttr(resourceName, "ml_framework", "TensorFlow 1.1"),
resource.TestCheckResourceAttr(resourceName, "programming_lang", "Python 3.8"),
resource.TestCheckResourceAttr(resourceName, "aliases.#", "2"),
resource.TestCheckTypeSetElemAttr(resourceName, "aliases.*", "latest"),
resource.TestCheckTypeSetElemAttr(resourceName, "aliases.*", "stable"),
),
},
{
Expand All @@ -118,6 +121,9 @@ func TestAccSageMakerImageVersion_full(t *testing.T) {
resource.TestCheckResourceAttr(resourceName, "job_type", "TRAINING"),
resource.TestCheckResourceAttr(resourceName, "ml_framework", "TensorFlow 1.1"),
resource.TestCheckResourceAttr(resourceName, "programming_lang", "Python 3.8"),
resource.TestCheckResourceAttr(resourceName, "aliases.#", "2"),
resource.TestCheckTypeSetElemAttr(resourceName, "aliases.*", "latest"),
resource.TestCheckTypeSetElemAttr(resourceName, "aliases.*", "stable"),
),
},
},
Expand Down Expand Up @@ -191,7 +197,7 @@ func testAccCheckImageVersionDestroy(ctx context.Context) resource.TestCheckFunc
continue
}

_, err := tfsagemaker.FindImageVersionByName(ctx, conn, rs.Primary.ID)
_, err := tfsagemaker.FindImageVersionByNameAndVersion(ctx, conn, rs.Primary.ID)

if tfresource.NotFound(err) {
continue
Expand Down Expand Up @@ -220,7 +226,7 @@ func testAccCheckImageVersionExists(ctx context.Context, n string, image *sagema
}

conn := acctest.Provider.Meta().(*conns.AWSClient).SageMakerClient(ctx)
resp, err := tfsagemaker.FindImageVersionByName(ctx, conn, rs.Primary.ID)
resp, err := tfsagemaker.FindImageVersionByNameAndVersion(ctx, conn, rs.Primary.ID)
if err != nil {
return err
}
Expand Down Expand Up @@ -285,6 +291,7 @@ resource "aws_sagemaker_image_version" "test" {
vendor_guidance = "STABLE"
ml_framework = "TensorFlow 1.1"
programming_lang = "Python 3.8"
aliases = ["latest", "stable"]
}
`, baseImage, notes)
}
15 changes: 13 additions & 2 deletions internal/service/sagemaker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package sagemaker

import (
"context"
"strings"

"github.com/aws/aws-sdk-go-v2/service/sagemaker"
awstypes "github.com/aws/aws-sdk-go-v2/service/sagemaker/types"
Expand Down Expand Up @@ -60,9 +61,19 @@ func statusImage(ctx context.Context, conn *sagemaker.Client, name string) retry
}
}

func statusImageVersion(ctx context.Context, conn *sagemaker.Client, name string) retry.StateRefreshFunc {
func statusImageVersion(ctx context.Context, conn *sagemaker.Client, id string) retry.StateRefreshFunc {
return func() (any, string, error) {
output, err := findImageVersionByName(ctx, conn, name)
// Check if the ID contains a version (has a colon)
var output *sagemaker.DescribeImageVersionOutput
var err error

if strings.Contains(id, ":") {
// New format - use the new function
output, err = findImageVersionByNameAndVersion(ctx, conn, id)
} else {
// Legacy format - just the name
output, err = findImageVersionByName(ctx, conn, id)
}

if tfresource.NotFound(err) {
return nil, "", nil
Expand Down
Loading
Loading