Skip to content

Commit

Permalink
Add Managed Identity option to azure storage
Browse files Browse the repository at this point in the history
  • Loading branch information
m7hm7t committed Jun 1, 2024
1 parent e0d0d90 commit 7935cf2
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 9 deletions.
44 changes: 38 additions & 6 deletions pkg/storages/azure/folder.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const (
BufferSizeSetting = "AZURE_BUFFER_SIZE"
MaxBuffersSetting = "AZURE_MAX_BUFFERS"
TryTimeoutSetting = "AZURE_TRY_TIMEOUT"
ClientIDSetting = "AZURE_CLIENT_ID"
minBufferSize = 1024
defaultBufferSize = 8 * 1024 * 1024
minBuffers = 1
Expand All @@ -40,8 +41,9 @@ const (
type AzureAuthType string

const (
AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth"
AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth"
AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth"
AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth"
AzureManagedIdentityAuth AzureAuthType = "AzureManagedIdentityAuth"
)

var SettingList = []string{
Expand All @@ -52,6 +54,7 @@ var SettingList = []string{
EndpointSuffix,
BufferSizeSetting,
MaxBuffersSetting,
ClientIDSetting,
}

func NewFolderError(err error, format string, args ...interface{}) storage.Error {
Expand All @@ -78,6 +81,31 @@ func NewFolder(
}
}

func getContainerClientWithManagedIndetity(
accountName string,
storageEndpointSuffix string,
containerName string,
timeout time.Duration,
clientID string) (*azblob.ContainerClient, error) {
cred, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(clientID),
})
if err != nil {
return nil, err
}

containerURLString := fmt.Sprintf("https://%s.blob.%s/%s", accountName, storageEndpointSuffix, containerName)
_, err = url.Parse(containerURLString)
if err != nil {
return nil, NewFolderError(err, "Unable to parse service URL")
}

containerClient, err := azblob.NewContainerClient(containerURLString, cred, &azblob.ClientOptions{
Retry: policy.RetryOptions{TryTimeout: timeout},
})
return containerClient, err
}

func getContainerClientWithSASToken(
accountName string,
storageEndpointSuffix string,
Expand Down Expand Up @@ -136,9 +164,9 @@ func getContainerClient(
return containerClient, err
}

func configureAuthType(settings map[string]string) (AzureAuthType, string, string) {
func configureAuthType(settings map[string]string) (AzureAuthType, string, string, string) {
var ok bool
var accountToken, accessKey string
var accountToken, accessKey, clientID string
var authType AzureAuthType

if accessKey, ok = settings[AccessKeySetting]; ok {
Expand All @@ -149,9 +177,11 @@ func configureAuthType(settings map[string]string) (AzureAuthType, string, strin
if !strings.HasPrefix(accountToken, "?") {
accountToken = "?" + accountToken
}
} else if clientID, ok = settings[ClientIDSetting]; ok {
authType = AzureManagedIdentityAuth
}

return authType, accountToken, accessKey
return authType, accountToken, accessKey, clientID
}

func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, error) {
Expand All @@ -161,7 +191,7 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder,
return nil, NewCredentialError(AccountSetting)
}

authType, accountToken, accountKey := configureAuthType(settings)
authType, accountToken, accountKey, clientID := configureAuthType(settings)

var credential *azblob.SharedKeyCredential
var err error
Expand Down Expand Up @@ -201,6 +231,8 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder,
containerClient, err = getContainerClientWithSASToken(accountName, storageEndpointSuffix, containerName, timeout, accountToken)
} else if authType == AzureAccessKeyAuth {
containerClient, err = getContainerClientWithAccessKey(accountName, storageEndpointSuffix, containerName, timeout, credential)
} else if authType == AzureManagedIdentityAuth {
containerClient, err = getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID)
} else {
// No explicitly configured auth method, try the default credential chain
containerClient, err = getContainerClient(accountName, storageEndpointSuffix, containerName, timeout)
Expand Down
58 changes: 55 additions & 3 deletions pkg/storages/azure/folder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package azure

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/wal-g/wal-g/pkg/storages/storage"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/golang/mock/gomock"
)

func TestAzureFolder(t *testing.T) {
Expand All @@ -22,24 +26,72 @@ var ConfigureAuthType = configureAuthType

func TestConfigureAccessKeyAuthType(t *testing.T) {
settings := map[string]string{AccessKeySetting: "foo"}
authType, accountToken, accessKey := ConfigureAuthType(settings)
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Equal(t, authType, AzureAccessKeyAuth)
assert.Empty(t, accountToken)
assert.Equal(t, accessKey, "foo")
assert.Empty(t, clientID)
}

func TestConfigureSASTokenAuth(t *testing.T) {
settings := map[string]string{SasTokenSetting: "foo"}
authType, accountToken, accessKey := ConfigureAuthType(settings)
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Equal(t, authType, AzureSASTokenAuth)
assert.Equal(t, accountToken, "?foo")
assert.Empty(t, accessKey)
assert.Empty(t, clientID)
}

func TestConfigureDefaultAuth(t *testing.T) {
settings := make(map[string]string)
authType, accountToken, accessKey := ConfigureAuthType(settings)
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Empty(t, authType)
assert.Empty(t, accountToken)
assert.Empty(t, accessKey)
assert.Empty(t, clientID)
}

func TestConfigureManagedIdentityAuth(t *testing.T) {
settings := map[string]string{ClientIDSetting: "foo"}
authType, accountToken, accessKey, clientID := ConfigureAuthType(settings)
assert.Equal(t, authType, AzureManagedIdentityAuth)
assert.Empty(t, accountToken)
assert.Empty(t, accessKey)
assert.Equal(t, clientID, "foo")
}
func TestGetContainerClientWithManagedIdentity(t *testing.T) {
accountName := "test-account"
storageEndpointSuffix := "test-endpoint"
containerName := "test-container"
timeout := time.Minute
clientID := "test-client-id"

containerClient, err := getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID)
assert.NoError(t, err)
assert.NotNil(t, containerClient)
}

func TestGetContainerClientWithManagedIdentity2(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockCred := azidentity.NewMockManagedIdentityCredential(ctrl)

Check failure on line 78 in pkg/storages/azure/folder_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: azidentity.NewMockManagedIdentityCredential
mockContainerClient := azblob.NewMockContainerClient(ctrl)

Check failure on line 79 in pkg/storages/azure/folder_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: azblob.NewMockContainerClient

accountName := "testAccount"
storageEndpointSuffix := "core.windows.net"
containerName := "testContainer"
timeout := time.Second * 10
clientID := "testClientID"

mockCred.EXPECT().NewManagedIdentityCredential(gomock.Any()).Return(mockCred, nil)
mockContainerClient.EXPECT().NewContainerClient(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockContainerClient, nil)

containerClient, err := getContainerClientWithManagedIdentity(accountName, storageEndpointSuffix, containerName, timeout, clientID)

Check failure on line 90 in pkg/storages/azure/folder_test.go

View workflow job for this annotation

GitHub Actions / lint

undefined: getContainerClientWithManagedIdentity (typecheck)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if containerClient == nil {
t.Error("Expected ContainerClient, got nil")
}
}

0 comments on commit 7935cf2

Please sign in to comment.