diff --git a/go.mod b/go.mod index b4abff4..fab6724 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/go-resty/resty/v2 v2.14.0 + github.com/jarcoal/httpmock v1.3.1 github.com/stretchr/testify v1.9.0 ) diff --git a/go.sum b/go.sum index 0855dea..d1d3c8e 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,10 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/go-resty/resty/v2 v2.14.0 h1:/rhkzsAqGQkozwfKS5aFAbb6TyKd3zyFRWcdRXLPCAU= github.com/go-resty/resty/v2 v2.14.0/go.mod h1:IW6mekUOsElt9C7oWr0XRt9BNSD6D5rr9mhk6NjmNHg= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/jarcoal/httpmock v1.3.1 h1:iUx3whfZWVf3jT01hQTO/Eo5sAYtB2/rqaUuOtpInww= +github.com/jarcoal/httpmock v1.3.1/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= +github.com/maxatome/go-testdeep v1.12.0 h1:Ql7Go8Tg0C1D/uMMX59LAoYK7LffeJQ6X2T04nTH68g= +github.com/maxatome/go-testdeep v1.12.0/go.mod h1:lPZc/HAcJMP92l7yI6TRz1aZN5URwUBUAfUNvrclaNM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= diff --git a/userdata.go b/userdata.go index dcc666e..1757dc6 100644 --- a/userdata.go +++ b/userdata.go @@ -1,13 +1,18 @@ package metadata import ( + "bytes" + "compress/gzip" "context" "encoding/base64" "fmt" + "io" ) +var gzipMagic = []byte{0x1F, 0x8B, 0x08} + // GetUserData returns the user data for the current instance. -// NOTE: The result of this endpoint is automatically decoded from base64. +// NOTE: The result of this endpoint is automatically decoded from base64 and un-gzipped if needed. func (c *Client) GetUserData(ctx context.Context) (string, error) { req := c.R(ctx) @@ -21,6 +26,28 @@ func (c *Client) GetUserData(ctx context.Context) (string, error) { if err != nil { return "", fmt.Errorf("failed to decode user-data: %w", err) } + rawBytes, err := ungzipIfNeeded(decodedBytes) + if err != nil { + return "", fmt.Errorf("failed to ungzip user-data: %w", err) + } + return string(rawBytes), nil +} - return string(decodedBytes), nil +// hasGzipMagicNumber checks for the gzipMagic bytes at the beginning of the source +func hasGzipMagicNumber(source []byte) bool { + return bytes.HasPrefix(source, gzipMagic) +} + +// ungzipIfNeeded checks for the gzip magic number and unzips the bytes if necessary, +// otherwise it returns the original raw bytes +func ungzipIfNeeded(raw []byte) ([]byte, error) { + if !hasGzipMagicNumber(raw) { + return raw, nil + } + reader, err := gzip.NewReader(bytes.NewReader(raw)) + if err != nil { + return nil, err + } + defer reader.Close() + return io.ReadAll(reader) } diff --git a/userdata_test.go b/userdata_test.go index 24ce1e6..06acb93 100644 --- a/userdata_test.go +++ b/userdata_test.go @@ -3,44 +3,115 @@ package metadata import ( "context" "encoding/base64" - "errors" + "fmt" + "net/http" "testing" + "github.com/jarcoal/httpmock" + "github.com/stretchr/testify/assert" ) -type UserdataMockClient struct { - UserData string - GetUserDataError error -} +var mockMetadataHost = fmt.Sprintf("%s://%s/%s", APIProto, APIHost, APIVersion) -func (m *UserdataMockClient) GetUserData(ctx context.Context) (string, error) { - if m.GetUserDataError != nil { - return "", m.GetUserDataError - } - return m.UserData, nil +func SetupMockClient() *http.Client { + // create mock client + mockClient := http.DefaultClient + httpmock.ActivateNonDefault(mockClient) + + // Mock out token request + tokenResponder := httpmock.NewStringResponder(200, "[\"token\"]") + httpmock.RegisterResponder("PUT", fmt.Sprintf("%s/token", mockMetadataHost), tokenResponder) + return mockClient } func TestGetUserData_Success(t *testing.T) { - mockClient := &UserdataMockClient{ - UserData: base64.StdEncoding.EncodeToString([]byte("mock-user-data")), - } + mockClient := SetupMockClient() + // Mock out user-data response with the encoded value for "mock-user-data" + instanceResponder := httpmock.NewStringResponder(200, "bW9jay11c2VyLWRhdGE=") + httpmock.RegisterResponder("GET", fmt.Sprintf("%s/user-data", mockMetadataHost), instanceResponder) + + newClient, err := NewClient(context.Background(), func(options *clientCreateConfig) { + options.HTTPClient = mockClient + }) + assert.NoError(t, err, "Expected no error") + + userData, err := newClient.GetUserData(context.Background()) + + assert.NoError(t, err, "Expected no error") - userData, err := mockClient.GetUserData(context.Background()) + assert.Equal(t, "mock-user-data", userData, "Unexpected user data") +} + +func TestGetUserDataGzip_Success(t *testing.T) { + mockClient := SetupMockClient() + // Mock out user-data response with the gzipped encoded value for "mock-user-data" + instanceResponder := httpmock.NewStringResponder(200, "H4sIAO0n32YAA8vNT87WLS1OLdJNSSxJBACRtuznDgAAAA==") + httpmock.RegisterResponder("GET", fmt.Sprintf("%s/user-data", mockMetadataHost), instanceResponder) + + newClient, err := NewClient(context.Background(), func(options *clientCreateConfig) { + options.HTTPClient = mockClient + }) + assert.NoError(t, err, "Expected no error") + + userData, err := newClient.GetUserData(context.Background()) assert.NoError(t, err, "Expected no error") - // Note "bW9jay11c2VyLWRhdGE=" is the encoded value - assert.Equal(t, "bW9jay11c2VyLWRhdGE=", userData, "Unexpected user data") + + assert.Equal(t, "mock-user-data", userData, "Unexpected user data") +} + +func TestGetUserDataGzip_Error(t *testing.T) { + mockClient := SetupMockClient() + // Mock out user-data response with the invalid gzip encoded value for "mock-user-data" + invalidGzipData := []byte{0x1F, 0x8B, 0x08, 0x23} + userDataResponse := base64.StdEncoding.EncodeToString(invalidGzipData) + instanceResponder := httpmock.NewStringResponder(200, userDataResponse) + httpmock.RegisterResponder("GET", fmt.Sprintf("%s/user-data", mockMetadataHost), instanceResponder) + + newClient, err := NewClient(context.Background(), func(options *clientCreateConfig) { + options.HTTPClient = mockClient + }) + assert.NoError(t, err, "Expected no error") + + userData, err := newClient.GetUserData(context.Background()) + + assert.EqualErrorf(t, err, "failed to ungzip user-data: unexpected EOF", "Unexpected error message") + + assert.Equal(t, "", userData, "expected Empty Userdata") } func TestGetUserData_Error(t *testing.T) { - mockClient := &UserdataMockClient{ - GetUserDataError: errors.New("mock error"), - } + mockClient := SetupMockClient() + + instanceResponder := httpmock.NewStringResponder(500, "{\"errors\": [{\"reason\": \"failed to get metadata\"}]}") + httpmock.RegisterResponder("GET", fmt.Sprintf("%s/user-data", mockMetadataHost), instanceResponder) + newClient, err := NewClient(context.Background(), func(options *clientCreateConfig) { + options.HTTPClient = mockClient + }) + assert.NoError(t, err, "Expected no error") + + userData, err := newClient.GetUserData(context.Background()) + + assert.Error(t, err, "Expected an error") + assert.Equal(t, "", userData, "Expected empty user data") + assert.EqualErrorf(t, err, "[500] failed to get metadata", "Unexpected error message") +} + +func TestGetUserDataDecode_Error(t *testing.T) { + mockClient := SetupMockClient() + // Mock out user-data response with the gzipped encoded value for "mock-user-data" + instanceResponder := httpmock.NewStringResponder(200, "invalid base64") + httpmock.RegisterResponder("GET", fmt.Sprintf("%s/user-data", mockMetadataHost), instanceResponder) + + newClient, err := NewClient(context.Background(), func(options *clientCreateConfig) { + options.HTTPClient = mockClient + }) + assert.NoError(t, err, "Expected no error") - userData, err := mockClient.GetUserData(context.Background()) + userData, err := newClient.GetUserData(context.Background()) assert.Error(t, err, "Expected an error") assert.Equal(t, "", userData, "Expected empty user data") - assert.EqualError(t, err, "mock error", "Unexpected error message") + assert.EqualErrorf(t, err, "failed to decode user-data: illegal base64 data at input byte 7", "Unexpected error message") }