diff --git a/pkg/util/ec2/ec2_tags.go b/pkg/util/ec2/ec2_tags.go index d60f93bf6256b..2dc6c80683ea1 100644 --- a/pkg/util/ec2/ec2_tags.go +++ b/pkg/util/ec2/ec2_tags.go @@ -15,6 +15,7 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" @@ -97,7 +98,7 @@ func fetchEc2TagsFromAPI(ctx context.Context) ([]string, error) { // except when a more specific role (e.g. task role in ECS) does not have // EC2:DescribeTags permission, but a more general role (e.g. instance role) // does have it. - tags, err := getTagsWithCreds(ctx, instanceIdentity, nil) + tags, err := getTagsWithCreds(ctx, instanceIdentity, ec2Cconnection(ctx, instanceIdentity.Region, nil)) if err == nil { return tags, nil } @@ -111,14 +112,25 @@ func fetchEc2TagsFromAPI(ctx context.Context) ([]string, error) { } awsCreds := credentials.NewStaticCredentialsProvider(iamParams.AccessKeyID, iamParams.SecretAccessKey, iamParams.Token) - return getTagsWithCreds(ctx, instanceIdentity, awsCreds) + return getTagsWithCreds(ctx, instanceIdentity, ec2Cconnection(ctx, instanceIdentity.Region, awsCreds)) } -func getTagsWithCreds(ctx context.Context, instanceIdentity *EC2Identity, awsCreds aws.CredentialsProvider) ([]string, error) { - connection := ec2.New(ec2.Options{ - Region: instanceIdentity.Region, - Credentials: awsCreds, - }) +type ec2ClientInterface interface { + DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) +} + +// ec2Cconnection creates an ec2 client with the given region and credentials +func ec2Cconnection(ctx context.Context, region string, awsCreds aws.CredentialsProvider) ec2ClientInterface { + // using aws config to read the build in credentials to set up the ec2 client. + cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region), config.WithCredentialsProvider(awsCreds)) + if err != nil { + log.Warnf("unable to get aws configurations: %s", err) + return nil + } + return ec2.NewFromConfig(cfg) +} + +func getTagsWithCreds(ctx context.Context, instanceIdentity *EC2Identity, connection ec2ClientInterface) ([]string, error) { // We want to use 'ec2_metadata_timeout' here instead of current context. 'ctx' comes from the agent main and will // only be canceled if the agent is stopped. The default timeout for the AWS SDK is 1 minutes (20s timeout with diff --git a/pkg/util/ec2/ec2_tags_test.go b/pkg/util/ec2/ec2_tags_test.go index cce8a5800df2c..b3b9fdf20bf13 100644 --- a/pkg/util/ec2/ec2_tags_test.go +++ b/pkg/util/ec2/ec2_tags_test.go @@ -9,6 +9,7 @@ package ec2 import ( "context" + "errors" "fmt" "io" "net/http" @@ -16,12 +17,14 @@ import ( "os" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - configmock "github.com/DataDog/datadog-agent/pkg/config/mock" pkgconfigsetup "github.com/DataDog/datadog-agent/pkg/config/setup" "github.com/DataDog/datadog-agent/pkg/util/cache" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGetIAMRole(t *testing.T) { @@ -208,3 +211,105 @@ func TestGetTagsFullWorkflow(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []string{"tag1", "tag2"}, tags) } + +// Mock implementation of ec2ClientInterface +type mockEC2Client struct { + DescribeTagsFunc func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) +} + +func (m *mockEC2Client) DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return m.DescribeTagsFunc(ctx, params, optFns...) +} + +// Helper function to compare slices of strings +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + aMap := make(map[string]struct{}, len(a)) + for _, v := range a { + aMap[v] = struct{}{} + } + for _, v := range b { + if _, ok := aMap[v]; !ok { + return false + } + } + return true +} + +func TestGetTagsWithCreds(t *testing.T) { + tests := []struct { + name string + instanceIdentity *EC2Identity + mockDescribeTags func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) + expectedTags []string + expectedError assert.ErrorAssertionFunc + }{ + { + name: "Successful retrieval of tags", + instanceIdentity: &EC2Identity{ + InstanceID: "i-1234567890abcdef0", + }, + mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return &ec2.DescribeTagsOutput{ + Tags: []types.TagDescription{ + {Key: aws.String("Name"), Value: aws.String("TestInstance")}, + {Key: aws.String("Env"), Value: aws.String("Production")}, + }, + }, nil + }, + expectedTags: []string{"Name:TestInstance", "Env:Production"}, + expectedError: assert.NoError, + }, + { + name: "Excluded tags are filtered out", + instanceIdentity: &EC2Identity{ + InstanceID: "i-1234567890abcdef0", + }, + mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return &ec2.DescribeTagsOutput{ + Tags: []types.TagDescription{ + {Key: aws.String("Name"), Value: aws.String("TestInstance")}, + {Key: aws.String("aws:cloudformation:stack-name"), Value: aws.String("MyStack")}, + }, + }, nil + }, + expectedTags: []string{"Name:TestInstance", "aws:cloudformation:stack-name:MyStack"}, + expectedError: assert.NoError, + }, + { + name: "DescribeTags returns an error", + instanceIdentity: &EC2Identity{ + InstanceID: "i-1234567890abcdef0", + }, + mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) { + return nil, errors.New("DescribeTags error") + }, + expectedTags: nil, + expectedError: assert.Error, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock the EC2 client + mockClient := &mockEC2Client{ + DescribeTagsFunc: tt.mockDescribeTags, + } + + // Create a background context + ctx := context.Background() + + // Call the function under test + tags, err := getTagsWithCreds(ctx, tt.instanceIdentity, mockClient) + + // Validate the error + tt.expectedError(t, err) + // Validate the tags + if !equalStringSlices(tags, tt.expectedTags) { + t.Fatalf("Expected tags '%v', got '%v'", tt.expectedTags, tags) + } + }) + } +}