Skip to content

Commit

Permalink
Pass azure_client_id to Azure MSI if specified (#354)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->
Added support for providing client id for MSI use case with multiple
MSIs attached to an azure vm.

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

Manually tested on a vm with multiple MSI.

Created failure first:

![10F7B5BF-23E3-4B62-ACF3-204D91149F5D](https://user-images.githubusercontent.com/54602805/229562710-d6aca548-39b9-4f76-aab8-10b6cab5a62d.png)

Same VM:

![2D0DFC9B-C2E8-468D-ADA0-5C5064D88A11](https://user-images.githubusercontent.com/54602805/229562783-793ae505-4e62-4930-b5b6-99c7b36f192f.png)

- [x] `make test` passing
- [x] `make fmt` applied
- [ ] relevant integration tests applied
  • Loading branch information
stikkireddy authored Apr 3, 2023
1 parent 54d4222 commit 31b924a
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions config/auth_azure_msi.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ func (c AzureMsiCredentials) Configure(ctx context.Context, cfg *Config) (func(*
return nil, fmt.Errorf("resolve host: %w", err)
}
logger.Debugf(ctx, "Generating AAD token via Azure MSI")
inner := azureMsiTokenSource(cfg.getAzureLoginAppID())
platform := azureMsiTokenSource(env.ServiceManagementEndpoint)
inner := azureMsiTokenSource{
resource: cfg.getAzureLoginAppID(),
clientId: cfg.AzureClientID,
}
platform := azureMsiTokenSource{
resource: env.ServiceManagementEndpoint,
clientId: cfg.AzureClientID,
}
return func(r *http.Request) error {
r.Header.Set("X-Databricks-Azure-Workspace-Resource-Id", cfg.AzureResourceID)
return serviceToServiceVisitor(inner, platform,
Expand Down Expand Up @@ -88,11 +94,17 @@ func (c AzureMsiCredentials) getInstanceEnvironment(ctx context.Context) (*azure
}

// implementing azureHostResolver for ensureWorkspaceUrl to work
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, _ *Config, _ azureEnvironment, resource string) oauth2.TokenSource {
return azureMsiTokenSource(resource)
func (c AzureMsiCredentials) tokenSourceFor(_ context.Context, cfg *Config, _ azureEnvironment, resource string) oauth2.TokenSource {
return azureMsiTokenSource{
resource: resource,
clientId: cfg.AzureClientID,
}
}

type azureMsiTokenSource string
type azureMsiTokenSource struct {
resource string
clientId string
}

func (s azureMsiTokenSource) Token() (*oauth2.Token, error) {
ctx, cancel := context.WithTimeout(context.Background(), azureMsiTimeout)
Expand All @@ -104,7 +116,10 @@ func (s azureMsiTokenSource) Token() (*oauth2.Token, error) {
}
query := req.URL.Query()
query.Add("api-version", "2018-02-01")
query.Add("resource", string(s))
query.Add("resource", s.resource)
if s.clientId != "" {
query.Add("client_id", s.clientId)
}
req.URL.RawQuery = query.Encode()
req.Header.Add("Metadata", "true")
res, err := http.DefaultClient.Do(req)
Expand Down

0 comments on commit 31b924a

Please sign in to comment.