Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ec2 client credential init fix for AWS IRSA #29784

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions pkg/util/ec2/ec2_tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
Expand Down Expand Up @@ -97,7 +98,7 @@ func fetchEc2TagsFromAPI(ctx context.Context) ([]string, error) {
// except when a more specific role (e.g. task role in ECS) does not have
// EC2:DescribeTags permission, but a more general role (e.g. instance role)
// does have it.
tags, err := getTagsWithCreds(ctx, instanceIdentity, nil)
tags, err := getTagsWithCreds(ctx, instanceIdentity, ec2Cconnection(ctx, instanceIdentity.Region, nil))
if err == nil {
return tags, nil
}
Expand All @@ -111,14 +112,25 @@ func fetchEc2TagsFromAPI(ctx context.Context) ([]string, error) {
}

awsCreds := credentials.NewStaticCredentialsProvider(iamParams.AccessKeyID, iamParams.SecretAccessKey, iamParams.Token)
return getTagsWithCreds(ctx, instanceIdentity, awsCreds)
return getTagsWithCreds(ctx, instanceIdentity, ec2Cconnection(ctx, instanceIdentity.Region, awsCreds))
}

func getTagsWithCreds(ctx context.Context, instanceIdentity *EC2Identity, awsCreds aws.CredentialsProvider) ([]string, error) {
connection := ec2.New(ec2.Options{
Region: instanceIdentity.Region,
Credentials: awsCreds,
})
type ec2ClientInterface interface {
DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
}

// ec2Cconnection creates an ec2 client with the given region and credentials
func ec2Cconnection(ctx context.Context, region string, awsCreds aws.CredentialsProvider) ec2ClientInterface {
// using aws config to read the build in credentials to set up the ec2 client.
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region), config.WithCredentialsProvider(awsCreds))
if err != nil {
log.Warnf("unable to get aws configurations: %s", err)
return nil
}
return ec2.NewFromConfig(cfg)
}

func getTagsWithCreds(ctx context.Context, instanceIdentity *EC2Identity, connection ec2ClientInterface) ([]string, error) {

// We want to use 'ec2_metadata_timeout' here instead of current context. 'ctx' comes from the agent main and will
// only be canceled if the agent is stopped. The default timeout for the AWS SDK is 1 minutes (20s timeout with
Expand Down
111 changes: 108 additions & 3 deletions pkg/util/ec2/ec2_tags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ package ec2

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

configmock "github.com/DataDog/datadog-agent/pkg/config/mock"
pkgconfigsetup "github.com/DataDog/datadog-agent/pkg/config/setup"
"github.com/DataDog/datadog-agent/pkg/util/cache"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetIAMRole(t *testing.T) {
Expand Down Expand Up @@ -208,3 +211,105 @@ func TestGetTagsFullWorkflow(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag2"}, tags)
}

// Mock implementation of ec2ClientInterface
type mockEC2Client struct {
DescribeTagsFunc func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
}

func (m *mockEC2Client) DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return m.DescribeTagsFunc(ctx, params, optFns...)
}

// Helper function to compare slices of strings
func equalStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
aMap := make(map[string]struct{}, len(a))
for _, v := range a {
aMap[v] = struct{}{}
}
for _, v := range b {
if _, ok := aMap[v]; !ok {
return false
}
}
return true
}

func TestGetTagsWithCreds(t *testing.T) {
tests := []struct {
name string
instanceIdentity *EC2Identity
mockDescribeTags func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
expectedTags []string
expectedError assert.ErrorAssertionFunc
}{
{
name: "Successful retrieval of tags",
instanceIdentity: &EC2Identity{
InstanceID: "i-1234567890abcdef0",
},
mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return &ec2.DescribeTagsOutput{
Tags: []types.TagDescription{
{Key: aws.String("Name"), Value: aws.String("TestInstance")},
{Key: aws.String("Env"), Value: aws.String("Production")},
},
}, nil
},
expectedTags: []string{"Name:TestInstance", "Env:Production"},
expectedError: assert.NoError,
},
{
name: "Excluded tags are filtered out",
instanceIdentity: &EC2Identity{
InstanceID: "i-1234567890abcdef0",
},
mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return &ec2.DescribeTagsOutput{
Tags: []types.TagDescription{
{Key: aws.String("Name"), Value: aws.String("TestInstance")},
{Key: aws.String("aws:cloudformation:stack-name"), Value: aws.String("MyStack")},
},
}, nil
},
expectedTags: []string{"Name:TestInstance", "aws:cloudformation:stack-name:MyStack"},
expectedError: assert.NoError,
},
{
name: "DescribeTags returns an error",
instanceIdentity: &EC2Identity{
InstanceID: "i-1234567890abcdef0",
},
mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return nil, errors.New("DescribeTags error")
},
expectedTags: nil,
expectedError: assert.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock the EC2 client
mockClient := &mockEC2Client{
DescribeTagsFunc: tt.mockDescribeTags,
}

// Create a background context
ctx := context.Background()

// Call the function under test
tags, err := getTagsWithCreds(ctx, tt.instanceIdentity, mockClient)

// Validate the error
tt.expectedError(t, err)
// Validate the tags
if !equalStringSlices(tags, tt.expectedTags) {
t.Fatalf("Expected tags '%v', got '%v'", tt.expectedTags, tags)
}
})
}
}