Skip to content

Commit

Permalink
provider: move helpers to own source file, log the metdata configurat…
Browse files Browse the repository at this point in the history
…ion mode, tidy up provider tests
  • Loading branch information
manicminer committed Apr 11, 2024
1 parent 352421f commit 8470fa4
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 169 deletions.
117 changes: 117 additions & 0 deletions internal/provider/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package provider

import (
"encoding/base64"
"fmt"
"log"
"os"
"strings"

"github.com/hashicorp/terraform-provider-azuread/internal/tf/pluginsdk"
)

// logEntry avoids log entries showing up in test output
func logEntry(f string, v ...interface{}) {
if os.Getenv("TF_LOG") == "" {
return
}

if os.Getenv("TF_ACC") != "" {
return
}

log.Printf(f, v...)
}

func decodeCertificate(clientCertificate string) ([]byte, error) {
var pfx []byte
if clientCertificate != "" {
out := make([]byte, base64.StdEncoding.DecodedLen(len(clientCertificate)))
n, err := base64.StdEncoding.Decode(out, []byte(clientCertificate))
if err != nil {
return pfx, fmt.Errorf("could not decode client certificate data: %v", err)
}
pfx = out[:n]
}
return pfx, nil
}

func getOidcToken(d *pluginsdk.ResourceData) (*string, error) {
idToken := d.Get("oidc_token").(string)

if path := d.Get("oidc_token_file_path").(string); path != "" {
fileTokenRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading OIDC Token from file %q: %v", path, err)
}

fileToken := strings.TrimSpace(string(fileTokenRaw))

if idToken != "" && idToken != fileToken {
return nil, fmt.Errorf("mismatch between supplied OIDC token and supplied OIDC token file contents - please either remove one or ensure they match")
}

idToken = fileToken
}

return &idToken, nil
}

func getClientId(d *pluginsdk.ResourceData) (*string, error) {
clientId := strings.TrimSpace(d.Get("client_id").(string))

if path := d.Get("client_id_file_path").(string); path != "" {
fileClientIdRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading Client ID from file %q: %v", path, err)
}

fileClientId := strings.TrimSpace(string(fileClientIdRaw))

if clientId != "" && clientId != fileClientId {
return nil, fmt.Errorf("mismatch between supplied Client ID and supplied Client ID file contents - please either remove one or ensure they match")
}

clientId = fileClientId
}

return &clientId, nil
}

func getClientSecret(d *pluginsdk.ResourceData) (*string, error) {
clientSecret := strings.TrimSpace(d.Get("client_secret").(string))

if path := d.Get("client_secret_file_path").(string); path != "" {
fileSecretRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading Client Secret from file %q: %v", path, err)
}

fileSecret := strings.TrimSpace(string(fileSecretRaw))

if clientSecret != "" && clientSecret != fileSecret {
return nil, fmt.Errorf("mismatch between supplied Client Secret and supplied Client Secret file contents - please either remove one or ensure they match")
}

clientSecret = fileSecret
}

return &clientSecret, nil
}

func getTenantId(d *pluginsdk.ResourceData) (*string, error) {
tenantId := strings.TrimSpace(d.Get("tenant_id").(string))

if d.Get("use_aks_workload_identity").(bool) && os.Getenv("AZURE_TENANT_ID") != "" {
aksTenantId := os.Getenv("AZURE_TENANT_ID")
if tenantId != "" && tenantId != aksTenantId {
return nil, fmt.Errorf("mismatch between supplied Tenant ID and that provided by AKS Workload Identity - please remove, ensure they match, or disable use_aks_workload_identity")
}
tenantId = aksTenantId
}

return &tenantId, nil
}
169 changes: 50 additions & 119 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@ package provider

import (
"context"
"encoding/base64"
"fmt"
"log"
"os"
"strings"

"github.com/hashicorp/go-azure-sdk/sdk/auth"
"github.com/hashicorp/go-azure-sdk/sdk/environments"
Expand Down Expand Up @@ -39,25 +35,12 @@ type ServiceRegistration interface {

// AzureADProvider returns a schema.Provider.
func AzureADProvider() *schema.Provider {
// avoids this showing up in test output
var debugLog = func(f string, v ...interface{}) {
if os.Getenv("TF_LOG") == "" {
return
}

if os.Getenv("TF_ACC") != "" {
return
}

log.Printf(f, v...)
}

dataSources := make(map[string]*pluginsdk.Resource)
resources := make(map[string]*pluginsdk.Resource)

// first handle the typed services
for _, service := range SupportedTypedServices() {
debugLog("[DEBUG] Registering Data Sources for %q..", service.Name())
logEntry("[DEBUG] Registering Data Sources for %q..", service.Name())
for _, ds := range service.DataSources() {
key := ds.ResourceType()
if existing := dataSources[key]; existing != nil {
Expand All @@ -73,7 +56,7 @@ func AzureADProvider() *schema.Provider {
dataSources[key] = dataSource
}

debugLog("[DEBUG] Registering Resources for %q..", service.Name())
logEntry("[DEBUG] Registering Resources for %q..", service.Name())
for _, r := range service.Resources() {
key := r.ResourceType()
if existing := resources[key]; existing != nil {
Expand All @@ -91,7 +74,7 @@ func AzureADProvider() *schema.Provider {

// then handle the untyped services
for _, service := range SupportedUntypedServices() {
debugLog("[DEBUG] Registering Data Sources for %q..", service.Name())
logEntry("[DEBUG] Registering Data Sources for %q..", service.Name())
for k, v := range service.SupportedDataSources() {
if existing := dataSources[k]; existing != nil {
panic(fmt.Sprintf("An existing Data Source exists for %q", k))
Expand All @@ -100,7 +83,7 @@ func AzureADProvider() *schema.Provider {
dataSources[k] = v
}

debugLog("[DEBUG] Registering Resources for %q..", service.Name())
logEntry("[DEBUG] Registering Resources for %q..", service.Name())
for k, v := range service.SupportedResources() {
if existing := resources[k]; existing != nil {
panic(fmt.Sprintf("An existing Resource exists for %q", k))
Expand Down Expand Up @@ -137,7 +120,7 @@ func AzureADProvider() *schema.Provider {
Type: pluginsdk.TypeString,
Required: true,
DefaultFunc: pluginsdk.EnvDefaultFunc("ARM_ENVIRONMENT", "global"),
Description: "The cloud environment which should be used. Possible values are: `global` (also `public`), `usgovernmentl4` (also `usgovernment`), `usgovernmentl5` (also `dod`), and `china`. Defaults to `global`",
Description: "The cloud environment which should be used. Possible values are: `global` (also `public`), `usgovernmentl4` (also `usgovernment`), `usgovernmentl5` (also `dod`), and `china`. Defaults to `global`. Not used and should not be specified when `metadata_host` is specified.",
},

"metadata_host": {
Expand Down Expand Up @@ -220,6 +203,14 @@ func AzureADProvider() *schema.Provider {
Description: "The URL for the OIDC provider from which to request an ID token. For use when authenticating as a Service Principal using OpenID Connect.",
},

// Azure AKS Workload Identity fields
"use_aks_workload_identity": {
Type: schema.TypeBool,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_USE_AKS_WORKLOAD_IDENTITY", false),
Description: "Allow Azure AKS Workload Identity to be used for Authentication.",
},

// CLI authentication specific fields
"use_cli": {
Type: pluginsdk.TypeBool,
Expand Down Expand Up @@ -280,6 +271,11 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
}
}

idToken, err := getOidcToken(d)
if err != nil {
return nil, pluginsdk.DiagFromErr(err)
}

clientSecret, err := getClientSecret(d)
if err != nil {
return nil, pluginsdk.DiagFromErr(err)
Expand All @@ -290,6 +286,11 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
return nil, pluginsdk.DiagFromErr(err)
}

tenantId, err := getTenantId(d)
if err != nil {
return nil, pluginsdk.DiagFromErr(err)
}

var (
env *environments.Environment

Expand All @@ -298,11 +299,15 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
)

if metadataHost != "" {
logEntry("[DEBUG] Configuring cloud environment from Metadata Service at %q", metadataHost)
if env, err = environments.FromEndpoint(ctx, fmt.Sprintf("https://%s", metadataHost)); err != nil {
return nil, pluginsdk.DiagFromErr(err)
}
} else if env, err = environments.FromName(envName); err != nil {
return nil, pluginsdk.DiagFromErr(err)
} else {
logEntry("[DEBUG] Configuring built-in cloud environment by name: %q", envName)
if env, err = environments.FromName(envName); err != nil {
return nil, pluginsdk.DiagFromErr(err)
}
}

if env.MicrosoftGraph == nil {
Expand All @@ -311,29 +316,34 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
return nil, pluginsdk.DiagErrorf("Microsoft Graph endpoint could not be determined for the specified environment")
}

idToken, err := oidcToken(d)
if err != nil {
return nil, pluginsdk.DiagFromErr(err)
}
var (
enableAzureCli = d.Get("use_cli").(bool)
enableManagedIdentity = d.Get("use_msi").(bool)
enableOidc = d.Get("use_oidc").(bool) || d.Get("use_aks_workload_identity").(bool)
)

authConfig := &auth.Credentials{
Environment: *env,
TenantID: d.Get("tenant_id").(string),
ClientID: *clientId,
ClientCertificateData: certData,
ClientCertificatePassword: d.Get("client_certificate_password").(string),
ClientCertificatePath: d.Get("client_certificate_path").(string),
ClientSecret: *clientSecret,
Environment: *env,
ClientID: *clientId,
TenantID: *tenantId,

ClientCertificateData: certData,
ClientCertificatePassword: d.Get("client_certificate_password").(string),
ClientCertificatePath: d.Get("client_certificate_path").(string),
ClientSecret: *clientSecret,

OIDCAssertionToken: *idToken,
GitHubOIDCTokenRequestURL: d.Get("oidc_request_url").(string),
GitHubOIDCTokenRequestToken: d.Get("oidc_request_token").(string),

CustomManagedIdentityEndpoint: d.Get("msi_endpoint").(string),

EnableAuthenticatingUsingAzureCLI: enableAzureCli,
EnableAuthenticatingUsingClientCertificate: true,
EnableAuthenticatingUsingClientSecret: true,
EnableAuthenticationUsingOIDC: d.Get("use_oidc").(bool),
EnableAuthenticationUsingGitHubOIDC: d.Get("use_oidc").(bool),
EnableAuthenticatingUsingAzureCLI: d.Get("use_cli").(bool),
EnableAuthenticatingUsingManagedIdentity: d.Get("use_msi").(bool),
CustomManagedIdentityEndpoint: d.Get("msi_endpoint").(string),
EnableAuthenticatingUsingManagedIdentity: enableManagedIdentity,
EnableAuthenticationUsingGitHubOIDC: enableOidc,
EnableAuthenticationUsingOIDC: enableOidc,
}

// only one pid can be interpreted currently
Expand Down Expand Up @@ -367,82 +377,3 @@ func buildClient(ctx context.Context, p *schema.Provider, authConfig *auth.Crede

return client, nil
}

func decodeCertificate(clientCertificate string) ([]byte, error) {
var pfx []byte
if clientCertificate != "" {
out := make([]byte, base64.StdEncoding.DecodedLen(len(clientCertificate)))
n, err := base64.StdEncoding.Decode(out, []byte(clientCertificate))
if err != nil {
return pfx, fmt.Errorf("could not decode client certificate data: %v", err)
}
pfx = out[:n]
}
return pfx, nil
}

func oidcToken(d *pluginsdk.ResourceData) (*string, error) {
idToken := d.Get("oidc_token").(string)

if path := d.Get("oidc_token_file_path").(string); path != "" {
fileTokenRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading OIDC Token from file %q: %v", path, err)
}

fileToken := strings.TrimSpace(string(fileTokenRaw))

if idToken != "" && idToken != fileToken {
return nil, fmt.Errorf("mismatch between supplied OIDC token and supplied OIDC token file contents - please either remove one or ensure they match")
}

idToken = fileToken
}

return &idToken, nil
}

func getClientId(d *pluginsdk.ResourceData) (*string, error) {
clientId := strings.TrimSpace(d.Get("client_id").(string))

if path := d.Get("client_id_file_path").(string); path != "" {
fileClientIdRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading Client ID from file %q: %v", path, err)
}

fileClientId := strings.TrimSpace(string(fileClientIdRaw))

if clientId != "" && clientId != fileClientId {
return nil, fmt.Errorf("mismatch between supplied Client ID and supplied Client ID file contents - please either remove one or ensure they match")
}

clientId = fileClientId
}

return &clientId, nil
}

func getClientSecret(d *pluginsdk.ResourceData) (*string, error) {
clientSecret := strings.TrimSpace(d.Get("client_secret").(string))

if path := d.Get("client_secret_file_path").(string); path != "" {
fileSecretRaw, err := os.ReadFile(path)

if err != nil {
return nil, fmt.Errorf("reading Client Secret from file %q: %v", path, err)
}

fileSecret := strings.TrimSpace(string(fileSecretRaw))

if clientSecret != "" && clientSecret != fileSecret {
return nil, fmt.Errorf("mismatch between supplied Client Secret and supplied Client Secret file contents - please either remove one or ensure they match")
}

clientSecret = fileSecret
}

return &clientSecret, nil
}
Loading

0 comments on commit 8470fa4

Please sign in to comment.