diff --git a/pkg/storages/azure/folder.go b/pkg/storages/azure/folder.go index b8599ce3e..b47c82951 100644 --- a/pkg/storages/azure/folder.go +++ b/pkg/storages/azure/folder.go @@ -28,6 +28,7 @@ const ( BufferSizeSetting = "AZURE_BUFFER_SIZE" MaxBuffersSetting = "AZURE_MAX_BUFFERS" TryTimeoutSetting = "AZURE_TRY_TIMEOUT" + ClientIDSetting = "AZURE_CLIENT_ID" minBufferSize = 1024 defaultBufferSize = 8 * 1024 * 1024 minBuffers = 1 @@ -40,8 +41,9 @@ const ( type AzureAuthType string const ( - AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth" - AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth" + AzureAccessKeyAuth AzureAuthType = "AzureAccessKeyAuth" + AzureSASTokenAuth AzureAuthType = "AzureSASTokenAuth" + AzureManagedIdentityAuth AzureAuthType = "AzureManagedIdentityAuth" ) var SettingList = []string{ @@ -52,6 +54,7 @@ var SettingList = []string{ EndpointSuffix, BufferSizeSetting, MaxBuffersSetting, + ClientIDSetting, } func NewFolderError(err error, format string, args ...interface{}) storage.Error { @@ -78,6 +81,31 @@ func NewFolder( } } +func getContainerClientWithManagedIndetity( + accountName string, + storageEndpointSuffix string, + containerName string, + timeout time.Duration, + clientID string) (*azblob.ContainerClient, error) { + cred, err := azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{ + ID: azidentity.ClientID(clientID), + }) + if err != nil { + return nil, err + } + + containerURLString := fmt.Sprintf("https://%s.blob.%s/%s", accountName, storageEndpointSuffix, containerName) + _, err = url.Parse(containerURLString) + if err != nil { + return nil, NewFolderError(err, "Unable to parse service URL") + } + + containerClient, err := azblob.NewContainerClient(containerURLString, cred, &azblob.ClientOptions{ + Retry: policy.RetryOptions{TryTimeout: timeout}, + }) + return containerClient, err +} + func getContainerClientWithSASToken( accountName string, storageEndpointSuffix string, @@ -136,9 +164,9 @@ func getContainerClient( return containerClient, err } -func configureAuthType(settings map[string]string) (AzureAuthType, string, string) { +func configureAuthType(settings map[string]string) (AzureAuthType, string, string, string) { var ok bool - var accountToken, accessKey string + var accountToken, accessKey, clientID string var authType AzureAuthType if accessKey, ok = settings[AccessKeySetting]; ok { @@ -149,9 +177,11 @@ func configureAuthType(settings map[string]string) (AzureAuthType, string, strin if !strings.HasPrefix(accountToken, "?") { accountToken = "?" + accountToken } + } else if clientID, ok = settings[ClientIDSetting]; ok { + authType = AzureManagedIdentityAuth } - return authType, accountToken, accessKey + return authType, accountToken, accessKey, clientID } func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, error) { @@ -161,7 +191,7 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, return nil, NewCredentialError(AccountSetting) } - authType, accountToken, accountKey := configureAuthType(settings) + authType, accountToken, accountKey, clientID := configureAuthType(settings) var credential *azblob.SharedKeyCredential var err error @@ -201,6 +231,8 @@ func ConfigureFolder(prefix string, settings map[string]string) (storage.Folder, containerClient, err = getContainerClientWithSASToken(accountName, storageEndpointSuffix, containerName, timeout, accountToken) } else if authType == AzureAccessKeyAuth { containerClient, err = getContainerClientWithAccessKey(accountName, storageEndpointSuffix, containerName, timeout, credential) + } else if authType == AzureManagedIdentityAuth { + containerClient, err = getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID) } else { // No explicitly configured auth method, try the default credential chain containerClient, err = getContainerClient(accountName, storageEndpointSuffix, containerName, timeout) diff --git a/pkg/storages/azure/folder_test.go b/pkg/storages/azure/folder_test.go index da4b1f637..d3d4ccecf 100644 --- a/pkg/storages/azure/folder_test.go +++ b/pkg/storages/azure/folder_test.go @@ -2,9 +2,13 @@ package azure import ( "testing" + "time" "github.com/stretchr/testify/assert" "github.com/wal-g/wal-g/pkg/storages/storage" + "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob" + "github.com/golang/mock/gomock" ) func TestAzureFolder(t *testing.T) { @@ -22,24 +26,72 @@ var ConfigureAuthType = configureAuthType func TestConfigureAccessKeyAuthType(t *testing.T) { settings := map[string]string{AccessKeySetting: "foo"} - authType, accountToken, accessKey := ConfigureAuthType(settings) + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) assert.Equal(t, authType, AzureAccessKeyAuth) assert.Empty(t, accountToken) assert.Equal(t, accessKey, "foo") + assert.Empty(t, clientID) } func TestConfigureSASTokenAuth(t *testing.T) { settings := map[string]string{SasTokenSetting: "foo"} - authType, accountToken, accessKey := ConfigureAuthType(settings) + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) assert.Equal(t, authType, AzureSASTokenAuth) assert.Equal(t, accountToken, "?foo") assert.Empty(t, accessKey) + assert.Empty(t, clientID) } func TestConfigureDefaultAuth(t *testing.T) { settings := make(map[string]string) - authType, accountToken, accessKey := ConfigureAuthType(settings) + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) assert.Empty(t, authType) assert.Empty(t, accountToken) assert.Empty(t, accessKey) + assert.Empty(t, clientID) } + +func TestConfigureManagedIdentityAuth(t *testing.T) { + settings := map[string]string{ClientIDSetting: "foo"} + authType, accountToken, accessKey, clientID := ConfigureAuthType(settings) + assert.Equal(t, authType, AzureManagedIdentityAuth) + assert.Empty(t, accountToken) + assert.Empty(t, accessKey) + assert.Equal(t, clientID, "foo") +} +func TestGetContainerClientWithManagedIdentity(t *testing.T) { + accountName := "test-account" + storageEndpointSuffix := "test-endpoint" + containerName := "test-container" + timeout := time.Minute + clientID := "test-client-id" + + containerClient, err := getContainerClientWithManagedIndetity(accountName, storageEndpointSuffix, containerName, timeout, clientID) + assert.NoError(t, err) + assert.NotNil(t, containerClient) +} + +func TestGetContainerClientWithManagedIdentity2(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCred := azidentity.NewMockManagedIdentityCredential(ctrl) + mockContainerClient := azblob.NewMockContainerClient(ctrl) + + accountName := "testAccount" + storageEndpointSuffix := "core.windows.net" + containerName := "testContainer" + timeout := time.Second * 10 + clientID := "testClientID" + + mockCred.EXPECT().NewManagedIdentityCredential(gomock.Any()).Return(mockCred, nil) + mockContainerClient.EXPECT().NewContainerClient(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockContainerClient, nil) + + containerClient, err := getContainerClientWithManagedIdentity(accountName, storageEndpointSuffix, containerName, timeout, clientID) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if containerClient == nil { + t.Error("Expected ContainerClient, got nil") + } +} \ No newline at end of file