Skip to content

Commit

Permalink
ec2 client credential init fix
Browse files Browse the repository at this point in the history
current use of ec2.New() does not work with IRSA. (not tested but most like same problem with EKS Pod Identity as well)

```
(pkg/util/ec2/ec2_tags.go:104 in fetchEc2TagsFromAPI) | unable to get tags using default credentials (falling back to instance role): operation error EC2: DescribeTags, https response error StatusCode: 400, RequestID: 1234-1234-1234, api error MissingParameter: The request must contain the parameter AWSAccessKeyId
```

This change is to change to use aws config package to help initialise aws ec2 client with the built-in credentials such as AWS IRSA tokens.
  • Loading branch information
Greyeye committed Oct 6, 2024
1 parent 66190f7 commit c1203db
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 10 deletions.
26 changes: 19 additions & 7 deletions pkg/util/ec2/ec2_tags.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
Expand Down Expand Up @@ -97,7 +98,7 @@ func fetchEc2TagsFromAPI(ctx context.Context) ([]string, error) {
// except when a more specific role (e.g. task role in ECS) does not have
// EC2:DescribeTags permission, but a more general role (e.g. instance role)
// does have it.
tags, err := getTagsWithCreds(ctx, instanceIdentity, nil)
tags, err := getTagsWithCreds(ctx, instanceIdentity, ec2Cconnection(ctx, instanceIdentity.Region, nil))
if err == nil {
return tags, nil
}
Expand All @@ -111,14 +112,25 @@ func fetchEc2TagsFromAPI(ctx context.Context) ([]string, error) {
}

awsCreds := credentials.NewStaticCredentialsProvider(iamParams.AccessKeyID, iamParams.SecretAccessKey, iamParams.Token)
return getTagsWithCreds(ctx, instanceIdentity, awsCreds)
return getTagsWithCreds(ctx, instanceIdentity, ec2Cconnection(ctx, instanceIdentity.Region, awsCreds))
}

func getTagsWithCreds(ctx context.Context, instanceIdentity *EC2Identity, awsCreds aws.CredentialsProvider) ([]string, error) {
connection := ec2.New(ec2.Options{
Region: instanceIdentity.Region,
Credentials: awsCreds,
})
type ec2ClientInterface interface {
DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
}

// ec2Cconnection creates an ec2 client with the given region and credentials
func ec2Cconnection(ctx context.Context, region string, awsCreds aws.CredentialsProvider) ec2ClientInterface {
// using aws config to read the build in credentials to set up the ec2 client.
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region), config.WithCredentialsProvider(awsCreds))
if err != nil {
log.Warnf("unable to get aws configurations: %s", err)
return nil
}
return ec2.NewFromConfig(cfg)
}

func getTagsWithCreds(ctx context.Context, instanceIdentity *EC2Identity, connection ec2ClientInterface) ([]string, error) {

// We want to use 'ec2_metadata_timeout' here instead of current context. 'ctx' comes from the agent main and will
// only be canceled if the agent is stopped. The default timeout for the AWS SDK is 1 minutes (20s timeout with
Expand Down
111 changes: 108 additions & 3 deletions pkg/util/ec2/ec2_tags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,22 @@ package ec2

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

configmock "github.com/DataDog/datadog-agent/pkg/config/mock"
pkgconfigsetup "github.com/DataDog/datadog-agent/pkg/config/setup"
"github.com/DataDog/datadog-agent/pkg/util/cache"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestGetIAMRole(t *testing.T) {
Expand Down Expand Up @@ -208,3 +211,105 @@ func TestGetTagsFullWorkflow(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, []string{"tag1", "tag2"}, tags)
}

// Mock implementation of ec2ClientInterface
type mockEC2Client struct {
DescribeTagsFunc func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
}

func (m *mockEC2Client) DescribeTags(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return m.DescribeTagsFunc(ctx, params, optFns...)
}

// Helper function to compare slices of strings
func equalStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
aMap := make(map[string]struct{}, len(a))
for _, v := range a {
aMap[v] = struct{}{}
}
for _, v := range b {
if _, ok := aMap[v]; !ok {
return false
}
}
return true
}

func TestGetTagsWithCreds(t *testing.T) {
tests := []struct {
name string
instanceIdentity *EC2Identity
mockDescribeTags func(ctx context.Context, params *ec2.DescribeTagsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error)
expectedTags []string
expectedError assert.ErrorAssertionFunc
}{
{
name: "Successful retrieval of tags",
instanceIdentity: &EC2Identity{
InstanceID: "i-1234567890abcdef0",
},
mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return &ec2.DescribeTagsOutput{
Tags: []types.TagDescription{
{Key: aws.String("Name"), Value: aws.String("TestInstance")},
{Key: aws.String("Env"), Value: aws.String("Production")},
},
}, nil
},
expectedTags: []string{"Name:TestInstance", "Env:Production"},
expectedError: assert.NoError,
},
{
name: "Excluded tags are filtered out",
instanceIdentity: &EC2Identity{
InstanceID: "i-1234567890abcdef0",
},
mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return &ec2.DescribeTagsOutput{
Tags: []types.TagDescription{
{Key: aws.String("Name"), Value: aws.String("TestInstance")},
{Key: aws.String("aws:cloudformation:stack-name"), Value: aws.String("MyStack")},
},
}, nil
},
expectedTags: []string{"Name:TestInstance", "aws:cloudformation:stack-name:MyStack"},
expectedError: assert.NoError,
},
{
name: "DescribeTags returns an error",
instanceIdentity: &EC2Identity{
InstanceID: "i-1234567890abcdef0",
},
mockDescribeTags: func(_ context.Context, _ *ec2.DescribeTagsInput, _ ...func(*ec2.Options)) (*ec2.DescribeTagsOutput, error) {
return nil, errors.New("DescribeTags error")
},
expectedTags: nil,
expectedError: assert.Error,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Mock the EC2 client
mockClient := &mockEC2Client{
DescribeTagsFunc: tt.mockDescribeTags,
}

// Create a background context
ctx := context.Background()

// Call the function under test
tags, err := getTagsWithCreds(ctx, tt.instanceIdentity, mockClient)

// Validate the error
tt.expectedError(t, err)
// Validate the tags
if !equalStringSlices(tags, tt.expectedTags) {
t.Fatalf("Expected tags '%v', got '%v'", tt.expectedTags, tags)
}
})
}
}

0 comments on commit c1203db

Please sign in to comment.