From 41746f0df44236f1173891b213fd62195421be92 Mon Sep 17 00:00:00 2001 From: Edwin Buck Date: Mon, 23 Sep 2024 18:18:45 -0500 Subject: [PATCH] Implement the Validate RPC on built-in plugins (#5303) The old API performed all configuration checks coupled with plugin reconfiguration under the Configure() func. The new API adds a Validation() func that only performs configuration checks but has no impact on the running plugin. To facilitate easier user, the pluginconf package was added that makes it easier to handle the merged code streams through a pluginconf.Status struct that will capture the first error (for integration with Configure() while permitting the Validation() to capture all errors that can be captured. Unit tests had to be reworked, as a side-effect of using the new pluginconf package is that all plugins now automatically check their trustdomain, instead of each plugin checking it in a haphazard manner. Occasionally, very small fixes were performed on plugins, and plugin coding standards were tweaked in small ways to be more similar to each other. Signed-off-by: Edwin Buck --- go.mod | 2 +- go.sum | 4 +- pkg/agent/plugin/nodeattestor/awsiid/iid.go | 30 ++- .../plugin/nodeattestor/awsiid/iid_test.go | 15 +- pkg/agent/plugin/nodeattestor/azuremsi/msi.go | 43 ++-- .../plugin/nodeattestor/azuremsi/msi_test.go | 40 +++- pkg/agent/plugin/nodeattestor/gcpiit/iit.go | 44 ++-- .../plugin/nodeattestor/gcpiit/iit_test.go | 16 +- .../httpchallenge/httpchallenge.go | 109 +++++----- .../httpchallenge/httpchallenge_test.go | 58 +++-- pkg/agent/plugin/nodeattestor/k8spsat/psat.go | 60 ++++-- .../nodeattestor/k8spsat/psat_posix_test.go | 21 +- .../plugin/nodeattestor/k8spsat/psat_test.go | 37 +++- .../nodeattestor/k8spsat/psat_windows_test.go | 17 +- pkg/agent/plugin/nodeattestor/k8ssat/sat.go | 60 ++++-- .../nodeattestor/k8ssat/sat_posix_test.go | 20 +- .../plugin/nodeattestor/k8ssat/sat_test.go | 41 +++- .../nodeattestor/k8ssat/sat_windows_test.go | 18 +- .../plugin/nodeattestor/sshpop/sshpop.go | 16 +- .../plugin/nodeattestor/sshpop/sshpop_test.go | 10 +- .../plugin/nodeattestor/tpmdevid/devid.go | 103 ++++----- .../nodeattestor/tpmdevid/devid_test.go | 94 +++++--- .../plugin/nodeattestor/x509pop/x509pop.go | 55 +++-- .../nodeattestor/x509pop/x509pop_test.go | 24 ++- .../plugin/svidstore/awssecretsmanager/aws.go | 53 +++-- .../svidstore/awssecretsmanager/aws_test.go | 15 +- .../svidstore/gcpsecretmanager/gcloud.go | 49 +++-- .../svidstore/gcpsecretmanager/gcloud_test.go | 11 +- .../plugin/workloadattestor/docker/docker.go | 91 +++++--- .../workloadattestor/docker/docker_posix.go | 18 +- .../docker/docker_posix_test.go | 64 +++--- .../workloadattestor/docker/docker_test.go | 32 ++- .../workloadattestor/docker/docker_windows.go | 5 +- pkg/agent/plugin/workloadattestor/k8s/k8s.go | 201 ++++++++++-------- .../plugin/workloadattestor/k8s/k8s_test.go | 99 ++++++--- .../systemd/systemd_windows.go | 4 + .../workloadattestor/unix/unix_posix.go | 38 +++- .../workloadattestor/unix/unix_posix_test.go | 134 +++++++----- .../workloadattestor/unix/unix_windows.go | 4 + .../workloadattestor/windows/windows_posix.go | 4 + .../windows/windows_windows.go | 38 +++- .../windows/windows_windows_test.go | 59 +++-- pkg/common/catalog/configure.go | 17 ++ pkg/common/plugin/sshpop/sshpop.go | 195 ++++++++++++----- pkg/common/plugin/sshpop/sshpop_test.go | 4 +- pkg/common/pluginconf/pluginconf.go | 63 ++++++ .../awsrolesanywhere/awsrolesanywhere.go | 64 +++--- .../plugin/bundlepublisher/awss3/awss3.go | 98 +++++---- .../gcpcloudstorage/gcpcloudstorage.go | 95 +++++---- pkg/server/plugin/keymanager/awskms/awskms.go | 93 ++++---- .../plugin/keymanager/awskms/awskms_test.go | 23 +- .../azurekeyvault/azure_key_vault.go | 92 ++++---- .../azurekeyvault/azure_key_vault_test.go | 13 +- pkg/server/plugin/keymanager/disk/disk.go | 36 +++- .../plugin/keymanager/disk/disk_test.go | 5 + pkg/server/plugin/keymanager/gcpkms/gcpkms.go | 93 ++++---- .../plugin/keymanager/gcpkms/gcpkms_test.go | 6 +- pkg/server/plugin/nodeattestor/awsiid/iid.go | 117 +++++----- .../plugin/nodeattestor/awsiid/iid_test.go | 2 +- .../nodeattestor/awsiid/organization.go | 2 +- .../plugin/nodeattestor/azuremsi/msi.go | 194 +++++++++-------- .../plugin/nodeattestor/azuremsi/msi_test.go | 4 +- pkg/server/plugin/nodeattestor/gcpiit/iit.go | 108 +++++----- .../plugin/nodeattestor/gcpiit/iit_test.go | 3 +- .../httpchallenge/httpchallenge.go | 123 +++++------ .../httpchallenge/httpchallenge_test.go | 4 +- .../nodeattestor/jointoken/join_token.go | 37 +++- .../plugin/nodeattestor/k8spsat/psat.go | 130 +++++------ .../plugin/nodeattestor/k8spsat/psat_test.go | 4 +- pkg/server/plugin/nodeattestor/k8ssat/sat.go | 138 ++++++------ .../plugin/nodeattestor/k8ssat/sat_test.go | 2 +- .../plugin/nodeattestor/sshpop/sshpop.go | 19 +- .../plugin/nodeattestor/sshpop/sshpop_test.go | 2 +- .../plugin/nodeattestor/tpmdevid/devid.go | 126 +++++------ .../nodeattestor/tpmdevid/devid_test.go | 14 +- .../plugin/nodeattestor/x509pop/x509pop.go | 128 ++++++----- .../nodeattestor/x509pop/x509pop_test.go | 2 +- .../plugin/notifier/gcsbundle/gcsbundle.go | 57 +++-- .../notifier/gcsbundle/gcsbundle_test.go | 34 ++- .../plugin/notifier/k8sbundle/k8sbundle.go | 59 +++-- .../notifier/k8sbundle/k8sbundle_test.go | 48 +++-- .../plugin/upstreamauthority/awspca/pca.go | 67 +++--- .../upstreamauthority/awspca/pca_test.go | 49 ++++- .../upstreamauthority/awssecret/awssecret.go | 77 +++---- .../awssecret/awssecret_test.go | 16 +- .../upstreamauthority/certmanager/api_test.go | 2 +- .../certmanager/certmanager.go | 86 ++++---- .../certmanager/certmanager_test.go | 18 +- .../plugin/upstreamauthority/disk/disk.go | 53 +++-- .../upstreamauthority/disk/disk_test.go | 4 +- .../plugin/upstreamauthority/ejbca/ejbca.go | 69 +++++- .../upstreamauthority/ejbca/ejbca_client.go | 50 ----- .../upstreamauthority/ejbca/ejbca_test.go | 7 + .../plugin/upstreamauthority/gcpcas/gcpcas.go | 83 +++++--- .../upstreamauthority/gcpcas/gcpcas_test.go | 9 +- .../plugin/upstreamauthority/spire/spire.go | 51 +++-- .../upstreamauthority/spire/spire_test.go | 4 +- .../plugin/upstreamauthority/vault/vault.go | 39 +++- .../upstreamauthority/vault/vault_test.go | 2 +- 99 files changed, 2969 insertions(+), 1827 deletions(-) create mode 100644 pkg/common/pluginconf/pluginconf.go diff --git a/go.mod b/go.mod index 8c888468e2..8e436fa92c 100644 --- a/go.mod +++ b/go.mod @@ -73,7 +73,7 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spiffe/go-spiffe/v2 v2.3.0 github.com/spiffe/spire-api-sdk v1.2.5-0.20240807182354-18e423ce2c1c - github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d + github.com/spiffe/spire-plugin-sdk v1.4.4-0.20240701180828-594312f4444d github.com/stretchr/testify v1.9.0 github.com/uber-go/tally/v4 v4.1.16 github.com/valyala/fastjson v1.6.4 diff --git a/go.sum b/go.sum index f371d80799..f23068fd81 100644 --- a/go.sum +++ b/go.sum @@ -1393,8 +1393,8 @@ github.com/spiffe/go-spiffe/v2 v2.3.0 h1:g2jYNb/PDMB8I7mBGL2Zuq/Ur6hUhoroxGQFyD6 github.com/spiffe/go-spiffe/v2 v2.3.0/go.mod h1:Oxsaio7DBgSNqhAO9i/9tLClaVlfRok7zvJnTV8ZyIY= github.com/spiffe/spire-api-sdk v1.2.5-0.20240807182354-18e423ce2c1c h1:lK/B2paDUiqbngUGsLxDBmNX/BsG2yKxS8W/iGT+x2c= github.com/spiffe/spire-api-sdk v1.2.5-0.20240807182354-18e423ce2c1c/go.mod h1:4uuhFlN6KBWjACRP3xXwrOTNnvaLp1zJs8Lribtr4fI= -github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d h1:LCRQGU6vOqKLfRrG+GJQrwMwDILcAddAEIf4/1PaSVc= -github.com/spiffe/spire-plugin-sdk v1.4.4-0.20230721151831-bf67dde4721d/go.mod h1:GA6o2PVLwyJdevT6KKt5ZXCY/ziAPna13y/seGk49Ik= +github.com/spiffe/spire-plugin-sdk v1.4.4-0.20240701180828-594312f4444d h1:Upcyq8u1aWFHTQSEskwxBE2PehobpY+M21LXXDS/mPw= +github.com/spiffe/spire-plugin-sdk v1.4.4-0.20240701180828-594312f4444d/go.mod h1:GA6o2PVLwyJdevT6KKt5ZXCY/ziAPna13y/seGk49Ik= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= diff --git a/pkg/agent/plugin/nodeattestor/awsiid/iid.go b/pkg/agent/plugin/nodeattestor/awsiid/iid.go index 294b3e2d5d..44ac241ceb 100644 --- a/pkg/agent/plugin/nodeattestor/awsiid/iid.go +++ b/pkg/agent/plugin/nodeattestor/awsiid/iid.go @@ -15,6 +15,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" caws "github.com/spiffe/spire/pkg/common/plugin/aws" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -40,6 +41,16 @@ type IIDAttestorConfig struct { EC2MetadataEndpoint string `hcl:"ec2_metadata_endpoint"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *IIDAttestorConfig { + newConfig := &IIDAttestorConfig{} + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + return newConfig +} + // IIDAttestorPlugin implements aws nodeattestation in the agent. type IIDAttestorPlugin struct { nodeattestorv1.UnsafeNodeAttestorServer @@ -155,20 +166,27 @@ func readStringAndClose(r io.ReadCloser) (string, error) { // Configure implements the Config interface method of the same name func (p *IIDAttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := &IIDAttestorConfig{} - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } p.mtx.Lock() defer p.mtx.Unlock() - - p.config = config + p.config = newConfig return &configv1.ConfigureResponse{}, nil } +func (p *IIDAttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *IIDAttestorPlugin) getConfig() (*IIDAttestorConfig, error) { p.mtx.RLock() defer p.mtx.RUnlock() diff --git a/pkg/agent/plugin/nodeattestor/awsiid/iid_test.go b/pkg/agent/plugin/nodeattestor/awsiid/iid_test.go index 1ae4b3e9f3..1e858e1e96 100644 --- a/pkg/agent/plugin/nodeattestor/awsiid/iid_test.go +++ b/pkg/agent/plugin/nodeattestor/awsiid/iid_test.go @@ -18,8 +18,10 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/fullsailor/pkcs7" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/pemutil" "github.com/spiffe/spire/pkg/common/plugin/aws" "github.com/spiffe/spire/test/plugintest" @@ -93,6 +95,9 @@ func (s *Suite) SetupTest() { })) s.p = s.loadPlugin( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configuref(`ec2_metadata_endpoint = "http://%s/latest"`, s.server.Listener.Addr()), ) @@ -141,13 +146,21 @@ func (s *Suite) TestConfigure() { var err error s.loadPlugin( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.CaptureConfigureError(&err), plugintest.Configure("malformed"), ) require.Error(err) // success - s.loadPlugin(plugintest.Configure("")) + s.loadPlugin( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(""), + ) } func (s *Suite) loadPlugin(opts ...plugintest.Option) nodeattestor.NodeAttestor { diff --git a/pkg/agent/plugin/nodeattestor/azuremsi/msi.go b/pkg/agent/plugin/nodeattestor/azuremsi/msi.go index eea7479b10..c28031e5fe 100644 --- a/pkg/agent/plugin/nodeattestor/azuremsi/msi.go +++ b/pkg/agent/plugin/nodeattestor/azuremsi/msi.go @@ -11,6 +11,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/azure" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -40,6 +41,20 @@ type MSIAttestorConfig struct { ResourceID string `hcl:"resource_id"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *MSIAttestorConfig { + newConfig := new(MSIAttestorConfig) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.ResourceID == "" { + newConfig.ResourceID = azure.DefaultMSIResourceID + } + + return newConfig +} + type MSIAttestorPlugin struct { nodeattestorv1.UnsafeNodeAttestorServer configv1.UnsafeConfigServer @@ -85,19 +100,27 @@ func (p *MSIAttestorPlugin) AidAttestation(stream nodeattestorv1.NodeAttestor_Ai } func (p *MSIAttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := new(MSIAttestorConfig) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - if config.ResourceID == "" { - config.ResourceID = azure.DefaultMSIResourceID - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - p.setConfig(config) return &configv1.ConfigureResponse{}, nil } +func (p *MSIAttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *MSIAttestorPlugin) getConfig() (*MSIAttestorConfig, error) { p.mu.RLock() defer p.mu.RUnlock() @@ -106,9 +129,3 @@ func (p *MSIAttestorPlugin) getConfig() (*MSIAttestorConfig, error) { } return p.config, nil } - -func (p *MSIAttestorPlugin) setConfig(config *MSIAttestorConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} diff --git a/pkg/agent/plugin/nodeattestor/azuremsi/msi_test.go b/pkg/agent/plugin/nodeattestor/azuremsi/msi_test.go index 3bd4690a78..0ae0e69829 100644 --- a/pkg/agent/plugin/nodeattestor/azuremsi/msi_test.go +++ b/pkg/agent/plugin/nodeattestor/azuremsi/msi_test.go @@ -9,8 +9,10 @@ import ( jose "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/azure" "github.com/spiffe/spire/test/plugintest" "github.com/spiffe/spire/test/spiretest" @@ -49,7 +51,12 @@ func (s *MSIAttestorSuite) TestAidAttestationNotConfigured() { func (s *MSIAttestorSuite) TestAidAttestationFailedToObtainToken() { s.tokenErr = errors.New("FAILED") - attestor := s.loadAttestor(plugintest.Configure("")) + attestor := s.loadAttestor( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(""), + ) err := attestor.Attest(context.Background(), streamBuilder.Build()) s.RequireGRPCStatus(err, codes.Internal, "nodeattestor(azure_msi): unable to fetch token: FAILED") } @@ -59,7 +66,12 @@ func (s *MSIAttestorSuite) TestAidAttestationSuccess() { expectPayload := []byte(fmt.Sprintf(`{"token":%q}`, s.token)) - attestor := s.loadAttestor(plugintest.Configure("")) + attestor := s.loadAttestor( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(""), + ) err := attestor.Attest(context.Background(), streamBuilder.ExpectAndBuild(expectPayload)) s.Require().NoError(err) } @@ -67,15 +79,33 @@ func (s *MSIAttestorSuite) TestAidAttestationSuccess() { func (s *MSIAttestorSuite) TestConfigure() { // malformed configuration var err error - s.loadAttestor(plugintest.CaptureConfigureError(&err), plugintest.Configure("blah")) + s.loadAttestor( + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure("blah"), + ) s.RequireGRPCStatusContains(err, codes.InvalidArgument, "unable to decode configuration") // success - s.loadAttestor(plugintest.CaptureConfigureError(&err), plugintest.Configure("")) + s.loadAttestor( + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(""), + ) s.Require().NoError(err) // success with resource_id - s.loadAttestor(plugintest.CaptureConfigureError(&err), plugintest.Configure(`resource_id = "foo"`)) + s.loadAttestor( + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(`resource_id = "foo"`), + ) s.Require().NoError(err) } diff --git a/pkg/agent/plugin/nodeattestor/gcpiit/iit.go b/pkg/agent/plugin/nodeattestor/gcpiit/iit.go index 330d420d84..065dcf86f9 100644 --- a/pkg/agent/plugin/nodeattestor/gcpiit/iit.go +++ b/pkg/agent/plugin/nodeattestor/gcpiit/iit.go @@ -16,6 +16,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/gcp" + "github.com/spiffe/spire/pkg/common/pluginconf" ) const ( @@ -50,6 +51,24 @@ type IITAttestorConfig struct { ServiceAccount string `hcl:"service_account"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *IITAttestorConfig { + newConfig := &IITAttestorConfig{} + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.ServiceAccount == "" { + newConfig.ServiceAccount = defaultServiceAccount + } + + if newConfig.IdentityTokenHost == "" { + newConfig.IdentityTokenHost = defaultIdentityTokenHost + } + + return newConfig +} + // NewIITAttestorPlugin creates a new IITAttestorPlugin. func New() *IITAttestorPlugin { return &IITAttestorPlugin{} @@ -76,26 +95,27 @@ func (p *IITAttestorPlugin) AidAttestation(stream nodeattestorv1.NodeAttestor_Ai } func (p *IITAttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := &IITAttestorConfig{} - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.ServiceAccount == "" { - config.ServiceAccount = defaultServiceAccount - } - - if config.IdentityTokenHost == "" { - config.IdentityTokenHost = defaultIdentityTokenHost + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } p.mtx.Lock() defer p.mtx.Unlock() - p.config = config + p.config = newConfig return &configv1.ConfigureResponse{}, nil } +func (p *IITAttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *IITAttestorPlugin) getConfig() (*IITAttestorConfig, error) { p.mtx.Lock() defer p.mtx.Unlock() diff --git a/pkg/agent/plugin/nodeattestor/gcpiit/iit_test.go b/pkg/agent/plugin/nodeattestor/gcpiit/iit_test.go index 3bdf6b64ec..523e136948 100644 --- a/pkg/agent/plugin/nodeattestor/gcpiit/iit_test.go +++ b/pkg/agent/plugin/nodeattestor/gcpiit/iit_test.go @@ -11,8 +11,10 @@ import ( "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/cryptosigner" "github.com/go-jose/go-jose/v4/jwt" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/gcp" "github.com/spiffe/spire/test/plugintest" "github.com/spiffe/spire/test/spiretest" @@ -66,7 +68,10 @@ func (s *Suite) SetupSuite() { func (s *Suite) SetupTest() { s.status = http.StatusOK s.body = "" - s.na = s.loadPlugin(plugintest.Configuref(` + s.na = s.loadPlugin(plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configuref(` service_account = "%s" identity_token_host = "%s" `, testServiceAccount, s.server.Listener.Addr().String())) @@ -111,7 +116,12 @@ func (s *Suite) TestConfigure() { // malformed var err error - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure("malformed")) + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure("malformed"), + ) require.Error(err) } @@ -135,7 +145,7 @@ func TestRetrieveIdentity(t *testing.T) { }, { msg: "invalid port", - url: "http://0.0.0.0:70000", + url: "http://127.0.0.1:70000", expectErrContains: "invalid port", }, { diff --git a/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge.go b/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge.go index ea6eb21aaf..07f3e5795f 100644 --- a/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge.go +++ b/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge.go @@ -16,6 +16,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/httpchallenge" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -48,6 +49,43 @@ type Config struct { AdvertisedPort int `hcl:"advertised_port"` } +func (p *Plugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *configData { + hclConfig := new(Config) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + } + + hostName := hclConfig.HostName + // hostname unset, autodetect hostname + if hostName == "" { + var err error + hostName, err = os.Hostname() + if err != nil { + status.ReportErrorf("unable to fetch hostname: %v", err) + } + } + + agentName := hclConfig.AgentName + if agentName == "" { + agentName = "default" + } + + advertisedPort := hclConfig.AdvertisedPort + // if unset, advertised port is same as hcl:"port" + if advertisedPort == 0 { + advertisedPort = hclConfig.Port + } + + newConfig := &configData{ + port: hclConfig.Port, + advertisedPort: advertisedPort, + hostName: hostName, + agentName: agentName, + } + + return newConfig +} + type Plugin struct { nodeattestorv1.UnsafeNodeAttestorServer configv1.UnsafeConfigServer @@ -70,14 +108,14 @@ func New() *Plugin { } func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestationServer) (err error) { - data, err := p.getConfig() + config, err := p.getConfig() if err != nil { return err } ctx := stream.Context() - port := data.port + port := config.port l, err := net.Listen("tcp", fmt.Sprintf("%s:%d", p.hooks.bindHost, port)) if err != nil { @@ -85,14 +123,14 @@ func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestatio } defer l.Close() - advertisedPort := data.advertisedPort + advertisedPort := config.advertisedPort if advertisedPort == 0 { advertisedPort = l.Addr().(*net.TCPAddr).Port } attestationPayload, err := json.Marshal(httpchallenge.AttestationData{ - HostName: data.hostName, - AgentName: data.agentName, + HostName: config.hostName, + AgentName: config.agentName, Port: advertisedPort, }) if err != nil { @@ -129,7 +167,7 @@ func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestatio return err } - err = p.serveNonce(ctx, l, data.agentName, challenge.Nonce) + err = p.serveNonce(ctx, l, config.agentName, challenge.Nonce) if err != nil { return status.Errorf(codes.Internal, "failed to start webserver: %v", err) } @@ -137,23 +175,27 @@ func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestatio } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := new(Config) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - // Make sure the configuration produces valid data - configData, err := p.loadConfigData(config) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { return nil, err } - p.setConfig(configData) + p.m.Lock() + defer p.m.Unlock() + p.c = newConfig return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *Plugin) serveNonce(ctx context.Context, l net.Listener, agentName string, nonce string) (err error) { h := http.NewServeMux() s := &http.Server{ @@ -192,40 +234,3 @@ func (p *Plugin) getConfig() (*configData, error) { } return p.c, nil } - -func (p *Plugin) setConfig(c *configData) { - p.m.Lock() - defer p.m.Unlock() - p.c = c -} - -func (p *Plugin) loadConfigData(config *Config) (*configData, error) { - // Determine the host name to pass to the server. Values are preferred in - // this order: - // 1. HCL HostName configuration value - // 2. OS hostname value - hostName := config.HostName - if hostName == "" { - var err error - hostName, err = os.Hostname() - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to fetch hostname: %v", err) - } - } - - var agentName = "default" - if config.AgentName != "" { - agentName = config.AgentName - } - - if config.AdvertisedPort == 0 { - config.AdvertisedPort = config.Port - } - - return &configData{ - port: config.Port, - advertisedPort: config.AdvertisedPort, - hostName: hostName, - agentName: agentName, - }, nil -} diff --git a/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge_test.go b/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge_test.go index 245309c4a4..4be50e94cc 100644 --- a/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge_test.go +++ b/pkg/agent/plugin/nodeattestor/httpchallenge/httpchallenge_test.go @@ -9,9 +9,11 @@ import ( "net/http" "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" common_httpchallenge "github.com/spiffe/spire/pkg/common/plugin/httpchallenge" "github.com/spiffe/spire/test/plugintest" "github.com/stretchr/testify/require" @@ -23,14 +25,16 @@ var ( func TestConfigureCommon(t *testing.T) { tests := []struct { - name string - hclConf string - expErr string + name string + trustDomain string + hclConf string + expErr string }{ { - name: "Configure fails if receives wrong HCL configuration", - hclConf: "not HCL conf", - expErr: "rpc error: code = InvalidArgument desc = unable to decode configuration", + name: "Configure fails if receives wrong HCL configuration", + trustDomain: "example.org", + hclConf: "not HCL conf", + expErr: "rpc error: code = InvalidArgument desc = unable to decode configuration", }, } @@ -39,7 +43,12 @@ func TestConfigureCommon(t *testing.T) { t.Run(tt.name, func(t *testing.T) { plugin := newPlugin() - resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{HclConfiguration: tt.hclConf}) + resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{ + CoreConfiguration: &configv1.CoreConfiguration{ + TrustDomain: tt.trustDomain, + }, + HclConfiguration: tt.hclConf}, + ) if tt.expErr != "" { require.Contains(t, err.Error(), tt.expErr) require.Nil(t, resp) @@ -55,18 +64,21 @@ func TestConfigureCommon(t *testing.T) { func TestAidAttestationFailures(t *testing.T) { tests := []struct { name string + trustDomain string config string expErr string serverStream nodeattestor.ServerStream }{ { name: "AidAttestation fails if server does not sends a challenge", + trustDomain: "example.org", config: "", expErr: "the error", serverStream: streamBuilder.FailAndBuild(errors.New("the error")), }, { name: "AidAttestation fails if agent cannot unmarshal server challenge", + trustDomain: "example.org", config: "", expErr: "rpc error: code = Internal desc = nodeattestor(http_challenge): unable to unmarshal challenge: invalid character 'o' in literal null (expecting 'u')", serverStream: streamBuilder.IgnoreThenChallenge([]byte("not-a-challenge")).Build(), @@ -76,7 +88,7 @@ func TestAidAttestationFailures(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { var err error - p := loadAndConfigurePlugin(t, tt.config) + p := loadAndConfigurePlugin(t, tt.trustDomain, tt.config) err = p.Attest(context.Background(), tt.serverStream) if tt.expErr != "" { @@ -96,13 +108,15 @@ func TestAidAttestationSucceeds(t *testing.T) { tests := []struct { name string + trustDomain string config string attestationData common_httpchallenge.AttestationData serverStream func(attestationData *common_httpchallenge.AttestationData, challenge []byte, expectPayload []byte, challengeobj *common_httpchallenge.Challenge, port int) nodeattestor.ServerStream }{ { - name: "Check for random port", - config: "", + name: "Check for random port", + trustDomain: "example.org", + config: "", attestationData: common_httpchallenge.AttestationData{ HostName: "spire-dev", AgentName: "default", @@ -122,8 +136,9 @@ func TestAidAttestationSucceeds(t *testing.T) { }, }, { - name: "Check for advertised port", - config: fmt.Sprintf("advertised_port = %d", port), + name: "Check for advertised port", + trustDomain: "example.org", + config: fmt.Sprintf("advertised_port = %d", port), attestationData: common_httpchallenge.AttestationData{ HostName: "spire-dev", AgentName: "default", @@ -143,8 +158,9 @@ func TestAidAttestationSucceeds(t *testing.T) { }, }, { - name: "Test with defaults except port", - config: "port=9999", + name: "Test with defaults except port", + trustDomain: "example.org", + config: "port=9999", attestationData: common_httpchallenge.AttestationData{ HostName: "localhost", AgentName: "default", @@ -159,8 +175,9 @@ func TestAidAttestationSucceeds(t *testing.T) { }, }, { - name: "Full test with all the settings", - config: "hostname=\"localhost\"\nagentname=\"test\"\nport=9999\nadvertised_port=9999", + name: "Full test with all the settings", + trustDomain: "example.org", + config: "hostname=\"localhost\"\nagentname=\"test\"\nport=9999\nadvertised_port=9999", attestationData: common_httpchallenge.AttestationData{ HostName: "localhost", AgentName: "test", @@ -189,7 +206,7 @@ func TestAidAttestationSucceeds(t *testing.T) { challenge, err := json.Marshal(challengeobj) require.NoError(t, err) - p := loadAndConfigurePlugin(t, tt.config) + p := loadAndConfigurePlugin(t, tt.trustDomain, tt.config) err = p.Attest(context.Background(), tt.serverStream(&tt.attestationData, challenge, expectPayload, challengeobj, port)) require.NoError(t, err) @@ -197,8 +214,11 @@ func TestAidAttestationSucceeds(t *testing.T) { } } -func loadAndConfigurePlugin(t *testing.T, config string) nodeattestor.NodeAttestor { - return loadPlugin(t, plugintest.Configure(config)) +func loadAndConfigurePlugin(t *testing.T, trustDomain string, config string) nodeattestor.NodeAttestor { + return loadPlugin(t, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), + plugintest.Configure(config)) } func loadPlugin(t *testing.T, options ...plugintest.Option) nodeattestor.NodeAttestor { diff --git a/pkg/agent/plugin/nodeattestor/k8spsat/psat.go b/pkg/agent/plugin/nodeattestor/k8spsat/psat.go index 4499b0d224..20e33c4c84 100644 --- a/pkg/agent/plugin/nodeattestor/k8spsat/psat.go +++ b/pkg/agent/plugin/nodeattestor/k8spsat/psat.go @@ -11,6 +11,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -54,6 +55,29 @@ type AttestorConfig struct { TokenPath string `hcl:"token_path"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *attestorConfig { + hclConfig := new(AttestorConfig) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if hclConfig.Cluster == "" { + status.ReportError("missing required cluster block") + } + + newConfig := &attestorConfig{ + cluster: hclConfig.Cluster, + tokenPath: hclConfig.TokenPath, + } + + if newConfig.tokenPath == "" { + newConfig.tokenPath = getDefaultTokenPath() + } + + return newConfig +} + type attestorConfig struct { cluster string tokenPath string @@ -88,27 +112,27 @@ func (p *AttestorPlugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAt // Configure decodes JSON config from request and populates AttestorPlugin with it func (p *AttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (resp *configv1.ConfigureResponse, err error) { - hclConfig := new(AttestorConfig) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if hclConfig.Cluster == "" { - return nil, status.Error(codes.InvalidArgument, "configuration missing cluster") + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - config := &attestorConfig{ - cluster: hclConfig.Cluster, - tokenPath: hclConfig.TokenPath, - } - if config.tokenPath == "" { - config.tokenPath = getDefaultTokenPath() - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - p.setConfig(config) return &configv1.ConfigureResponse{}, nil } +func (p *AttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (resp *configv1.ValidateResponse, err error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { p.mu.RLock() defer p.mu.RUnlock() @@ -118,12 +142,6 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { return p.config, nil } -func (p *AttestorPlugin) setConfig(config *attestorConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} - func loadTokenFromFile(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { diff --git a/pkg/agent/plugin/nodeattestor/k8spsat/psat_posix_test.go b/pkg/agent/plugin/nodeattestor/k8spsat/psat_posix_test.go index 05be271683..b908f60ff4 100644 --- a/pkg/agent/plugin/nodeattestor/k8spsat/psat_posix_test.go +++ b/pkg/agent/plugin/nodeattestor/k8spsat/psat_posix_test.go @@ -5,7 +5,9 @@ package k8spsat import ( "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/plugintest" "github.com/stretchr/testify/require" ) @@ -13,13 +15,24 @@ import ( func TestConfigureDefaultToken(t *testing.T) { p := New() var err error - plugintest.Load(t, builtin(p), new(nodeattestor.V1), plugintest.CaptureConfigureError(&err), plugintest.Configure(`cluster = "production"`)) + plugintest.Load(t, builtin(p), new(nodeattestor.V1), + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(`cluster = "production"`), + ) require.NoError(t, err) require.Equal(t, "/var/run/secrets/tokens/spire-agent", p.config.tokenPath) - plugintest.Load(t, builtin(p), new(nodeattestor.V1), plugintest.CaptureConfigureError(&err), plugintest.Configure(` - cluster = "production" - token_path = "/tmp/token"`)) + plugintest.Load(t, builtin(p), new(nodeattestor.V1), + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(`cluster = "production" + token_path = "/tmp/token"`), + ) require.NoError(t, err) require.Equal(t, "/tmp/token", p.config.tokenPath) diff --git a/pkg/agent/plugin/nodeattestor/k8spsat/psat_test.go b/pkg/agent/plugin/nodeattestor/k8spsat/psat_test.go index b688f4fefe..0e383d1f94 100644 --- a/pkg/agent/plugin/nodeattestor/k8spsat/psat_test.go +++ b/pkg/agent/plugin/nodeattestor/k8spsat/psat_test.go @@ -9,8 +9,10 @@ import ( jose "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/pemutil" sat_common "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/test/plugintest" @@ -52,7 +54,8 @@ func (s *AttestorSuite) SetupTest() { func (s *AttestorSuite) TestAttestNotConfigured() { na := s.loadPlugin() err := na.Attest(context.Background(), streamBuilder.Build()) - s.RequireGRPCStatus(err, codes.FailedPrecondition, "nodeattestor(k8s_psat): not configured") + s.T().Logf("failed: %s", err.Error()) + s.RequireGRPCStatusContains(err, codes.FailedPrecondition, "nodeattestor(k8s_psat): not configured") } func (s *AttestorSuite) TestAttestNoToken() { @@ -81,23 +84,43 @@ func (s *AttestorSuite) TestConfigure() { var err error // malformed configuration - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure("malformed")) + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure("malformed"), + ) s.RequireGRPCStatusContains(err, codes.InvalidArgument, "unable to decode configuration") // missing cluster - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure("")) - s.RequireGRPCStatus(err, codes.InvalidArgument, "configuration missing cluster") + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(""), + ) + s.RequireGRPCStatus(err, codes.InvalidArgument, "missing required cluster block") // success - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure(`cluster = "production"`)) + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(`cluster = "production"`), + ) s.Require().NoError(err) } func (s *AttestorSuite) loadPluginWithTokenPath(tokenPath string) nodeattestor.NodeAttestor { - return s.loadPlugin(plugintest.Configuref(` + return s.loadPlugin( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configuref(` cluster = "production" token_path = %q - `, tokenPath)) + `, tokenPath), + ) } func (s *AttestorSuite) loadPlugin(options ...plugintest.Option) nodeattestor.NodeAttestor { diff --git a/pkg/agent/plugin/nodeattestor/k8spsat/psat_windows_test.go b/pkg/agent/plugin/nodeattestor/k8spsat/psat_windows_test.go index a5c2378f8d..5e3fad8d69 100644 --- a/pkg/agent/plugin/nodeattestor/k8spsat/psat_windows_test.go +++ b/pkg/agent/plugin/nodeattestor/k8spsat/psat_windows_test.go @@ -5,7 +5,9 @@ package k8spsat import ( "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/plugintest" "github.com/stretchr/testify/require" ) @@ -13,24 +15,28 @@ import ( func TestConfigureDefaultToken(t *testing.T) { for _, tt := range []struct { name string + trustDomain string mountPoint string config string expectTokenPath string }{ { name: "mountPoint set", + trustDomain: "example.org", mountPoint: "c:/somepath", config: `cluster = "production"`, expectTokenPath: "c:\\somepath\\var\\run\\secrets\\tokens\\spire-agent", }, { name: "no mountPoint", + trustDomain: "example.org", config: `cluster = "production"`, expectTokenPath: "\\var\\run\\secrets\\tokens\\spire-agent", }, { - name: "token path set on configuration", - mountPoint: "c:/somepath", + name: "token path set on configuration", + trustDomain: "example.org", + mountPoint: "c:/somepath", config: ` cluster = "production" token_path = "c:\\token"`, @@ -44,7 +50,12 @@ func TestConfigureDefaultToken(t *testing.T) { p := New() var err error - plugintest.Load(t, builtin(p), new(nodeattestor.V1), plugintest.CaptureConfigureError(&err), plugintest.Configure(tt.config)) + plugintest.Load(t, builtin(p), new(nodeattestor.V1), + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + }), + plugintest.Configure(tt.config)) require.NoError(t, err) require.Equal(t, tt.expectTokenPath, p.config.tokenPath) diff --git a/pkg/agent/plugin/nodeattestor/k8ssat/sat.go b/pkg/agent/plugin/nodeattestor/k8ssat/sat.go index fada3a84c7..bce6fd91e6 100644 --- a/pkg/agent/plugin/nodeattestor/k8ssat/sat.go +++ b/pkg/agent/plugin/nodeattestor/k8ssat/sat.go @@ -13,6 +13,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/zeebo/errs" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -39,6 +40,29 @@ type AttestorConfig struct { TokenPath string `hcl:"token_path"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *attestorConfig { + hclConfig := new(AttestorConfig) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if hclConfig.Cluster == "" { + status.ReportError("configuration missing cluster") + } + + newConfig := &attestorConfig{ + cluster: hclConfig.Cluster, + tokenPath: hclConfig.TokenPath, + } + + if newConfig.tokenPath == "" { + newConfig.tokenPath = getDefaultTokenPath() + } + + return newConfig +} + type attestorConfig struct { cluster string tokenPath string @@ -91,27 +115,27 @@ func (p *AttestorPlugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAt func (p *AttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (resp *configv1.ConfigureResponse, err error) { p.log.Warn(fmt.Sprintf("The %q node attestor plugin has been deprecated in favor of the \"k8s_psat\" plugin and will be removed in a future release", pluginName)) - hclConfig := new(AttestorConfig) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if hclConfig.Cluster == "" { - return nil, status.Error(codes.InvalidArgument, "configuration missing cluster") + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - config := &attestorConfig{ - cluster: hclConfig.Cluster, - tokenPath: hclConfig.TokenPath, - } - if config.tokenPath == "" { - config.tokenPath = getDefaultTokenPath() - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - p.setConfig(config) return &configv1.ConfigureResponse{}, nil } +func (p *AttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (resp *configv1.ValidateResponse, err error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { p.mu.RLock() defer p.mu.RUnlock() @@ -121,12 +145,6 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { return p.config, nil } -func (p *AttestorPlugin) setConfig(config *attestorConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} - func loadTokenFromFile(path string) (string, error) { data, err := os.ReadFile(path) if err != nil { diff --git a/pkg/agent/plugin/nodeattestor/k8ssat/sat_posix_test.go b/pkg/agent/plugin/nodeattestor/k8ssat/sat_posix_test.go index 7692a0055b..344bf3f63a 100644 --- a/pkg/agent/plugin/nodeattestor/k8ssat/sat_posix_test.go +++ b/pkg/agent/plugin/nodeattestor/k8ssat/sat_posix_test.go @@ -5,7 +5,9 @@ package k8ssat import ( "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/plugintest" "github.com/stretchr/testify/require" ) @@ -13,13 +15,25 @@ import ( func TestConfigureDefaultToken(t *testing.T) { p := New() var err error - plugintest.Load(t, builtin(p), new(nodeattestor.V1), plugintest.CaptureConfigureError(&err), plugintest.Configure(`cluster = "production"`)) + plugintest.Load(t, builtin(p), new(nodeattestor.V1), + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(`cluster = "production"`), + ) require.NoError(t, err) require.Equal(t, "/var/run/secrets/kubernetes.io/serviceaccount/token", p.config.tokenPath) - plugintest.Load(t, builtin(p), new(nodeattestor.V1), plugintest.CaptureConfigureError(&err), plugintest.Configure(` + plugintest.Load(t, builtin(p), new(nodeattestor.V1), + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(` cluster = "production" - token_path = "/tmp/token"`)) + token_path = "/tmp/token"`), + ) require.NoError(t, err) require.Equal(t, "/tmp/token", p.config.tokenPath) diff --git a/pkg/agent/plugin/nodeattestor/k8ssat/sat_test.go b/pkg/agent/plugin/nodeattestor/k8ssat/sat_test.go index 48feb62300..3ecb83171a 100644 --- a/pkg/agent/plugin/nodeattestor/k8ssat/sat_test.go +++ b/pkg/agent/plugin/nodeattestor/k8ssat/sat_test.go @@ -6,8 +6,10 @@ import ( "path/filepath" "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/plugintest" "github.com/spiffe/spire/test/spiretest" "google.golang.org/grpc/codes" @@ -38,19 +40,19 @@ func (s *AttestorSuite) TestAttestNotConfigured() { } func (s *AttestorSuite) TestAttestNoToken() { - na := s.loadPluginWithTokenPath(s.joinPath("token")) + na := s.loadPluginWithTokenPath("example.org", s.joinPath("token")) err := na.Attest(context.Background(), streamBuilder.Build()) s.RequireGRPCStatusContains(err, codes.InvalidArgument, "nodeattestor(k8s_sat): unable to load token from") } func (s *AttestorSuite) TestAttestEmptyToken() { - na := s.loadPluginWithTokenPath(s.writeValue("token", "")) + na := s.loadPluginWithTokenPath("example.org", s.writeValue("token", "")) err := na.Attest(context.Background(), streamBuilder.Build()) s.RequireGRPCStatusContains(err, codes.InvalidArgument, "nodeattestor(k8s_sat): unable to load token from") } func (s *AttestorSuite) TestAttestSuccess() { - na := s.loadPluginWithTokenPath(s.writeValue("token", "TOKEN")) + na := s.loadPluginWithTokenPath("example.org", s.writeValue("token", "TOKEN")) err := na.Attest(context.Background(), streamBuilder.ExpectAndBuild([]byte(`{"cluster":"production","token":"TOKEN"}`))) s.Require().NoError(err) @@ -60,23 +62,42 @@ func (s *AttestorSuite) TestConfigure() { var err error // malformed configuration - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure("malformed")) + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure("malformed"), + ) s.RequireGRPCStatusContains(err, codes.InvalidArgument, "unable to decode configuration") // missing cluster - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure("")) + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(""), + ) s.RequireGRPCStatus(err, codes.InvalidArgument, "configuration missing cluster") // success - s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure(`cluster = "production"`)) + s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(`cluster = "production"`), + ) s.Require().NoError(err) } -func (s *AttestorSuite) loadPluginWithTokenPath(tokenPath string) nodeattestor.NodeAttestor { - return s.loadPlugin(plugintest.Configuref(` +func (s *AttestorSuite) loadPluginWithTokenPath(trustDomain string, tokenPath string) nodeattestor.NodeAttestor { + return s.loadPlugin( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), + plugintest.Configuref(` cluster = "production" - token_path = %q - `, tokenPath)) + token_path = %q`, tokenPath), + ) } func (s *AttestorSuite) loadPlugin(options ...plugintest.Option) nodeattestor.NodeAttestor { diff --git a/pkg/agent/plugin/nodeattestor/k8ssat/sat_windows_test.go b/pkg/agent/plugin/nodeattestor/k8ssat/sat_windows_test.go index 40753dc8d7..95e1a011d8 100644 --- a/pkg/agent/plugin/nodeattestor/k8ssat/sat_windows_test.go +++ b/pkg/agent/plugin/nodeattestor/k8ssat/sat_windows_test.go @@ -5,7 +5,9 @@ package k8ssat import ( "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/plugintest" "github.com/stretchr/testify/require" ) @@ -13,24 +15,28 @@ import ( func TestConfigureDefaultToken(t *testing.T) { for _, tt := range []struct { name string + trustDomain string mountPoint string config string expectTokenPath string }{ { name: "mountPoint set", + trustDomain: "example.org", mountPoint: "c:\\somepath", config: `cluster = "production"`, expectTokenPath: "c:\\somepath\\var\\run\\secrets\\kubernetes.io\\serviceaccount\\token", }, { name: "no mountPoint", + trustDomain: "example.org", config: `cluster = "production"`, expectTokenPath: "\\var\\run\\secrets\\kubernetes.io\\serviceaccount\\token", }, { - name: "token path set on configuration", - mountPoint: "c:\\somepath", + name: "token path set on configuration", + trustDomain: "example.org", + mountPoint: "c:\\somepath", config: ` cluster = "production" token_path = "c:\\token"`, @@ -44,7 +50,13 @@ func TestConfigureDefaultToken(t *testing.T) { p := New() var err error - plugintest.Load(t, builtin(p), new(nodeattestor.V1), plugintest.CaptureConfigureError(&err), plugintest.Configure(tt.config)) + plugintest.Load(t, builtin(p), new(nodeattestor.V1), + plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + }), + plugintest.Configure(tt.config), + ) require.NoError(t, err) require.Equal(t, tt.expectTokenPath, p.config.tokenPath) diff --git a/pkg/agent/plugin/nodeattestor/sshpop/sshpop.go b/pkg/agent/plugin/nodeattestor/sshpop/sshpop.go index 6da6be2f46..d5b9bf93b4 100644 --- a/pkg/agent/plugin/nodeattestor/sshpop/sshpop.go +++ b/pkg/agent/plugin/nodeattestor/sshpop/sshpop.go @@ -8,6 +8,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/sshpop" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -74,12 +75,23 @@ func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestatio // Configure configures the Plugin. func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - sshclient, err := sshpop.NewClient(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, sshpop.BuildClientConfig) if err != nil { return nil, err } + p.mu.Lock() - p.sshclient = sshclient + p.sshclient = newConfig.NewClient() p.mu.Unlock() + return &configv1.ConfigureResponse{}, nil } + +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, sshpop.BuildClientConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} diff --git a/pkg/agent/plugin/nodeattestor/sshpop/sshpop_test.go b/pkg/agent/plugin/nodeattestor/sshpop/sshpop_test.go index e3a9ed8fbc..69ea062e80 100644 --- a/pkg/agent/plugin/nodeattestor/sshpop/sshpop_test.go +++ b/pkg/agent/plugin/nodeattestor/sshpop/sshpop_test.go @@ -6,8 +6,10 @@ import ( "os" "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/sshpop" "github.com/spiffe/spire/test/fixture" "github.com/spiffe/spire/test/plugintest" @@ -42,9 +44,13 @@ func (s *Suite) SetupTest() { host_key_path = %q host_cert_path = %q`, privateKeyPath, certificatePath) - s.na = s.loadPlugin(plugintest.Configure(clientConfig)) + s.na = s.loadPlugin(plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(clientConfig), + ) - sshclient, err := sshpop.NewClient(clientConfig) + sshclient, err := sshpop.NewClient("example.org", clientConfig) require.NoError(err) s.sshclient = sshclient diff --git a/pkg/agent/plugin/nodeattestor/tpmdevid/devid.go b/pkg/agent/plugin/nodeattestor/tpmdevid/devid.go index 724c4d01de..8c5c331d82 100644 --- a/pkg/agent/plugin/nodeattestor/tpmdevid/devid.go +++ b/pkg/agent/plugin/nodeattestor/tpmdevid/devid.go @@ -3,7 +3,6 @@ package tpmdevid import ( "context" "encoding/json" - "errors" "fmt" "os" "runtime" @@ -16,6 +15,7 @@ import ( "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/tpmdevid/tpmutil" "github.com/spiffe/spire/pkg/common/catalog" common_devid "github.com/spiffe/spire/pkg/common/plugin/tpmdevid" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -49,6 +49,37 @@ type Config struct { EndorsementHierarchyPassword string `hcl:"endorsement_hierarchy_password"` DevicePath string `hcl:"tpm_device_path"` + Autodetect bool +} + +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.DevIDCertPath == "" { + status.ReportError("invalid configuration: devid_cert_path is required") + } + + if newConfig.DevIDPrivPath == "" { + status.ReportError("invalid configuration: devid_priv_path is required") + } + + if newConfig.DevIDPubPath == "" { + status.ReportError("invalid configuration: devid_pub_path is required") + } + + if newConfig.DevicePath != "" && runtime.GOOS == "windows" { + status.ReportError("device path is not allowed on windows") + } + + if newConfig.DevicePath == "" && runtime.GOOS != "windows" { + newConfig.Autodetect = true + } + + return newConfig } type config struct { @@ -194,46 +225,45 @@ func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestatio } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - extConf, err := decodePluginConfig(req.HclConfiguration) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - err = validatePluginConfig(extConf) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid configuration: %v", err) + return nil, err } - p.m.Lock() - defer p.m.Unlock() - - switch { - case runtime.GOOS == "windows" && extConf.DevicePath == "": - // OK - case runtime.GOOS == "windows" && extConf.DevicePath != "": - return nil, status.Error(codes.InvalidArgument, "device path is not allowed on windows") - case runtime.GOOS != "windows" && extConf.DevicePath != "": - p.c.devicePath = extConf.DevicePath - case runtime.GOOS != "windows" && extConf.DevicePath == "": + if newConfig.Autodetect { tpmPath, err := AutoDetectTPMPath(BaseTPMDir) if err != nil { return nil, status.Errorf(codes.Internal, "tpm autodetection failed: %v", err) } - p.c.devicePath = tpmPath + newConfig.DevicePath = tpmPath } - err = p.loadDevIDFiles(extConf) + p.m.Lock() + defer p.m.Unlock() + + p.c.devicePath = newConfig.DevicePath + + err = p.loadDevIDFiles(newConfig) if err != nil { return nil, status.Errorf(codes.Internal, "unable to load DevID files: %v", err) } - p.c.passwords.DevIDKey = extConf.DevIDKeyPassword - p.c.passwords.OwnerHierarchy = extConf.OwnerHierarchyPassword - p.c.passwords.EndorsementHierarchy = extConf.EndorsementHierarchyPassword + p.c.passwords.DevIDKey = newConfig.DevIDKeyPassword + p.c.passwords.OwnerHierarchy = newConfig.OwnerHierarchyPassword + p.c.passwords.EndorsementHierarchy = newConfig.EndorsementHierarchyPassword return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *Plugin) SetLogger(log hclog.Logger) { p.log = log } @@ -266,28 +296,3 @@ func (p *Plugin) loadDevIDFiles(c *Config) error { return nil } - -func decodePluginConfig(hclConf string) (*Config, error) { - extConfig := new(Config) - if err := hcl.Decode(extConfig, hclConf); err != nil { - return nil, err - } - - return extConfig, nil -} - -func validatePluginConfig(c *Config) error { - // DevID certificate, public and private key are always required - switch { - case c.DevIDCertPath == "": - return errors.New("devid_cert_path is required") - - case c.DevIDPrivPath == "": - return errors.New("devid_priv_path is required") - - case c.DevIDPubPath == "": - return errors.New("devid_pub_path is required") - } - - return nil -} diff --git a/pkg/agent/plugin/nodeattestor/tpmdevid/devid_test.go b/pkg/agent/plugin/nodeattestor/tpmdevid/devid_test.go index e164342d2b..562bb076d1 100644 --- a/pkg/agent/plugin/nodeattestor/tpmdevid/devid_test.go +++ b/pkg/agent/plugin/nodeattestor/tpmdevid/devid_test.go @@ -15,11 +15,13 @@ import ( "github.com/google/go-tpm/legacy/tpm2" "github.com/hashicorp/go-hclog" + "github.com/spiffe/go-spiffe/v2/spiffeid" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/tpmdevid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/tpmdevid/tpmutil" + "github.com/spiffe/spire/pkg/common/catalog" common_devid "github.com/spiffe/spire/pkg/common/plugin/tpmdevid" server_devid "github.com/spiffe/spire/pkg/server/plugin/nodeattestor/tpmdevid" "github.com/spiffe/spire/test/plugintest" @@ -34,6 +36,7 @@ var ( tpmDevicePath = "/dev/tpmrm0" + trustDomain string devIDCertPath string devIDPrivPath string devIDPubPath string @@ -85,6 +88,7 @@ func setupSimulator(t *testing.T) *tpmsimulator.TPMSimulator { func writeDevIDFiles(t *testing.T) { dir := t.TempDir() + trustDomain = "example.org" devIDCertPath = path.Join(dir, "devid-certificate.pem") devIDPrivPath = path.Join(dir, "devid-priv-path") devIDPubPath = path.Join(dir, "devid-pub-path") @@ -109,33 +113,39 @@ func TestConfigureCommon(t *testing.T) { tests := []struct { name string + trustDomain string hclConf string expErr string autoDetectTPMFails bool }{ { - name: "Configure fails if receives wrong HCL configuration", - hclConf: "not HCL conf", - expErr: "rpc error: code = InvalidArgument desc = unable to decode configuration", + name: "Configure fails if receives wrong HCL configuration", + trustDomain: "example.org", + hclConf: "not HCL conf", + expErr: "rpc error: code = InvalidArgument desc = unable to decode configuration", }, { - name: "Configure fails if DevID certificate path is empty", - hclConf: "", - expErr: "rpc error: code = InvalidArgument desc = invalid configuration: devid_cert_path is required", + name: "Configure fails if DevID certificate path is empty", + trustDomain: "example.org", + hclConf: "", + expErr: "rpc error: code = InvalidArgument desc = invalid configuration: devid_cert_path is required", }, { - name: "Configure fails if DevID private key path is empty", - hclConf: `devid_cert_path = "non-existent-path/to/devid.cert"`, - expErr: "rpc error: code = InvalidArgument desc = invalid configuration: devid_priv_path is required", + name: "Configure fails if DevID private key path is empty", + trustDomain: "example.org", + hclConf: `devid_cert_path = "non-existent-path/to/devid.cert"`, + expErr: "rpc error: code = InvalidArgument desc = invalid configuration: devid_priv_path is required", }, { - name: "Configure fails if DevID public key path is empty", + name: "Configure fails if DevID public key path is empty", + trustDomain: "example.org", hclConf: ` devid_cert_path = "non-existent-path/to/devid.cert" devid_priv_path = "non-existent-path/to/devid-private-blob"`, expErr: "rpc error: code = InvalidArgument desc = invalid configuration: devid_pub_path is required", }, { - name: "Configure succeeds auto detecting the TPM path", + name: "Configure succeeds auto detecting the TPM path", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = %q`, @@ -144,7 +154,8 @@ func TestConfigureCommon(t *testing.T) { devIDPubPath), }, { - name: "Configure succeeds if DevID does not have intermediates certificates", + name: "Configure succeeds if DevID does not have intermediates certificates", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = %q`, @@ -170,7 +181,12 @@ func TestConfigureCommon(t *testing.T) { plugin := tpmdevid.New() - resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{HclConfiguration: tt.hclConf}) + resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{ + CoreConfiguration: &configv1.CoreConfiguration{ + TrustDomain: tt.trustDomain, + }, + HclConfiguration: tt.hclConf, + }) if tt.expErr != "" { require.Contains(t, err.Error(), tt.expErr) require.Nil(t, resp) @@ -192,12 +208,14 @@ func TestConfigurePosix(t *testing.T) { tests := []struct { name string + trustDomain string hclConf string expErr string autoDetectTPMFails bool }{ { - name: "Configure fails if DevID certificate cannot be opened", + name: "Configure fails if DevID certificate cannot be opened", + trustDomain: "example.org", hclConf: ` devid_cert_path = "non-existent-path/to/devid.cert" devid_priv_path = "non-existent-path/to/devid-private-blob" devid_pub_path = "non-existent-path/to/devid-public-blob" @@ -205,7 +223,8 @@ func TestConfigurePosix(t *testing.T) { expErr: "rpc error: code = Internal desc = unable to load DevID files: cannot load certificate(s): open non-existent-path/to/devid.cert:", }, { - name: "Configure fails if TPM path is not provided and it cannot be auto detected", + name: "Configure fails if TPM path is not provided and it cannot be auto detected", + trustDomain: "example.org", hclConf: `devid_cert_path = "non-existent-path/to/devid.cert" devid_priv_path = "non-existent-path/to/devid-private-blob" devid_pub_path = "non-existent-path/to/devid-public-blob"`, @@ -213,7 +232,8 @@ func TestConfigurePosix(t *testing.T) { autoDetectTPMFails: true, }, { - name: "Configure fails if DevID private key cannot be opened", + name: "Configure fails if DevID private key cannot be opened", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = "non-existent-path/to/devid-private-blob" devid_pub_path = "non-existent-path/to/devid-public-blob" @@ -221,7 +241,8 @@ func TestConfigurePosix(t *testing.T) { expErr: "rpc error: code = Internal desc = unable to load DevID files: cannot load private key: open non-existent-path/to/devid-private-blob:", }, { - name: "Configure fails if DevID public key cannot be opened", + name: "Configure fails if DevID public key cannot be opened", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = "non-existent-path/to/devid-public-blob" @@ -231,7 +252,8 @@ func TestConfigurePosix(t *testing.T) { expErr: "rpc error: code = Internal desc = unable to load DevID files: cannot load public key: open non-existent-path/to/devid-public-blob:", }, { - name: "Configure succeeds providing a TPM path", + name: "Configure succeeds providing a TPM path", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = %q @@ -254,7 +276,12 @@ func TestConfigurePosix(t *testing.T) { plugin := tpmdevid.New() - resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{HclConfiguration: tt.hclConf}) + resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{ + CoreConfiguration: &configv1.CoreConfiguration{ + TrustDomain: tt.trustDomain, + }, + HclConfiguration: tt.hclConf, + }) if tt.expErr != "" { require.Contains(t, err.Error(), tt.expErr) require.Nil(t, resp) @@ -276,26 +303,30 @@ func TestConfigureWindows(t *testing.T) { tests := []struct { name string + trustDomain string hclConf string expErr string autoDetectTPMFails bool }{ { - name: "Configure fails if DevID certificate cannot be opened", + name: "Configure fails if DevID certificate cannot be opened", + trustDomain: "example.org", hclConf: ` devid_cert_path = "non-existent-path/to/devid.cert" devid_priv_path = "non-existent-path/to/devid-private-blob" devid_pub_path = "non-existent-path/to/devid-public-blob"`, expErr: "rpc error: code = Internal desc = unable to load DevID files: cannot load certificate(s): open non-existent-path/to/devid.cert:", }, { - name: "Configure fails if DevID private key cannot be opened", + name: "Configure fails if DevID private key cannot be opened", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = "non-existent-path/to/devid-private-blob" devid_pub_path = "non-existent-path/to/devid-public-blob"`, devIDCertPath), expErr: "rpc error: code = Internal desc = unable to load DevID files: cannot load private key: open non-existent-path/to/devid-private-blob:", }, { - name: "Configure fails if Device Path is provided", + name: "Configure fails if Device Path is provided", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = %q @@ -306,7 +337,8 @@ func TestConfigureWindows(t *testing.T) { expErr: "rpc error: code = InvalidArgument desc = device path is not allowed on windows", }, { - name: "Configure fails if DevID public key cannot be opened", + name: "Configure fails if DevID public key cannot be opened", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = "non-existent-path/to/devid-public-blob"`, @@ -315,7 +347,8 @@ func TestConfigureWindows(t *testing.T) { expErr: "rpc error: code = Internal desc = unable to load DevID files: cannot load public key: open non-existent-path/to/devid-public-blob:", }, { - name: "Configure succeeds providing a TPM path", + name: "Configure succeeds providing a TPM path", + trustDomain: "example.org", hclConf: fmt.Sprintf(`devid_cert_path = %q devid_priv_path = %q devid_pub_path = %q`, @@ -334,7 +367,12 @@ func TestConfigureWindows(t *testing.T) { plugin := tpmdevid.New() - resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{HclConfiguration: tt.hclConf}) + resp, err := plugin.Configure(context.Background(), &configv1.ConfigureRequest{ + CoreConfiguration: &configv1.CoreConfiguration{ + TrustDomain: tt.trustDomain, + }, + HclConfiguration: tt.hclConf, + }) if tt.expErr != "" { require.Contains(t, err.Error(), tt.expErr) require.Nil(t, resp) @@ -578,7 +616,11 @@ func loadAndConfigurePlugin(t *testing.T, passwords tpmutil.TPMPasswords) nodeat passwords.EndorsementHierarchy, ) - return loadPlugin(t, plugintest.Configure(config)) + return loadPlugin(t, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), + plugintest.Configure(config), + ) } func loadPlugin(t *testing.T, options ...plugintest.Option) nodeattestor.NodeAttestor { diff --git a/pkg/agent/plugin/nodeattestor/x509pop/x509pop.go b/pkg/agent/plugin/nodeattestor/x509pop/x509pop.go index e0b36edca1..802b3e6a46 100644 --- a/pkg/agent/plugin/nodeattestor/x509pop/x509pop.go +++ b/pkg/agent/plugin/nodeattestor/x509pop/x509pop.go @@ -13,6 +13,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/x509pop" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -43,6 +44,24 @@ type Config struct { IntermediatesPath string `hcl:"intermediates_path"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.PrivateKeyPath == "" { + status.ReportError("private_key_path is required") + } + + if newConfig.CertificatePath == "" { + status.ReportError("certificate_path is required") + } + + return newConfig +} + type Plugin struct { nodeattestorv1.UnsafeNodeAttestorServer configv1.UnsafeConfigServer @@ -100,39 +119,36 @@ func (p *Plugin) AidAttestation(stream nodeattestorv1.NodeAttestor_AidAttestatio } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := new(Config) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.PrivateKeyPath == "" { - return nil, status.Error(codes.InvalidArgument, "private_key_path is required") - } - if config.CertificatePath == "" { - return nil, status.Error(codes.InvalidArgument, "certificate_path is required") + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } // make sure the configuration produces valid data - if _, err := loadConfigData(config); err != nil { + if _, err := loadConfigData(newConfig); err != nil { return nil, err } - p.setConfig(config) + p.m.Lock() + defer p.m.Unlock() + p.c = newConfig return &configv1.ConfigureResponse{}, nil } -func (p *Plugin) getConfig() *Config { - p.m.Lock() - defer p.m.Unlock() - return p.c +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil } -func (p *Plugin) setConfig(c *Config) { +func (p *Plugin) getConfig() *Config { p.m.Lock() defer p.m.Unlock() - p.c = c + return p.c } func (p *Plugin) loadConfigData() (*configData, error) { @@ -143,6 +159,7 @@ func (p *Plugin) loadConfigData() (*configData, error) { return loadConfigData(config) } +// TODO: this needs more attention. Parts of it might belong in buildConfig func loadConfigData(config *Config) (*configData, error) { certificate, err := tls.LoadX509KeyPair(config.CertificatePath, config.PrivateKeyPath) if err != nil { diff --git a/pkg/agent/plugin/nodeattestor/x509pop/x509pop_test.go b/pkg/agent/plugin/nodeattestor/x509pop/x509pop_test.go index 5301033806..e64622a4c1 100644 --- a/pkg/agent/plugin/nodeattestor/x509pop/x509pop_test.go +++ b/pkg/agent/plugin/nodeattestor/x509pop/x509pop_test.go @@ -8,8 +8,10 @@ import ( "fmt" "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor" nodeattestortest "github.com/spiffe/spire/pkg/agent/plugin/nodeattestor/test" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/x509pop" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/test/fixture" @@ -19,6 +21,7 @@ import ( ) var ( + trustDomain = "example.org" leafKeyPath = fixture.Join("nodeattestor", "x509pop", "leaf-key.pem") leafCertPath = fixture.Join("nodeattestor", "x509pop", "leaf-crt-bundle.pem") intermediatePath = fixture.Join("nodeattestor", "x509pop", "intermediate.pem") @@ -87,10 +90,13 @@ func (s *Suite) TestConfigure() { // malformed s.loadPlugin(plugintest.CaptureConfigureError(&err), plugintest.Configure(`bad juju`)) - s.RequireGRPCStatusContains(err, codes.InvalidArgument, "unable to decode configuration") + s.RequireGRPCStatusContains(err, codes.InvalidArgument, "server core configuration must contain trust_domain") // missing private_key_path s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configure(` certificate_path = "blah" `), @@ -99,6 +105,9 @@ func (s *Suite) TestConfigure() { // missing certificate_path s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configure(` private_key_path = "blah" `), @@ -107,6 +116,9 @@ func (s *Suite) TestConfigure() { // cannot load keypair s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configure(` private_key_path = "blah" certificate_path = "blah" @@ -116,6 +128,9 @@ func (s *Suite) TestConfigure() { // cannot load intermediates s.loadPlugin(plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configuref(` private_key_path = %q certificate_path = %q @@ -138,7 +153,12 @@ func (s *Suite) loadAndConfigurePlugin(withIntermediate bool) nodeattestor.NodeA config += fmt.Sprintf(` intermediates_path = %q`, intermediatePath) } - return s.loadPlugin(plugintest.Configure(config)) + return s.loadPlugin( + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), + plugintest.Configure(config), + ) } func (s *Suite) testAttestSuccess(p nodeattestor.NodeAttestor, expectBundle [][]byte) { diff --git a/pkg/agent/plugin/svidstore/awssecretsmanager/aws.go b/pkg/agent/plugin/svidstore/awssecretsmanager/aws.go index c16962a762..dad088db71 100644 --- a/pkg/agent/plugin/svidstore/awssecretsmanager/aws.go +++ b/pkg/agent/plugin/svidstore/awssecretsmanager/aws.go @@ -16,6 +16,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/agent/plugin/svidstore" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -53,6 +54,28 @@ type Configuration struct { Region string `hcl:"region" json:"region"` } +func (p *SecretsManagerPlugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := &Configuration{} + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.AccessKeyID == "" { + newConfig.AccessKeyID = p.hooks.getenv("AWS_ACCESS_KEY_ID") + } + + if newConfig.SecretAccessKey == "" { + newConfig.SecretAccessKey = p.hooks.getenv("AWS_SECRET_ACCESS_KEY") + } + + if newConfig.Region == "" { + status.ReportError("region is required") + } + + return newConfig +} + type SecretsManagerPlugin struct { svidstorev1.UnsafeSVIDStoreServer configv1.UnsafeConfigServer @@ -73,25 +96,12 @@ func (p *SecretsManagerPlugin) SetLogger(log hclog.Logger) { // Configure configures the SecretsManagerPlugin. func (p *SecretsManagerPlugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := &Configuration{} - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.AccessKeyID == "" { - config.AccessKeyID = p.hooks.getenv("AWS_ACCESS_KEY_ID") - } - - if config.SecretAccessKey == "" { - config.SecretAccessKey = p.hooks.getenv("AWS_SECRET_ACCESS_KEY") - } - - if config.Region == "" { - return nil, status.Error(codes.InvalidArgument, "region is required") + newConfig, _, err := pluginconf.Build(req, p.buildConfig) + if err != nil { + return nil, err } - smClient, err := p.hooks.newClient(ctx, config.SecretAccessKey, config.AccessKeyID, config.Region) + smClient, err := p.hooks.newClient(ctx, newConfig.SecretAccessKey, newConfig.AccessKeyID, newConfig.Region) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create secrets manager client: %v", err) } @@ -104,6 +114,15 @@ func (p *SecretsManagerPlugin) Configure(ctx context.Context, req *configv1.Conf return &configv1.ConfigureResponse{}, nil } +func (p *SecretsManagerPlugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + // PutX509SVID puts the specified X509-SVID in the configured AWS Secrets Manager func (p *SecretsManagerPlugin) PutX509SVID(ctx context.Context, req *svidstorev1.PutX509SVIDRequest) (*svidstorev1.PutX509SVIDResponse, error) { opt, err := optionsFromSecretData(req.Metadata) diff --git a/pkg/agent/plugin/svidstore/awssecretsmanager/aws_test.go b/pkg/agent/plugin/svidstore/awssecretsmanager/aws_test.go index 6dde197509..e7f1bf8879 100644 --- a/pkg/agent/plugin/svidstore/awssecretsmanager/aws_test.go +++ b/pkg/agent/plugin/svidstore/awssecretsmanager/aws_test.go @@ -82,6 +82,7 @@ func TestConfigure(t *testing.T) { for _, tt := range []struct { name string + trustDomain string envs map[string]string accessKeyID string secretAccessKey string @@ -94,6 +95,7 @@ func TestConfigure(t *testing.T) { }{ { name: "access key and secret from config", + trustDomain: "example.org", envs: envs, accessKeyID: "ACCESS_KEY", secretAccessKey: "ID", @@ -105,9 +107,10 @@ func TestConfigure(t *testing.T) { }, }, { - name: "access key and secret from env vars", - envs: envs, - region: "r1", + name: "access key and secret from env vars", + trustDomain: "example.org", + envs: envs, + region: "r1", expectConfig: &Configuration{ AccessKeyID: "foh", SecretAccessKey: "bar", @@ -116,12 +119,14 @@ func TestConfigure(t *testing.T) { }, { name: "no region provided", + trustDomain: "example.org", envs: envs, expectCode: codes.InvalidArgument, expectMsgPrefix: "region is required", }, { name: "new client fails", + trustDomain: "example.org", envs: envs, region: "r1", expectClientErr: errors.New("oh no"), @@ -130,6 +135,7 @@ func TestConfigure(t *testing.T) { }, { name: "malformed configuration", + trustDomain: "example.org", envs: envs, region: "r1", customConfig: "{ not a config }", @@ -142,6 +148,9 @@ func TestConfigure(t *testing.T) { options := []plugintest.Option{ plugintest.CaptureConfigureError(&err), } + options = append(options, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + })) if tt.customConfig != "" { options = append(options, plugintest.Configure(tt.customConfig)) diff --git a/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud.go b/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud.go index eb92c2b943..db7efe7258 100644 --- a/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud.go +++ b/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud.go @@ -19,6 +19,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/agent/plugin/svidstore" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -54,6 +55,26 @@ type Configuration struct { UnusedKeyPositions map[string][]token.Pos `hcl:",unusedKeyPositions" json:",omitempty"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := &Configuration{} + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if len(newConfig.UnusedKeyPositions) != 0 { + var keys []string + for k := range newConfig.UnusedKeyPositions { + keys = append(keys, k) + } + + sort.Strings(keys) + status.ReportErrorf("unknown configurations detected: %s", strings.Join(keys, ",")) + } + + return newConfig +} + type SecretManagerPlugin struct { svidstorev1.UnsafeSVIDStoreServer configv1.UnsafeConfigServer @@ -74,23 +95,12 @@ func (p *SecretManagerPlugin) SetLogger(log hclog.Logger) { // Configure configures the SecretManagerPlugin. func (p *SecretManagerPlugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := &Configuration{} - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if len(config.UnusedKeyPositions) != 0 { - var keys []string - for k := range config.UnusedKeyPositions { - keys = append(keys, k) - } - - sort.Strings(keys) - return nil, status.Errorf(codes.InvalidArgument, "unknown configurations detected: %s", strings.Join(keys, ",")) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - secretMangerClient, err := p.hooks.newSecretManagerClient(ctx, config.ServiceAccountFile) + secretMangerClient, err := p.hooks.newSecretManagerClient(ctx, newConfig.ServiceAccountFile) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create secretmanager client: %v", err) } @@ -107,6 +117,15 @@ func (p *SecretManagerPlugin) Configure(ctx context.Context, req *configv1.Confi return &configv1.ConfigureResponse{}, nil } +func (p *SecretManagerPlugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + // PutX509SVID puts the specified X509-SVID in the configured Google Cloud Secrets Manager func (p *SecretManagerPlugin) PutX509SVID(ctx context.Context, req *svidstorev1.PutX509SVIDRequest) (*svidstorev1.PutX509SVIDResponse, error) { opt, err := optionsFromSecretData(req.Metadata) diff --git a/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud_test.go b/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud_test.go index 85f44781d7..8fe2aad9b3 100644 --- a/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud_test.go +++ b/pkg/agent/plugin/svidstore/gcpsecretmanager/gcloud_test.go @@ -87,9 +87,9 @@ func TestConfigure(t *testing.T) { for _, tt := range []struct { name string + trustDomain spiffeid.TrustDomain customConfig string newClientErr error - trustDomain spiffeid.TrustDomain expectCode codes.Code expectMsgPrefix string expectFilePath string @@ -98,32 +98,35 @@ func TestConfigure(t *testing.T) { }{ { name: "success", + trustDomain: trustDomain, expectFilePath: "someFile", expectConfig: &Configuration{ServiceAccountFile: "someFile"}, - trustDomain: trustDomain, expectTD: tdHash, }, { name: "no config file", - expectConfig: &Configuration{ServiceAccountFile: ""}, trustDomain: trustDomain, + expectConfig: &Configuration{ServiceAccountFile: ""}, expectTD: tdHash, }, { name: "malformed configuration", + trustDomain: trustDomain, customConfig: "{no a config}", expectCode: codes.InvalidArgument, expectMsgPrefix: "unable to decode configuration:", }, { name: "failed to create client", + trustDomain: trustDomain, expectConfig: &Configuration{ServiceAccountFile: "someFile"}, newClientErr: errors.New("oh! no"), expectCode: codes.Internal, expectMsgPrefix: "failed to create secretmanager client: oh! no", }, { - name: "contains unused keys", + name: "contains unused keys", + trustDomain: trustDomain, customConfig: ` service_account_file = "some_file" invalid1 = "something" diff --git a/pkg/agent/plugin/workloadattestor/docker/docker.go b/pkg/agent/plugin/workloadattestor/docker/docker.go index 36ec0b3081..21bca906fa 100644 --- a/pkg/agent/plugin/workloadattestor/docker/docker.go +++ b/pkg/agent/plugin/workloadattestor/docker/docker.go @@ -17,6 +17,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/agent/common/sigstore" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/telemetry" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -74,6 +75,10 @@ type dockerPluginConfig struct { UnusedKeyPositions map[string][]token.Pos `hcl:",unusedKeyPositions"` Experimental experimentalConfig `hcl:"experimental,omitempty" json:"experimental,omitempty"` + + containerHelper *containerHelper + dockerOpts []dockerclient.Opt + sigstoreConfig *sigstore.Config } type experimentalConfig struct { @@ -81,6 +86,43 @@ type experimentalConfig struct { Sigstore *sigstore.HCLConfig `hcl:"sigstore,omitempty"` } +func (p *Plugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *dockerPluginConfig { + var err error + newConfig := &dockerPluginConfig{} + if err = hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if len(newConfig.UnusedKeyPositions) > 0 { + var keys []string + for k := range newConfig.UnusedKeyPositions { + keys = append(keys, k) + } + + sort.Strings(keys) + status.ReportErrorf("unknown configurations detected: %s", strings.Join(keys, ",")) + } + + newConfig.containerHelper = p.createHelper(newConfig, status) + + dockerHost := getDockerHost(newConfig) + if dockerHost != "" { + newConfig.dockerOpts = append(newConfig.dockerOpts, dockerclient.WithHost(dockerHost)) + } + if newConfig.DockerVersion == "" { + newConfig.dockerOpts = append(newConfig.dockerOpts, dockerclient.WithAPIVersionNegotiation()) + } else { + newConfig.dockerOpts = append(newConfig.dockerOpts, dockerclient.WithVersion(newConfig.DockerVersion)) + } + + if newConfig.Experimental.Sigstore != nil { + newConfig.sigstoreConfig = sigstore.NewConfigFromHCL(newConfig.Experimental.Sigstore, p.log) + } + + return newConfig +} + func (p *Plugin) SetLogger(log hclog.Logger) { p.log = log } @@ -165,48 +207,19 @@ func getSelectorValuesFromConfig(cfg *container.Config) []string { } func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - var err error - config := &dockerPluginConfig{} - if err = hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if len(config.UnusedKeyPositions) > 0 { - var keys []string - for k := range config.UnusedKeyPositions { - keys = append(keys, k) - } - - sort.Strings(keys) - return nil, status.Errorf(codes.InvalidArgument, "unknown configurations detected: %s", strings.Join(keys, ",")) - } - - containerHelper, err := createHelper(config, p.log) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { return nil, err } - var opts []dockerclient.Opt - dockerHost := getDockerHost(config) - if dockerHost != "" { - opts = append(opts, dockerclient.WithHost(dockerHost)) - } - switch { - case config.DockerVersion != "": - opts = append(opts, dockerclient.WithVersion(config.DockerVersion)) - default: - opts = append(opts, dockerclient.WithAPIVersionNegotiation()) - } - - docker, err := dockerclient.NewClientWithOpts(opts...) + docker, err := dockerclient.NewClientWithOpts(newConfig.dockerOpts...) if err != nil { return nil, err } var sigstoreVerifier sigstore.Verifier - if config.Experimental.Sigstore != nil { - cfg := sigstore.NewConfigFromHCL(config.Experimental.Sigstore, p.log) - verifier := sigstore.NewVerifier(cfg) + if newConfig.sigstoreConfig != nil { + verifier := sigstore.NewVerifier(newConfig.sigstoreConfig) err = verifier.Init(ctx) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "error initializing sigstore verifier: %v", err) @@ -217,7 +230,17 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) p.mtx.Lock() defer p.mtx.Unlock() p.docker = docker - p.c = containerHelper + p.c = newConfig.containerHelper p.sigstoreVerifier = sigstoreVerifier + return &configv1.ConfigureResponse{}, nil } + +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} diff --git a/pkg/agent/plugin/workloadattestor/docker/docker_posix.go b/pkg/agent/plugin/workloadattestor/docker/docker_posix.go index 19e24ff29b..9c2bdacaeb 100644 --- a/pkg/agent/plugin/workloadattestor/docker/docker_posix.go +++ b/pkg/agent/plugin/workloadattestor/docker/docker_posix.go @@ -9,12 +9,11 @@ import ( "os" "path/filepath" - "github.com/gogo/status" "github.com/hashicorp/go-hclog" "github.com/spiffe/spire/pkg/agent/common/cgroups" "github.com/spiffe/spire/pkg/agent/plugin/workloadattestor/docker/cgroup" "github.com/spiffe/spire/pkg/common/containerinfo" - "google.golang.org/grpc/codes" + "github.com/spiffe/spire/pkg/common/pluginconf" ) type OSConfig struct { @@ -38,22 +37,25 @@ type OSConfig struct { rootDir string } -func createHelper(c *dockerPluginConfig, log hclog.Logger) (*containerHelper, error) { +func (p *Plugin) createHelper(c *dockerPluginConfig, status *pluginconf.Status) *containerHelper { useNewContainerLocator := c.UseNewContainerLocator == nil || *c.UseNewContainerLocator var containerIDFinder cgroup.ContainerIDFinder if len(c.ContainerIDCGroupMatchers) > 0 { if useNewContainerLocator { - return nil, status.Error(codes.InvalidArgument, "the new container locator and custom cgroup matchers cannot both be used; please open an issue if the new container locator fails to locate workload containers in your environment; to continue using custom matchers set use_new_container_locator=false") + status.ReportError("the new container locator and custom cgroup matchers cannot both be used; please open an issue if the new container locator fails to locate workload containers in your environment; to continue using custom matchers set use_new_container_locator=false") + return nil } - log.Warn("Using the legacy container locator with custom cgroup matchers. This feature will be removed in a future release.") + p.log.Warn("Using the legacy container locator with custom cgroup matchers. This feature will be removed in a future release.") + status.ReportInfo("Using the legacy container locator with custom cgroup matchers. This feature will be removed in a future release.") var err error containerIDFinder, err = cgroup.NewContainerIDFinder(c.ContainerIDCGroupMatchers) if err != nil { - return nil, err + status.ReportError(err.Error()) + return nil } } else { - log.Info("Using the new container locator") + status.ReportInfo("Using the new container locator") } rootDir := c.rootDir @@ -65,7 +67,7 @@ func createHelper(c *dockerPluginConfig, log hclog.Logger) (*containerHelper, er rootDir: rootDir, containerIDFinder: containerIDFinder, verboseContainerLocatorLogs: c.VerboseContainerLocatorLogs, - }, nil + } } type dirFS string diff --git a/pkg/agent/plugin/workloadattestor/docker/docker_posix_test.go b/pkg/agent/plugin/workloadattestor/docker/docker_posix_test.go index 0b8fa6efb7..9458952530 100644 --- a/pkg/agent/plugin/workloadattestor/docker/docker_posix_test.go +++ b/pkg/agent/plugin/workloadattestor/docker/docker_posix_test.go @@ -19,15 +19,17 @@ const ( func TestContainerExtraction(t *testing.T) { tests := []struct { - desc string - cfg string - cgroups string - hasMatch bool - expectErr string + desc string + trustDomain string + cfg string + cgroups string + hasMatch bool + expectErr string }{ { - desc: "no match", - cgroups: testCgroupEntries, + desc: "no match", + trustDomain: "example.org", + cgroups: testCgroupEntries, cfg: ` use_new_container_locator = false container_id_cgroup_matchers = [ @@ -36,8 +38,9 @@ func TestContainerExtraction(t *testing.T) { `, }, { - desc: "one miss one match", - cgroups: testCgroupEntries, + desc: "one miss one match", + trustDomain: "example.org", + cgroups: testCgroupEntries, cfg: ` use_new_container_locator = false container_id_cgroup_matchers = [ @@ -48,8 +51,9 @@ func TestContainerExtraction(t *testing.T) { hasMatch: true, }, { - desc: "no container id", - cgroups: "10:cpu:/docker/", + desc: "no container id", + trustDomain: "example.org", + cgroups: "10:cpu:/docker/", cfg: ` use_new_container_locator = false container_id_cgroup_matchers = [ @@ -59,24 +63,28 @@ func TestContainerExtraction(t *testing.T) { expectErr: "a pattern matched, but no container id was found", }, { - desc: "RHEL docker cgroups", - cgroups: "4:devices:/system.slice/docker-6469646e742065787065637420616e796f6e6520746f20726561642074686973.scope", - hasMatch: true, + desc: "RHEL docker cgroups", + trustDomain: "example.org", + cgroups: "4:devices:/system.slice/docker-6469646e742065787065637420616e796f6e6520746f20726561642074686973.scope", + hasMatch: true, }, { - desc: "docker for desktop", - cgroups: "6:devices:/docker/6469646e742065787065637420616e796f6e6520746f20726561642074686973/docker/6469646e742065787065637420616e796f6e6520746f20726561642074686973/system.slice/containerd.service", - hasMatch: true, + desc: "docker for desktop", + trustDomain: "example.org", + cgroups: "6:devices:/docker/6469646e742065787065637420616e796f6e6520746f20726561642074686973/docker/6469646e742065787065637420616e796f6e6520746f20726561642074686973/system.slice/containerd.service", + hasMatch: true, }, { - desc: "more than one id", - cgroups: testCgroupEntries + "\n" + "4:devices:/system.slice/docker-41e4ab61d2860b0e1467de0da0a9c6068012761febec402dc04a5a94f32ea867.scope", - expectErr: "multiple container IDs found", + desc: "more than one id", + trustDomain: "example.org", + cgroups: testCgroupEntries + "\n" + "4:devices:/system.slice/docker-41e4ab61d2860b0e1467de0da0a9c6068012761febec402dc04a5a94f32ea867.scope", + expectErr: "multiple container IDs found", }, { - desc: "default configuration matches cgroup missing docker prefix", - cgroups: "4:devices:/system.slice/6469646e742065787065637420616e796f6e6520746f20726561642074686973.scope", - hasMatch: true, + desc: "default configuration matches cgroup missing docker prefix", + trustDomain: "example.org", + cgroups: "4:devices:/system.slice/6469646e742065787065637420616e796f6e6520746f20726561642074686973.scope", + hasMatch: true, }, } @@ -93,7 +101,7 @@ func TestContainerExtraction(t *testing.T) { p := newTestPlugin( t, - withConfig(t, tt.cfg), // this must be the first option + withConfig(t, tt.trustDomain, tt.cfg), // this must be the first option withDocker(d), withRootDirOpt, ) @@ -132,7 +140,7 @@ func TestDockerConfigPosix(t *testing.T) { expectFinder, err := cgroup.NewContainerIDFinder([]string{"/docker/"}) require.NoError(t, err) - p := newTestPlugin(t, withConfig(t, ` + p := newTestPlugin(t, withConfig(t, "example.org", ` use_new_container_locator = false docker_socket_path = "unix:///socket_path" docker_version = "1.20" @@ -152,7 +160,7 @@ use_new_container_locator = false container_id_cgroup_matchers = [ "/docker/", ]` - err := doConfigure(t, p, cfg) + err := doConfigure(t, p, "example.org", cfg) require.Error(t, err) require.Contains(t, err.Error(), `must contain the container id token "" exactly once`) }) @@ -184,9 +192,9 @@ func withRootDir(dir string) testPluginOpt { } // this must be the first plugin opt -func withConfig(t *testing.T, cfg string) testPluginOpt { +func withConfig(t *testing.T, trustDomain string, cfg string) testPluginOpt { return func(p *Plugin) { - err := doConfigure(t, p, cfg) + err := doConfigure(t, p, trustDomain, cfg) require.NoError(t, err) } } diff --git a/pkg/agent/plugin/workloadattestor/docker/docker_test.go b/pkg/agent/plugin/workloadattestor/docker/docker_test.go index f8cabf6717..ec583c859f 100644 --- a/pkg/agent/plugin/workloadattestor/docker/docker_test.go +++ b/pkg/agent/plugin/workloadattestor/docker/docker_test.go @@ -12,8 +12,10 @@ import ( "github.com/docker/docker/api/types/container" dockerclient "github.com/docker/docker/client" "github.com/hashicorp/go-hclog" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/common/sigstore" "github.com/spiffe/spire/pkg/agent/plugin/workloadattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/clock" "github.com/spiffe/spire/test/plugintest" "github.com/spiffe/spire/test/spiretest" @@ -23,8 +25,9 @@ import ( ) const ( - testContainerID = "6469646e742065787065637420616e796f6e6520746f20726561642074686973" - testImageID = "test-image-id" + testContainerID = "6469646e742065787065637420616e796f6e6520746f20726561642074686973" + testImageID = "test-image-id" + defaultTrustDomain = "example.org" ) var disabledRetryer = &retryer{disabled: true} @@ -156,17 +159,20 @@ func TestDockerErrorContextCancel(t *testing.T) { func TestDockerConfig(t *testing.T) { for _, tt := range []struct { name string + trustDomain string expectCode codes.Code expectMsg string config string sigstoreConfigured bool }{ { - name: "success configuration", - config: `docker_version = "/123/"`, + name: "success configuration", + trustDomain: "example.org", + config: `docker_version = "/123/"`, }, { - name: "sigstore configuration", + name: "sigstore configuration", + trustDomain: "example.org", config: ` experimental { sigstore { @@ -186,7 +192,8 @@ func TestDockerConfig(t *testing.T) { sigstoreConfigured: true, }, { - name: "bad hcl", + name: "bad hcl", + trustDomain: "example.org", config: ` container_id_cgroup_matchers = [ "/docker/"`, @@ -194,7 +201,8 @@ container_id_cgroup_matchers = [ expectMsg: "unable to decode configuration:", }, { - name: "unknown configuration", + name: "unknown configuration", + trustDomain: "example.org", config: ` invalid1 = "/oh/" invalid2 = "/no/"`, @@ -207,6 +215,9 @@ invalid2 = "/no/"`, var err error plugintest.Load(t, builtin(p), new(workloadattestor.V1), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + }), plugintest.Configure(tt.config), plugintest.CaptureConfigureError(&err)) @@ -345,9 +356,12 @@ func doAttestWithContext(ctx context.Context, t *testing.T, p *Plugin) ([]string return selectorValues, nil } -func doConfigure(t *testing.T, p *Plugin, cfg string) error { +func doConfigure(t *testing.T, p *Plugin, trustDomain string, cfg string) error { var err error plugintest.Load(t, builtin(p), new(workloadattestor.V1), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), plugintest.Configure(cfg), plugintest.CaptureConfigureError(&err)) return err @@ -381,7 +395,7 @@ func withSigstoreVerifier(v sigstore.Verifier) testPluginOpt { func newTestPlugin(t *testing.T, opts ...testPluginOpt) *Plugin { p := New() - err := doConfigure(t, p, "") + err := doConfigure(t, p, defaultTrustDomain, "") require.NoError(t, err) for _, o := range opts { diff --git a/pkg/agent/plugin/workloadattestor/docker/docker_windows.go b/pkg/agent/plugin/workloadattestor/docker/docker_windows.go index 1101c8ee55..ba98477d85 100644 --- a/pkg/agent/plugin/workloadattestor/docker/docker_windows.go +++ b/pkg/agent/plugin/workloadattestor/docker/docker_windows.go @@ -5,6 +5,7 @@ package docker import ( hclog "github.com/hashicorp/go-hclog" "github.com/spiffe/spire/pkg/common/container/process" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -14,10 +15,10 @@ type OSConfig struct { DockerHost string `hcl:"docker_host" json:"docker_host"` } -func createHelper(*dockerPluginConfig, hclog.Logger) (*containerHelper, error) { +func (p *Plugin) createHelper(*dockerPluginConfig, *pluginconf.Status) *containerHelper { return &containerHelper{ ph: process.CreateHelper(), - }, nil + } } type containerHelper struct { diff --git a/pkg/agent/plugin/workloadattestor/k8s/k8s.go b/pkg/agent/plugin/workloadattestor/k8s/k8s.go index 8432c5d8b5..e52b67be62 100644 --- a/pkg/agent/plugin/workloadattestor/k8s/k8s.go +++ b/pkg/agent/plugin/workloadattestor/k8s/k8s.go @@ -25,6 +25,7 @@ import ( "github.com/spiffe/spire/pkg/agent/common/sigstore" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/valyala/fastjson" "golang.org/x/sync/singleflight" @@ -156,11 +157,104 @@ type k8sConfig struct { NodeName string ReloadInterval time.Duration DisableContainerSelectors bool + ContainerHelper ContainerHelper + sigstoreConfig *sigstore.Config Client *kubeletClient LastReload time.Time } +func (p *Plugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *k8sConfig { + // Parse HCL config payload into config struct + newConfig := new(HCLConfig) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + // Determine max poll attempts with default + maxPollAttempts := newConfig.MaxPollAttempts + if maxPollAttempts <= 0 { + maxPollAttempts = defaultMaxPollAttempts + } + + // Determine poll retry interval with default + var pollRetryInterval time.Duration + var err error + if newConfig.PollRetryInterval != "" { + pollRetryInterval, err = time.ParseDuration(newConfig.PollRetryInterval) + if err != nil { + status.ReportErrorf("unable to parse poll retry interval: %v", err) + } + } + if pollRetryInterval <= 0 { + pollRetryInterval = defaultPollRetryInterval + } + + // Determine reload interval + var reloadInterval time.Duration + if newConfig.ReloadInterval != "" { + reloadInterval, err = time.ParseDuration(newConfig.ReloadInterval) + if err != nil { + status.ReportErrorf("unable to parse reload interval: %v", err) + } + } + if reloadInterval <= 0 { + reloadInterval = defaultReloadInterval + } + + // Determine which kubelet port to hit. Default to the secure port if none + // is specified (this is backwards compatible because the read-only-port + // config value has always been required, so it should already be set in + // existing configurations that rely on it). + if newConfig.KubeletSecurePort > 0 && newConfig.KubeletReadOnlyPort > 0 { + status.ReportError("cannot use both the read-only and secure port") + } + + port := newConfig.KubeletReadOnlyPort + secure := false + if port <= 0 { + port = newConfig.KubeletSecurePort + secure = true + } + if port <= 0 { + port = defaultSecureKubeletPort + secure = true + } + + containerHelper := createHelper(p) + if err := containerHelper.Configure(newConfig, p.log); err != nil { + status.ReportError(err.Error()) + } + + // Determine the node name + nodeName := p.getNodeName(newConfig.NodeName, newConfig.NodeNameEnv) + + var sigstoreConfig *sigstore.Config + if newConfig.Experimental.Sigstore != nil { + sigstoreConfig = sigstore.NewConfigFromHCL(newConfig.Experimental.Sigstore, p.log) + } + + // return the kubelet client + return &k8sConfig{ + Secure: secure, + Port: port, + MaxPollAttempts: maxPollAttempts, + PollRetryInterval: pollRetryInterval, + SkipKubeletVerification: newConfig.SkipKubeletVerification, + TokenPath: newConfig.TokenPath, + CertificatePath: newConfig.CertificatePath, + PrivateKeyPath: newConfig.PrivateKeyPath, + UseAnonymousAuthentication: newConfig.UseAnonymousAuthentication, + KubeletCAPath: newConfig.KubeletCAPath, + NodeName: nodeName, + ReloadInterval: reloadInterval, + DisableContainerSelectors: newConfig.DisableContainerSelectors, + ContainerHelper: containerHelper, + sigstoreConfig: sigstoreConfig, + } +} + type ContainerHelper interface { Configure(config *HCLConfig, log hclog.Logger) error GetPodUIDAndContainerID(pID int32, log hclog.Logger) (types.UID, string, error) @@ -321,94 +415,18 @@ func (p *Plugin) Attest(ctx context.Context, req *workloadattestorv1.AttestReque } func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (resp *configv1.ConfigureResponse, err error) { - // Parse HCL config payload into config struct - config := new(HCLConfig) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - // Determine max poll attempts with default - maxPollAttempts := config.MaxPollAttempts - if maxPollAttempts <= 0 { - maxPollAttempts = defaultMaxPollAttempts - } - - // Determine poll retry interval with default - var pollRetryInterval time.Duration - if config.PollRetryInterval != "" { - pollRetryInterval, err = time.ParseDuration(config.PollRetryInterval) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to parse poll retry interval: %v", err) - } - } - if pollRetryInterval <= 0 { - pollRetryInterval = defaultPollRetryInterval - } - - // Determine reload interval - var reloadInterval time.Duration - if config.ReloadInterval != "" { - reloadInterval, err = time.ParseDuration(config.ReloadInterval) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to parse reload interval: %v", err) - } - } - if reloadInterval <= 0 { - reloadInterval = defaultReloadInterval - } - - // Determine which kubelet port to hit. Default to the secure port if none - // is specified (this is backwards compatible because the read-only-port - // config value has always been required, so it should already be set in - // existing configurations that rely on it). - if config.KubeletSecurePort > 0 && config.KubeletReadOnlyPort > 0 { - return nil, status.Error(codes.InvalidArgument, "cannot use both the read-only and secure port") - } - - containerHelper := createHelper(p) - if err := containerHelper.Configure(config, p.log); err != nil { + newConfig, _, err := pluginconf.Build(req, p.buildConfig) + if err != nil { return nil, err } - port := config.KubeletReadOnlyPort - secure := false - if port <= 0 { - port = config.KubeletSecurePort - secure = true - } - if port <= 0 { - port = defaultSecureKubeletPort - secure = true - } - - // Determine the node name - nodeName := p.getNodeName(config.NodeName, config.NodeNameEnv) - - // Configure the kubelet client - c := &k8sConfig{ - Secure: secure, - Port: port, - MaxPollAttempts: maxPollAttempts, - PollRetryInterval: pollRetryInterval, - SkipKubeletVerification: config.SkipKubeletVerification, - TokenPath: config.TokenPath, - CertificatePath: config.CertificatePath, - PrivateKeyPath: config.PrivateKeyPath, - UseAnonymousAuthentication: config.UseAnonymousAuthentication, - KubeletCAPath: config.KubeletCAPath, - NodeName: nodeName, - ReloadInterval: reloadInterval, - DisableContainerSelectors: config.DisableContainerSelectors, - } - - if err := p.reloadKubeletClient(c); err != nil { + if err := p.reloadKubeletClient(newConfig); err != nil { return nil, err } var sigstoreVerifier sigstore.Verifier - if config.Experimental.Sigstore != nil { - cfg := sigstore.NewConfigFromHCL(config.Experimental.Sigstore, p.log) - verifier := sigstore.NewVerifier(cfg) + if newConfig.sigstoreConfig != nil { + verifier := sigstore.NewVerifier(newConfig.sigstoreConfig) err = verifier.Init(ctx) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "error initializing sigstore verifier: %v", err) @@ -416,17 +434,22 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) sigstoreVerifier = verifier } - // Set the config - p.setConfig(c, containerHelper, sigstoreVerifier) - return &configv1.ConfigureResponse{}, nil -} - -func (p *Plugin) setConfig(config *k8sConfig, containerHelper ContainerHelper, sigstoreVerifier sigstore.Verifier) { p.mu.Lock() defer p.mu.Unlock() - p.config = config - p.containerHelper = containerHelper + p.config = newConfig + p.containerHelper = newConfig.ContainerHelper p.sigstoreVerifier = sigstoreVerifier + + return &configv1.ConfigureResponse{}, nil +} + +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (resp *configv1.ValidateResponse, err error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil } func (p *Plugin) getConfig() (*k8sConfig, ContainerHelper, sigstore.Verifier, error) { diff --git a/pkg/agent/plugin/workloadattestor/k8s/k8s_test.go b/pkg/agent/plugin/workloadattestor/k8s/k8s_test.go index 97d269334c..3c34b81443 100644 --- a/pkg/agent/plugin/workloadattestor/k8s/k8s_test.go +++ b/pkg/agent/plugin/workloadattestor/k8s/k8s_test.go @@ -18,8 +18,10 @@ import ( "testing" "time" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/common/sigstore" "github.com/spiffe/spire/pkg/agent/plugin/workloadattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/pemutil" "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/proto/spire/common" @@ -335,15 +337,17 @@ func (s *Suite) TestConfigure() { } testCases := []struct { - name string - raw string - hcl string - config *config - errCode codes.Code - errMsg string + name string + trustDomain string + raw string + hcl string + config *config + errCode codes.Code + errMsg string }{ { - name: "insecure defaults", + name: "insecure defaults", + trustDomain: "example.org", hcl: ` kubelet_read_only_port = 12345 `, @@ -356,8 +360,9 @@ func (s *Suite) TestConfigure() { }, }, { - name: "secure defaults", - hcl: ``, + name: "secure defaults", + trustDomain: "example.org", + hcl: ``, config: &config{ VerifyKubelet: true, Token: "default-token", @@ -368,7 +373,8 @@ func (s *Suite) TestConfigure() { }, }, { - name: "skip kubelet verification", + name: "skip kubelet verification", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true `, @@ -382,7 +388,8 @@ func (s *Suite) TestConfigure() { }, }, { - name: "secure overrides", + name: "secure overrides", + trustDomain: "example.org", hcl: ` kubelet_secure_port = 12345 kubelet_ca_path = "some-other-ca" @@ -401,7 +408,8 @@ func (s *Suite) TestConfigure() { }, }, { - name: "secure with keypair", + name: "secure with keypair", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true certificate_path = "cert.pem" @@ -415,7 +423,8 @@ func (s *Suite) TestConfigure() { }, }, { - name: "secure with node name", + name: "secure with node name", + trustDomain: "example.org", hcl: ` node_name = "boo" `, @@ -431,13 +440,15 @@ func (s *Suite) TestConfigure() { }, { - name: "invalid hcl", - hcl: "bad", - errCode: codes.InvalidArgument, - errMsg: "unable to decode configuration", + name: "invalid hcl", + trustDomain: "example.org", + hcl: "bad", + errCode: codes.InvalidArgument, + errMsg: "unable to decode configuration", }, { - name: "both insecure and secure ports specified", + name: "both insecure and secure ports specified", + trustDomain: "example.org", hcl: ` kubelet_read_only_port = 10255 kubelet_secure_port = 10250 @@ -446,7 +457,8 @@ func (s *Suite) TestConfigure() { errMsg: "cannot use both the read-only and secure port", }, { - name: "non-existent kubelet ca", + name: "non-existent kubelet ca", + trustDomain: "example.org", hcl: ` kubelet_ca_path = "no-such-file" `, @@ -454,7 +466,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to load kubelet CA", }, { - name: "bad kubelet ca", + name: "bad kubelet ca", + trustDomain: "example.org", hcl: ` kubelet_ca_path = "bad-pem" `, @@ -462,7 +475,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to parse kubelet CA", }, { - name: "non-existent token", + name: "non-existent token", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true token_path = "no-such-file" @@ -471,7 +485,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to load token", }, { - name: "invalid poll retry interval", + name: "invalid poll retry interval", + trustDomain: "example.org", hcl: ` kubelet_read_only_port = 10255 poll_retry_interval = "blah" @@ -480,7 +495,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to parse poll retry interval", }, { - name: "invalid reload interval", + name: "invalid reload interval", + trustDomain: "example.org", hcl: ` kubelet_read_only_port = 10255 reload_interval = "blah" @@ -489,7 +505,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to parse reload interval", }, { - name: "cert but no key", + name: "cert but no key", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true certificate_path = "cert" @@ -498,7 +515,8 @@ func (s *Suite) TestConfigure() { errMsg: "the private key path is required with the certificate path", }, { - name: "key but no cert", + name: "key but no cert", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true private_key_path = "key" @@ -507,7 +525,8 @@ func (s *Suite) TestConfigure() { errMsg: "the certificate path is required with the private key path", }, { - name: "bad cert", + name: "bad cert", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true certificate_path = "bad-pem" @@ -517,7 +536,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to load keypair", }, { - name: "non-existent cert", + name: "non-existent cert", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true certificate_path = "no-such-file" @@ -527,7 +547,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to load certificate", }, { - name: "bad key", + name: "bad key", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true certificate_path = "cert.pem" @@ -537,7 +558,8 @@ func (s *Suite) TestConfigure() { errMsg: "unable to load keypair", }, { - name: "non-existent key", + name: "non-existent key", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true certificate_path = "cert.pem" @@ -555,6 +577,9 @@ func (s *Suite) TestConfigure() { var err error plugintest.Load(s.T(), builtin(p), nil, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(testCase.trustDomain), + }), plugintest.Configure(testCase.hcl), plugintest.CaptureConfigureError(&err)) @@ -598,12 +623,14 @@ func (s *Suite) TestConfigure() { func (s *Suite) TestConfigureWithSigstore() { cases := []struct { name string + trustDomain string hcl string expectedError string want *sigstore.Config }{ { - name: "complete sigstore configuration", + name: "complete sigstore configuration", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true experimental { @@ -626,7 +653,8 @@ func (s *Suite) TestConfigureWithSigstore() { expectedError: "", }, { - name: "empty sigstore configuration", + name: "empty sigstore configuration", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true experimental { sigstore {} } @@ -634,7 +662,8 @@ func (s *Suite) TestConfigureWithSigstore() { expectedError: "", }, { - name: "invalid HCL", + name: "invalid HCL", + trustDomain: "example.org", hcl: ` skip_kubelet_verification = true experimental { sigstore = "invalid" } @@ -650,6 +679,9 @@ func (s *Suite) TestConfigureWithSigstore() { var err error plugintest.Load(s.T(), builtin(p), nil, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tc.trustDomain), + }), plugintest.Configure(tc.hcl), plugintest.CaptureConfigureError(&err)) @@ -713,6 +745,9 @@ func (s *Suite) loadPlugin(configuration string) workloadattestor.WorkloadAttest p := s.newPlugin() plugintest.Load(s.T(), builtin(p), v1, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configure(configuration), ) diff --git a/pkg/agent/plugin/workloadattestor/systemd/systemd_windows.go b/pkg/agent/plugin/workloadattestor/systemd/systemd_windows.go index c35bdce462..338725f795 100644 --- a/pkg/agent/plugin/workloadattestor/systemd/systemd_windows.go +++ b/pkg/agent/plugin/workloadattestor/systemd/systemd_windows.go @@ -31,3 +31,7 @@ func New() *Plugin { func (p *Plugin) Configure(context.Context, *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { return nil, status.Error(codes.Unimplemented, "plugin not supported in this platform") } + +func (p *Plugin) Validate(context.Context, *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + return nil, status.Error(codes.Unimplemented, "plugin not supported in this platform") +} diff --git a/pkg/agent/plugin/workloadattestor/unix/unix_posix.go b/pkg/agent/plugin/workloadattestor/unix/unix_posix.go index 90a131af7d..10547d4c77 100644 --- a/pkg/agent/plugin/workloadattestor/unix/unix_posix.go +++ b/pkg/agent/plugin/workloadattestor/unix/unix_posix.go @@ -20,6 +20,7 @@ import ( workloadattestorv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/agent/workloadattestor/v1" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -91,6 +92,16 @@ type Configuration struct { WorkloadSizeLimit int64 `hcl:"workload_size_limit"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("failed to decode configuration: %v", err) + return nil + } + + return newConfig +} + type Plugin struct { workloadattestorv1.UnsafeWorkloadAttestorServer configv1.UnsafeConfigServer @@ -195,14 +206,27 @@ func (p *Plugin) Attest(_ context.Context, req *workloadattestorv1.AttestRequest } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := new(Configuration) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - p.setConfig(config) + + p.mu.Lock() + p.config = newConfig + p.mu.Unlock() + return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *Plugin) getConfig() (*Configuration, error) { p.mu.Lock() config := p.config @@ -213,12 +237,6 @@ func (p *Plugin) getConfig() (*Configuration, error) { return config, nil } -func (p *Plugin) setConfig(config *Configuration) { - p.mu.Lock() - p.config = config - p.mu.Unlock() -} - func (p *Plugin) getUID(proc processInfo) (string, error) { uids, err := proc.Uids() if err != nil { diff --git a/pkg/agent/plugin/workloadattestor/unix/unix_posix_test.go b/pkg/agent/plugin/workloadattestor/unix/unix_posix_test.go index 46847e2906..52017323d0 100644 --- a/pkg/agent/plugin/workloadattestor/unix/unix_posix_test.go +++ b/pkg/agent/plugin/workloadattestor/unix/unix_posix_test.go @@ -14,7 +14,9 @@ import ( "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/workloadattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/test/plugintest" "github.com/spiffe/spire/test/spiretest" "github.com/stretchr/testify/require" @@ -52,6 +54,7 @@ func (s *Suite) TestAttest() { } testCases := []struct { name string + trustDomain string pid int selectorValues []string config string @@ -59,20 +62,23 @@ func (s *Suite) TestAttest() { expectMsg string }{ { - name: "pid with no uids", - pid: 1, - expectCode: codes.Internal, - expectMsg: "workloadattestor(unix): UIDs lookup: no UIDs for process", + name: "pid with no uids", + trustDomain: "example.org", + pid: 1, + expectCode: codes.Internal, + expectMsg: "workloadattestor(unix): UIDs lookup: no UIDs for process", }, { - name: "fail to get uids", - pid: 2, - expectCode: codes.Internal, - expectMsg: "workloadattestor(unix): UIDs lookup: unable to get UIDs for PID 2", + name: "fail to get uids", + trustDomain: "example.org", + pid: 2, + expectCode: codes.Internal, + expectMsg: "workloadattestor(unix): UIDs lookup: unable to get UIDs for PID 2", }, { - name: "user lookup fails", - pid: 3, + name: "user lookup fails", + trustDomain: "example.org", + pid: 3, selectorValues: []string{ "uid:1999", "gid:2000", @@ -81,20 +87,23 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "pid with no gids", - pid: 4, - expectCode: codes.Internal, - expectMsg: "workloadattestor(unix): GIDs lookup: no GIDs for process", + name: "pid with no gids", + trustDomain: "example.org", + pid: 4, + expectCode: codes.Internal, + expectMsg: "workloadattestor(unix): GIDs lookup: no GIDs for process", }, { - name: "fail to get gids", - pid: 5, - expectCode: codes.Internal, - expectMsg: "workloadattestor(unix): GIDs lookup: unable to get GIDs for PID 5", + name: "fail to get gids", + trustDomain: "example.org", + pid: 5, + expectCode: codes.Internal, + expectMsg: "workloadattestor(unix): GIDs lookup: unable to get GIDs for PID 5", }, { - name: "group lookup fails", - pid: 6, + name: "group lookup fails", + trustDomain: "example.org", + pid: 6, selectorValues: []string{ "uid:1000", "user:u1000", @@ -103,8 +112,9 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "primary user and gid", - pid: 7, + name: "primary user and gid", + trustDomain: "example.org", + pid: 7, selectorValues: []string{ "uid:1000", "user:u1000", @@ -114,8 +124,9 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "effective user and gid", - pid: 8, + name: "effective user and gid", + trustDomain: "example.org", + pid: 8, selectorValues: []string{ "uid:1100", "user:u1100", @@ -125,30 +136,34 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "fail to get process binary path", - pid: 9, - config: "discover_workload_path = true", - expectCode: codes.Internal, - expectMsg: "workloadattestor(unix): path lookup: unable to get EXE for PID 9", + name: "fail to get process binary path", + trustDomain: "example.org", + pid: 9, + config: "discover_workload_path = true", + expectCode: codes.Internal, + expectMsg: "workloadattestor(unix): path lookup: unable to get EXE for PID 9", }, { - name: "fail to hash process binary", - pid: 10, - config: "discover_workload_path = true", - expectCode: codes.Internal, - expectMsg: fmt.Sprintf("workloadattestor(unix): SHA256 digest: open %s: no such file or directory", unreadableExePath), + name: "fail to hash process binary", + trustDomain: "example.org", + pid: 10, + config: "discover_workload_path = true", + expectCode: codes.Internal, + expectMsg: fmt.Sprintf("workloadattestor(unix): SHA256 digest: open %s: no such file or directory", unreadableExePath), }, { - name: "process binary exceeds size limits", - pid: 11, - config: "discover_workload_path = true\nworkload_size_limit = 2", - expectCode: codes.Internal, - expectMsg: fmt.Sprintf("workloadattestor(unix): SHA256 digest: workload %s exceeds size limit (4 > 2)", filepath.Join(s.dir, "exe")), + name: "process binary exceeds size limits", + trustDomain: "example.org", + pid: 11, + config: "discover_workload_path = true\nworkload_size_limit = 2", + expectCode: codes.Internal, + expectMsg: fmt.Sprintf("workloadattestor(unix): SHA256 digest: workload %s exceeds size limit (4 > 2)", filepath.Join(s.dir, "exe")), }, { - name: "success getting path and hashing process binary", - pid: 12, - config: "discover_workload_path = true", + name: "success getting path and hashing process binary", + trustDomain: "example.org", + pid: 12, + config: "discover_workload_path = true", selectorValues: []string{ "uid:1000", "user:u1000", @@ -160,9 +175,10 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "success getting path and hashing process binary", - pid: 12, - config: "discover_workload_path = true", + name: "success getting path and hashing process binary", + trustDomain: "example.org", + pid: 12, + config: "discover_workload_path = true", selectorValues: []string{ "uid:1000", "user:u1000", @@ -174,9 +190,10 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "success getting path, disabled hashing process binary", - pid: 12, - config: "discover_workload_path = true\nworkload_size_limit = -1", + name: "success getting path, disabled hashing process binary", + trustDomain: "example.org", + pid: 12, + config: "discover_workload_path = true\nworkload_size_limit = -1", selectorValues: []string{ "uid:1000", "user:u1000", @@ -187,8 +204,9 @@ func (s *Suite) TestAttest() { expectCode: codes.OK, }, { - name: "pid with supplementary gids", - pid: 13, + name: "pid with supplementary gids", + trustDomain: "example.org", + pid: 13, selectorValues: []string{ "uid:1000", "user:u1000", @@ -205,10 +223,11 @@ func (s *Suite) TestAttest() { }, }, { - name: "fail to get supplementary gids", - pid: 14, - expectCode: codes.Internal, - expectMsg: "workloadattestor(unix): supplementary GIDs lookup: some error for PID 14", + name: "fail to get supplementary gids", + trustDomain: "example.org", + pid: 14, + expectCode: codes.Internal, + expectMsg: "workloadattestor(unix): supplementary GIDs lookup: some error for PID 14", }, } @@ -220,7 +239,7 @@ func (s *Suite) TestAttest() { s.T().Run(testCase.name, func(t *testing.T) { defer s.logHook.Reset() - p := s.loadPlugin(t, testCase.config) + p := s.loadPlugin(t, testCase.trustDomain, testCase.config) selectors, err := p.Attest(ctx, testCase.pid) spiretest.RequireGRPCStatus(t, err, testCase.expectCode, testCase.expectMsg) if testCase.expectCode != codes.OK { @@ -245,12 +264,15 @@ func (s *Suite) writeFile(path string, data []byte) { s.Require().NoError(os.WriteFile(filepath.Join(s.dir, path), data, 0600)) } -func (s *Suite) loadPlugin(t *testing.T, config string) workloadattestor.WorkloadAttestor { +func (s *Suite) loadPlugin(t *testing.T, trustDomain string, config string) workloadattestor.WorkloadAttestor { p := s.newPlugin() v1 := new(workloadattestor.V1) plugintest.Load(t, builtin(p), v1, plugintest.Log(s.log), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), plugintest.Configure(config)) return v1 } diff --git a/pkg/agent/plugin/workloadattestor/unix/unix_windows.go b/pkg/agent/plugin/workloadattestor/unix/unix_windows.go index ed5befd262..bcf109f1b7 100644 --- a/pkg/agent/plugin/workloadattestor/unix/unix_windows.go +++ b/pkg/agent/plugin/workloadattestor/unix/unix_windows.go @@ -31,3 +31,7 @@ func New() *Plugin { func (p *Plugin) Configure(context.Context, *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { return nil, status.Error(codes.Unimplemented, "plugin not supported in this platform") } + +func (p *Plugin) Validate(context.Context, *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + return nil, status.Error(codes.Unimplemented, "plugin not supported in this platform") +} diff --git a/pkg/agent/plugin/workloadattestor/windows/windows_posix.go b/pkg/agent/plugin/workloadattestor/windows/windows_posix.go index 30840b9bfa..64ef11c3ee 100644 --- a/pkg/agent/plugin/workloadattestor/windows/windows_posix.go +++ b/pkg/agent/plugin/workloadattestor/windows/windows_posix.go @@ -31,3 +31,7 @@ func New() *Plugin { func (p *Plugin) Configure(context.Context, *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { return nil, status.Error(codes.Unimplemented, "plugin not supported in this platform") } + +func (p *Plugin) Validate(context.Context, *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + return nil, status.Error(codes.Unimplemented, "plugin not supported in this platform") +} diff --git a/pkg/agent/plugin/workloadattestor/windows/windows_windows.go b/pkg/agent/plugin/workloadattestor/windows/windows_windows.go index 8e185e6ed1..3e9d034be6 100644 --- a/pkg/agent/plugin/workloadattestor/windows/windows_windows.go +++ b/pkg/agent/plugin/workloadattestor/windows/windows_windows.go @@ -12,6 +12,7 @@ import ( workloadattestorv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/agent/workloadattestor/v1" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/pkg/common/util" "golang.org/x/sys/windows" @@ -36,6 +37,16 @@ type Configuration struct { WorkloadSizeLimit int64 `hcl:"workload_size_limit"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("failed to decode configuration: %v", err) + return nil + } + + return newConfig +} + type Plugin struct { workloadattestorv1.UnsafeWorkloadAttestorServer configv1.UnsafeConfigServer @@ -175,14 +186,27 @@ func (p *Plugin) newProcessInfo(pid int32, queryPath bool) (*processInfo, error) } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := new(Configuration) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - p.setConfig(config) + + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig + return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *Plugin) getConfig() (*Configuration, error) { p.mu.Lock() config := p.config @@ -193,12 +217,6 @@ func (p *Plugin) getConfig() (*Configuration, error) { return config, nil } -func (p *Plugin) setConfig(config *Configuration) { - p.mu.Lock() - p.config = config - p.mu.Unlock() -} - type processQueryer interface { // OpenProcess returns an open handle to the specified process id. OpenProcess(int32) (windows.Handle, error) diff --git a/pkg/agent/plugin/workloadattestor/windows/windows_windows_test.go b/pkg/agent/plugin/workloadattestor/windows/windows_windows_test.go index b0dcbacdae..a37b29482b 100644 --- a/pkg/agent/plugin/workloadattestor/windows/windows_windows_test.go +++ b/pkg/agent/plugin/workloadattestor/windows/windows_windows_test.go @@ -12,7 +12,9 @@ import ( "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" "github.com/spiffe/spire/pkg/agent/plugin/workloadattestor" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/telemetry" "github.com/spiffe/spire/test/plugintest" "github.com/spiffe/spire/test/spiretest" @@ -49,6 +51,7 @@ func TestAttest(t *testing.T) { testCases := []struct { name string + trustDomain string expectSelectors []string config string pq *fakeProcessQuery @@ -57,7 +60,8 @@ func TestAttest(t *testing.T) { expectLogs []spiretest.LogEntry }{ { - name: "successful no groups", + name: "successful no groups", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -72,7 +76,8 @@ func TestAttest(t *testing.T) { expectCode: codes.OK, }, { - name: "successful with groups all enabled", + name: "successful with groups all enabled", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -92,7 +97,8 @@ func TestAttest(t *testing.T) { expectCode: codes.OK, }, { - name: "successful with not enabled group", + name: "successful with not enabled group", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -110,7 +116,8 @@ func TestAttest(t *testing.T) { expectCode: codes.OK, }, { - name: "successful getting path and hashing process binary", + name: "successful getting path and hashing process binary", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -129,7 +136,8 @@ func TestAttest(t *testing.T) { expectCode: codes.OK, }, { - name: "successful getting path, disabled hashing process binary", + name: "successful getting path, disabled hashing process binary", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -147,7 +155,8 @@ func TestAttest(t *testing.T) { expectCode: codes.OK, }, { - name: "failed to get binary path", + name: "failed to get binary path", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -161,7 +170,8 @@ func TestAttest(t *testing.T) { expectMsg: "workloadattestor(windows): failed to get process information: error getting process exe: get process exe error", }, { - name: "failed to hash binary", + name: "failed to hash binary", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -175,7 +185,8 @@ func TestAttest(t *testing.T) { expectMsg: "workloadattestor(windows): SHA256 digest: open unreadable: The system cannot find the file specified.", }, { - name: "binary exceeds limit size", + name: "binary exceeds limit size", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -189,7 +200,8 @@ func TestAttest(t *testing.T) { expectMsg: fmt.Sprintf("workloadattestor(windows): SHA256 digest: workload %s exceeds size limit (4 > 2)", exe), }, { - name: "OpenProcess error", + name: "OpenProcess error", + trustDomain: "example.org", pq: &fakeProcessQuery{ openProcessErr: errors.New("open process error"), }, @@ -197,7 +209,8 @@ func TestAttest(t *testing.T) { expectMsg: "workloadattestor(windows): failed to get process information: failed to open process: open process error", }, { - name: "OpenProcessToken error", + name: "OpenProcessToken error", + trustDomain: "example.org", pq: &fakeProcessQuery{ openProcessTokenErr: errors.New("open process token error"), handle: windows.InvalidHandle, @@ -206,7 +219,8 @@ func TestAttest(t *testing.T) { expectMsg: "workloadattestor(windows): failed to get process information: failed to open the access token associated with the process: open process token error", }, { - name: "GetTokenUser error", + name: "GetTokenUser error", + trustDomain: "example.org", pq: &fakeProcessQuery{ getTokenUserErr: errors.New("get token user error"), handle: windows.InvalidHandle, @@ -215,7 +229,8 @@ func TestAttest(t *testing.T) { expectMsg: "workloadattestor(windows): failed to get process information: failed to retrieve user account information from access token: get token user error", }, { - name: "GetTokenGroups error", + name: "GetTokenGroups error", + trustDomain: "example.org", pq: &fakeProcessQuery{ getTokenGroupsErr: errors.New("get token groups error"), handle: windows.InvalidHandle, @@ -225,7 +240,8 @@ func TestAttest(t *testing.T) { expectMsg: "workloadattestor(windows): failed to get process information: failed to retrieve group accounts information from access token: get token groups error", }, { - name: "LookupAccount failure", + name: "LookupAccount failure", + trustDomain: "example.org", pq: &fakeProcessQuery{ lookupAccountErr: errors.New("lookup error"), handle: windows.InvalidHandle, @@ -260,7 +276,8 @@ func TestAttest(t *testing.T) { }, }, { - name: "close handle error", + name: "close handle error", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -286,7 +303,8 @@ func TestAttest(t *testing.T) { }, }, { - name: "close process token error", + name: "close process token error", + trustDomain: "example.org", pq: &fakeProcessQuery{ handle: windows.InvalidHandle, tokenUser: &windows.Tokenuser{User: windows.SIDAndAttributes{Sid: sidUser}}, @@ -316,7 +334,7 @@ func TestAttest(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { test := setupTest() - p, err := test.loadPlugin(t, testCase.pq, testCase.config) + p, err := test.loadPlugin(t, testCase.pq, testCase.trustDomain, testCase.config) require.NoError(t, err) selectors, err := p.Attest(ctx, testPID) @@ -342,11 +360,11 @@ func TestConfigure(t *testing.T) { test := setupTest() // malformed configuration - _, err := test.loadPlugin(t, &fakeProcessQuery{}, "malformed") + _, err := test.loadPlugin(t, &fakeProcessQuery{}, "example.org", "malformed") spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "failed to decode configuration") // success - _, err = test.loadPlugin(t, &fakeProcessQuery{}, "discover_workload_path = true\nworkload_size_limit = 2") + _, err = test.loadPlugin(t, &fakeProcessQuery{}, "example.org", "discover_workload_path = true\nworkload_size_limit = 2") require.NoError(t, err) } @@ -355,7 +373,7 @@ type windowsTest struct { logHook *test.Hook } -func (w *windowsTest) loadPlugin(t *testing.T, q *fakeProcessQuery, config string) (workloadattestor.WorkloadAttestor, error) { +func (w *windowsTest) loadPlugin(t *testing.T, q *fakeProcessQuery, trustDomain string, config string) (workloadattestor.WorkloadAttestor, error) { var err error p := New() p.q = q @@ -363,6 +381,9 @@ func (w *windowsTest) loadPlugin(t *testing.T, q *fakeProcessQuery, config strin v1 := new(workloadattestor.V1) plugintest.Load(t, builtin(p), v1, plugintest.Log(w.log), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(trustDomain), + }), plugintest.Configure(config), plugintest.CaptureConfigureError(&err)) return v1, err diff --git a/pkg/common/catalog/configure.go b/pkg/common/catalog/configure.go index d31660c4f6..5c381cd88b 100644 --- a/pkg/common/catalog/configure.go +++ b/pkg/common/catalog/configure.go @@ -29,6 +29,7 @@ func (c CoreConfig) v1() *configv1.CoreConfiguration { type Configurer interface { Configure(ctx context.Context, coreConfig CoreConfig, configuration string) error + Validate(ctx context.Context, coreConfig CoreConfig, configuration string) error } type ConfigurerFunc func(ctx context.Context, coreConfig CoreConfig, configuration string) error @@ -37,6 +38,10 @@ func (fn ConfigurerFunc) Configure(ctx context.Context, coreConfig CoreConfig, c return fn(ctx, coreConfig, configuration) } +func (fn ConfigurerFunc) Validate(ctx context.Context, coreConfig CoreConfig, configuration string) error { + return fn(ctx, coreConfig, configuration) +} + func ConfigurePlugin(ctx context.Context, coreConfig CoreConfig, configurer Configurer, dataSource DataSource, lastHash string) (string, error) { data, err := dataSource.Load() if err != nil { @@ -172,12 +177,24 @@ func (v1 *configurerV1) Configure(ctx context.Context, coreConfig CoreConfig, hc return err } +func (v1 *configurerV1) Validate(ctx context.Context, coreConfig CoreConfig, hclConfiguration string) error { + _, err := v1.ConfigServiceClient.Validate(ctx, &configv1.ValidateRequest{ + CoreConfiguration: coreConfig.v1(), + HclConfiguration: hclConfiguration, + }) + return err +} + type configurerUnsupported struct{} func (c configurerUnsupported) Configure(context.Context, CoreConfig, string) error { return status.Error(codes.FailedPrecondition, "plugin does not support a configuration interface") } +func (c configurerUnsupported) Validate(context.Context, CoreConfig, string) error { + return status.Error(codes.FailedPrecondition, "plugin does not support a validation interface") +} + func hashData(data string) string { h := sha512.New() _, _ = io.Copy(h, strings.NewReader(data)) diff --git a/pkg/common/plugin/sshpop/sshpop.go b/pkg/common/plugin/sshpop/sshpop.go index 8eab7dbf7c..288340b628 100644 --- a/pkg/common/plugin/sshpop/sshpop.go +++ b/pkg/common/plugin/sshpop/sshpop.go @@ -9,10 +9,11 @@ import ( "github.com/hashicorp/hcl" "github.com/spiffe/go-spiffe/v2/spiffeid" + configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/agentpathtemplate" + "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "golang.org/x/crypto/ssh" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) const ( @@ -55,6 +56,35 @@ type Server struct { type ClientConfig struct { HostKeyPath string `hcl:"host_key_path"` HostCertPath string `hcl:"host_cert_path"` + + cert *ssh.Certificate + signer ssh.Signer +} + +type ClientConfigRequest struct { + coreConfig *configv1.CoreConfiguration + hclText string +} + +func (ccr *ClientConfigRequest) GetCoreConfiguration() *configv1.CoreConfiguration { + return ccr.coreConfig +} + +func (ccr *ClientConfigRequest) GetHclConfiguration() string { + return ccr.hclText +} + +type ServerConfigRequest struct { + coreConfig *configv1.CoreConfiguration + hclText string +} + +func (scr *ServerConfigRequest) GetCoreConfiguration() *configv1.CoreConfiguration { + return scr.coreConfig +} + +func (scr *ServerConfigRequest) GetHclConfiguration() string { + return scr.hclText } // ServerConfig configures the server. @@ -65,31 +95,115 @@ type ServerConfig struct { // the certificate's valid principals. See CanonicalDomains in ssh_config(5). CanonicalDomain string `hcl:"canonical_domain"` AgentPathTemplate string `hcl:"agent_path_template"` + + certChecker *ssh.CertChecker + agentPathTemplate *agentpathtemplate.Template + trustDomain spiffeid.TrustDomain } -func NewClient(configString string) (*Client, error) { - config := new(ClientConfig) - if err := hcl.Decode(config, configString); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode configuration: %v", err) +func BuildServerConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *ServerConfig { + newConfig := new(ServerConfig) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("failed to decode configuration: %v", err) + return nil + } + + newConfig.trustDomain = coreConfig.TrustDomain + + if newConfig.CertAuthorities == nil && newConfig.CertAuthoritiesPath == "" { + status.ReportErrorf("missing required config value for \"cert_authorities\" or \"cert_authorities_path\"") + } + var certAuthorities []string + if newConfig.CertAuthorities != nil { + certAuthorities = append(certAuthorities, newConfig.CertAuthorities...) + } + if newConfig.CertAuthoritiesPath != "" { + fileCertAuthorities, err := pubkeysFromPath(newConfig.CertAuthoritiesPath) + if err != nil { + status.ReportErrorf("failed to get cert authorities from file: %v", err) + } + certAuthorities = append(certAuthorities, fileCertAuthorities...) } - config.HostKeyPath = stringOrDefault(config.HostKeyPath, defaultHostKeyPath) - config.HostCertPath = stringOrDefault(config.HostCertPath, defaultHostCertPath) - keyBytes, err := os.ReadFile(config.HostKeyPath) + + certChecker, err := certCheckerFromPubkeys(certAuthorities) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to read host key file: %v", err) + status.ReportErrorf("failed to create cert checker: %v", err) + } + newConfig.certChecker = certChecker + + newConfig.agentPathTemplate = DefaultAgentPathTemplate + if len(newConfig.AgentPathTemplate) != 0 { + tmpl, err := agentpathtemplate.Parse(newConfig.AgentPathTemplate) + if err != nil { + status.ReportErrorf("failed to parse agent svid template: %q", newConfig.AgentPathTemplate) + } else { + newConfig.agentPathTemplate = tmpl + } + } + + return newConfig +} + +func (sc *ServerConfig) NewServer() *Server { + return &Server{ + certChecker: sc.certChecker, + agentPathTemplate: sc.agentPathTemplate, + trustDomain: sc.trustDomain, + canonicalDomain: sc.CanonicalDomain, + } +} + +func BuildClientConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *ClientConfig { + newConfig := new(ClientConfig) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("failed to decode configuration: %v", err) + return nil } - certBytes, err := os.ReadFile(config.HostCertPath) + + newConfig.HostKeyPath = stringOrDefault(newConfig.HostKeyPath, defaultHostKeyPath) + newConfig.HostCertPath = stringOrDefault(newConfig.HostCertPath, defaultHostCertPath) + + keyBytes, err := os.ReadFile(newConfig.HostKeyPath) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to read host cert file: %v", err) + status.ReportErrorf("failed to read host key file: %v", err) } - cert, signer, err := getCertAndSignerFromBytes(certBytes, keyBytes) + certBytes, err := os.ReadFile(newConfig.HostCertPath) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to get cert and signer from pem: %v", err) + status.ReportErrorf("failed to read host cert file: %v", err) + } + if keyBytes != nil && certBytes != nil { + cert, signer, err := getCertAndSignerFromBytes(certBytes, keyBytes) + if err != nil { + status.ReportErrorf("failed to get cert and signer from pem: %v", err) + } + newConfig.cert = cert + newConfig.signer = signer } + + return newConfig +} + +func (cc *ClientConfig) NewClient() *Client { return &Client{ - cert: cert, - signer: signer, - }, nil + cert: cc.cert, + signer: cc.signer, + } +} + +func NewClient(trustDomain string, configString string) (*Client, error) { + request := &ClientConfigRequest{ + coreConfig: &configv1.CoreConfiguration{ + TrustDomain: fmt.Sprintf("spiffe://%s", trustDomain), + }, + hclText: configString, + } + + newClientConfig, _, err := pluginconf.Build(request, BuildClientConfig) + if err != nil { + return nil, err + } + + return newClientConfig.NewClient(), nil } func stringOrDefault(configValue, defaultValue string) string { @@ -116,46 +230,19 @@ func getCertAndSignerFromBytes(certBytes, keyBytes []byte) (*ssh.Certificate, ss } func NewServer(trustDomain, configString string) (*Server, error) { - td, err := spiffeid.TrustDomainFromString(trustDomain) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "trust_domain global configuration is invalid: %v", err) - } - config := new(ServerConfig) - if err := hcl.Decode(config, configString); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode configuration: %v", err) - } - if config.CertAuthorities == nil && config.CertAuthoritiesPath == "" { - return nil, status.Errorf(codes.InvalidArgument, "missing required config value for \"cert_authorities\" or \"cert_authorities_path\"") - } - var certAuthorities []string - if config.CertAuthorities != nil { - certAuthorities = append(certAuthorities, config.CertAuthorities...) - } - if config.CertAuthoritiesPath != "" { - fileCertAuthorities, err := pubkeysFromPath(config.CertAuthoritiesPath) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to get cert authorities from file: %v", err) - } - certAuthorities = append(certAuthorities, fileCertAuthorities...) + request := &ServerConfigRequest{ + coreConfig: &configv1.CoreConfiguration{ + TrustDomain: trustDomain, + }, + hclText: configString, } - certChecker, err := certCheckerFromPubkeys(certAuthorities) + + newServerConfig, _, err := pluginconf.Build(request, BuildServerConfig) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to create cert checker: %v", err) - } - agentPathTemplate := DefaultAgentPathTemplate - if len(config.AgentPathTemplate) > 0 { - tmpl, err := agentpathtemplate.Parse(config.AgentPathTemplate) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse agent svid template: %q", config.AgentPathTemplate) - } - agentPathTemplate = tmpl + return nil, err } - return &Server{ - certChecker: certChecker, - agentPathTemplate: agentPathTemplate, - trustDomain: td, - canonicalDomain: config.CanonicalDomain, - }, nil + + return newServerConfig.NewServer(), nil } func pubkeysFromPath(pubkeysPath string) ([]string, error) { diff --git a/pkg/common/plugin/sshpop/sshpop_test.go b/pkg/common/plugin/sshpop/sshpop_test.go index 2f3815040f..4bd32d8bd7 100644 --- a/pkg/common/plugin/sshpop/sshpop_test.go +++ b/pkg/common/plugin/sshpop/sshpop_test.go @@ -56,7 +56,7 @@ func TestNewClient(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.desc, func(t *testing.T) { - c, err := NewClient(tt.configString) + c, err := NewClient("example.org", tt.configString) if tt.expectErr != "" { require.Error(t, err) require.Contains(t, err.Error(), tt.expectErr) @@ -78,7 +78,7 @@ func TestNewServer(t *testing.T) { }{ { desc: "missing trust domain", - expectErr: "trust_domain global configuration is invalid", + expectErr: "server core configuration must contain trust_domain", }, { desc: "bad config", diff --git a/pkg/common/pluginconf/pluginconf.go b/pkg/common/pluginconf/pluginconf.go new file mode 100644 index 0000000000..4f74d44257 --- /dev/null +++ b/pkg/common/pluginconf/pluginconf.go @@ -0,0 +1,63 @@ +package pluginconf + +import ( + "fmt" + + "github.com/spiffe/go-spiffe/v2/spiffeid" + configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" + "github.com/spiffe/spire/pkg/common/catalog" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +type Status struct { + notes []string + err error +} + +func (s *Status) ReportInfo(message string) { + s.notes = append(s.notes, message) +} + +func (s *Status) ReportInfof(format string, args ...any) { + s.ReportInfo(fmt.Sprintf(format, args...)) +} + +func (s *Status) ReportError(message string) { + if s.err == nil { + s.err = status.Error(codes.InvalidArgument, message) + } + s.notes = append(s.notes, message) +} + +func (s *Status) ReportErrorf(format string, args ...any) { + s.ReportError(fmt.Sprintf(format, args...)) +} + +type Request interface { + GetCoreConfiguration() *configv1.CoreConfiguration + GetHclConfiguration() string +} + +func Build[C any](req Request, build func(coreConfig catalog.CoreConfig, hclText string, s *Status) *C) (*C, []string, error) { + var s Status + var coreConfig catalog.CoreConfig + + requestCoreConfig := req.GetCoreConfiguration() + + switch { + case requestCoreConfig == nil: + s.ReportError("server core configuration is required") + case requestCoreConfig.TrustDomain == "": + s.ReportError("server core configuration must contain trust_domain") + default: + var err error + coreConfig.TrustDomain, err = spiffeid.TrustDomainFromString(requestCoreConfig.TrustDomain) + if err != nil { + s.ReportErrorf("server core configuration trust_domain is malformed: %v", err) + } + } + + config := build(coreConfig, req.GetHclConfiguration(), &s) + return config, s.notes, s.err +} diff --git a/pkg/server/plugin/bundlepublisher/awsrolesanywhere/awsrolesanywhere.go b/pkg/server/plugin/bundlepublisher/awsrolesanywhere/awsrolesanywhere.go index 803566c1bc..f43ede99ff 100644 --- a/pkg/server/plugin/bundlepublisher/awsrolesanywhere/awsrolesanywhere.go +++ b/pkg/server/plugin/bundlepublisher/awsrolesanywhere/awsrolesanywhere.go @@ -14,6 +14,7 @@ import ( "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -43,6 +44,24 @@ type Config struct { TrustAnchorID string `hcl:"trust_anchor_id" json:"trust_anchor_id"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.Region == "" { + status.ReportError("configuration is missing the region") + } + + if newConfig.TrustAnchorID == "" { + status.ReportError("configuration is missing the trust anchor id") + } + + return newConfig +} + // Plugin is the main representation of this bundle publisher plugin. type Plugin struct { bundlepublisherv1.UnsafeBundlePublisherServer @@ -66,12 +85,12 @@ func (p *Plugin) SetLogger(log hclog.Logger) { // Configure configures the plugin. func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := parseAndValidateConfig(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - awsCfg, err := newAWSConfig(ctx, config) + awsCfg, err := newAWSConfig(ctx, newConfig) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create client configuration: %v", err) } @@ -81,11 +100,23 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) } p.rolesAnywhereClient = rolesAnywhere - p.setConfig(config) + p.configMtx.Lock() + defer p.configMtx.Unlock() + p.config = newConfig + p.setBundle(nil) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // PublishBundle puts the bundle in the Roles Anywhere trust anchor, with // the configured id. func (p *Plugin) PublishBundle(ctx context.Context, req *bundlepublisherv1.PublishBundleRequest) (*bundlepublisherv1.PublishBundleResponse, error) { @@ -166,14 +197,6 @@ func (p *Plugin) setBundle(bundle *types.Bundle) { p.bundle = bundle } -// setConfig sets the configuration for the plugin. -func (p *Plugin) setConfig(config *Config) { - p.configMtx.Lock() - defer p.configMtx.Unlock() - - p.config = config -} - // builtin creates a new BundlePublisher built-in plugin. func builtin(p *Plugin) catalog.BuiltIn { return catalog.MakeBuiltIn(pluginName, @@ -190,22 +213,3 @@ func newPlugin(newRolesAnywhereClientFunc func(c aws.Config) (rolesAnywhere, err }, } } - -// parseAndValidateConfig returns an error if any configuration provided does -// not meet acceptable criteria -func parseAndValidateConfig(c string) (*Config, error) { - config := new(Config) - - if err := hcl.Decode(config, c); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.Region == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the region") - } - - if config.TrustAnchorID == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the trust anchor id") - } - return config, nil -} diff --git a/pkg/server/plugin/bundlepublisher/awss3/awss3.go b/pkg/server/plugin/bundlepublisher/awss3/awss3.go index 7339815cd9..59c64a811f 100644 --- a/pkg/server/plugin/bundlepublisher/awss3/awss3.go +++ b/pkg/server/plugin/bundlepublisher/awss3/awss3.go @@ -14,6 +14,7 @@ import ( "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -49,6 +50,45 @@ type Config struct { bundleFormat bundleformat.Format } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.Region == "" { + status.ReportError("configuration is missing the region") + } + if newConfig.Bucket == "" { + status.ReportError("configuration is missing the bucket name") + } + if newConfig.ObjectKey == "" { + status.ReportError("configuration is missing the object key") + } + if newConfig.Format == "" { + status.ReportError("configuration is missing the bundle format") + } + + bundleFormat, err := bundleformat.FromString(newConfig.Format) + if err != nil { + status.ReportErrorf("could not parse bundle format from configuration: %v", err) + } else { + // This plugin only supports some bundleformats. + switch bundleFormat { + case bundleformat.JWKS: + case bundleformat.SPIFFE: + case bundleformat.PEM: + default: + status.ReportErrorf("bundle format %q is not supported", newConfig.Format) + } + newConfig.bundleFormat = bundleFormat + } + + return newConfig +} + // Plugin is the main representation of this bundle publisher plugin. type Plugin struct { bundlepublisherv1.UnsafeBundlePublisherServer @@ -72,12 +112,13 @@ func (p *Plugin) SetLogger(log hclog.Logger) { // Configure configures the plugin. func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := parseAndValidateConfig(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - awsCfg, err := newAWSConfig(ctx, config) + // seems wrong to change plugin s3Client before config change + awsCfg, err := newAWSConfig(ctx, newConfig) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create client configuration: %v", err) } @@ -87,11 +128,20 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) } p.s3Client = s3Client - p.setConfig(config) + p.setConfig(newConfig) p.setBundle(nil) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // PublishBundle puts the bundle in the configured S3 bucket name and // object key. func (p *Plugin) PublishBundle(ctx context.Context, req *bundlepublisherv1.PublishBundleRequest) (*bundlepublisherv1.PublishBundleResponse, error) { @@ -182,45 +232,3 @@ func newPlugin(newS3ClientFunc func(c aws.Config) (simpleStorageService, error)) }, } } - -// parseAndValidateConfig returns an error if any configuration provided does -// not meet acceptable criteria -func parseAndValidateConfig(c string) (*Config, error) { - config := new(Config) - - if err := hcl.Decode(config, c); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.Region == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the region") - } - - if config.Bucket == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the bucket name") - } - - if config.ObjectKey == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the object key") - } - - if config.Format == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the bundle format") - } - bundleFormat, err := bundleformat.FromString(config.Format) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "could not parse bundle format from configuration: %v", err) - } - // The bundleformat package may support formats that this plugin does not - // support. Validate that the format is a supported format in this plugin. - switch bundleFormat { - case bundleformat.JWKS: - case bundleformat.SPIFFE: - case bundleformat.PEM: - default: - return nil, status.Errorf(codes.InvalidArgument, "format not supported %q", config.Format) - } - - config.bundleFormat = bundleFormat - return config, nil -} diff --git a/pkg/server/plugin/bundlepublisher/gcpcloudstorage/gcpcloudstorage.go b/pkg/server/plugin/bundlepublisher/gcpcloudstorage/gcpcloudstorage.go index ea0c304e8f..4875b8c720 100644 --- a/pkg/server/plugin/bundlepublisher/gcpcloudstorage/gcpcloudstorage.go +++ b/pkg/server/plugin/bundlepublisher/gcpcloudstorage/gcpcloudstorage.go @@ -13,6 +13,7 @@ import ( "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/telemetry" "google.golang.org/api/option" "google.golang.org/grpc/codes" @@ -50,6 +51,42 @@ type Config struct { bundleFormat bundleformat.Format } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.BucketName == "" { + status.ReportError("configuration is missing the bucket name") + } + if newConfig.ObjectName == "" { + status.ReportError("configuration is missing the object name") + } + + if newConfig.Format == "" { + status.ReportError("configuration is missing the bundle format") + } + bundleFormat, err := bundleformat.FromString(newConfig.Format) + if err != nil { + status.ReportErrorf("could not parse bundle format from configuration: %v", err) + } else { + // Only some bundleformats are supported by this plugin. + switch bundleFormat { + case bundleformat.JWKS: + case bundleformat.SPIFFE: + case bundleformat.PEM: + default: + status.ReportErrorf("format not supported %q", newConfig.Format) + } + } + newConfig.bundleFormat = bundleFormat + + return newConfig +} + // Plugin is the main representation of this bundle publisher plugin. type Plugin struct { bundlepublisherv1.UnsafeBundlePublisherServer @@ -73,27 +110,37 @@ func (p *Plugin) SetLogger(log hclog.Logger) { // Configure configures the plugin. func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := parseAndValidateConfig(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } var opts []option.ClientOption - if config.ServiceAccountFile != "" { - opts = append(opts, option.WithCredentialsFile(config.ServiceAccountFile)) + if newConfig.ServiceAccountFile != "" { + opts = append(opts, option.WithCredentialsFile(newConfig.ServiceAccountFile)) } - gcsClient, err := p.hooks.newGCSClientFunc(ctx, opts...) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create client: %v", err) } p.gcsClient = gcsClient - p.setConfig(config) + p.setConfig(newConfig) + p.setBundle(nil) + return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // PublishBundle puts the bundle in the configured GCS bucket and object name. func (p *Plugin) PublishBundle(ctx context.Context, req *bundlepublisherv1.PublishBundleRequest) (*bundlepublisherv1.PublishBundleResponse, error) { config, err := p.getConfig() @@ -223,41 +270,3 @@ func newPlugin(newGCSClientFunc func(ctx context.Context, opts ...option.ClientO }, } } - -// parseAndValidateConfig returns an error if any configuration provided does -// not meet acceptable criteria -func parseAndValidateConfig(c string) (*Config, error) { - config := new(Config) - - if err := hcl.Decode(config, c); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.BucketName == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the bucket name") - } - - if config.ObjectName == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the object name") - } - - if config.Format == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the bundle format") - } - bundleFormat, err := bundleformat.FromString(config.Format) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "could not parse bundle format from configuration: %v", err) - } - // The bundleformat package may support formats that this plugin does not - // support. Validate that the format is a supported format in this plugin. - switch bundleFormat { - case bundleformat.JWKS: - case bundleformat.SPIFFE: - case bundleformat.PEM: - default: - return nil, status.Errorf(codes.InvalidArgument, "format not supported %q", config.Format) - } - - config.bundleFormat = bundleFormat - return config, nil -} diff --git a/pkg/server/plugin/keymanager/awskms/awskms.go b/pkg/server/plugin/keymanager/awskms/awskms.go index f2a51e37fe..f94bc8ba79 100644 --- a/pkg/server/plugin/keymanager/awskms/awskms.go +++ b/pkg/server/plugin/keymanager/awskms/awskms.go @@ -25,6 +25,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/diskutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -101,6 +102,41 @@ type Config struct { KeyPolicyFile string `hcl:"key_policy_file" json:"key_policy_file"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.Region == "" { + status.ReportError("configuration is missing a region") + } + + if newConfig.KeyIdentifierValue != "" { + re := regexp.MustCompile(".*[^A-z0-9/_-].*") + if re.MatchString(newConfig.KeyIdentifierValue) { + status.ReportError("Key identifier must contain only alphanumeric characters, forward slashes (/), underscores (_), and dashes (-)") + } + if strings.HasPrefix(newConfig.KeyIdentifierValue, "alias/aws/") { + status.ReportError("Key identifier must not start with alias/aws/") + } + if len(newConfig.KeyIdentifierValue) > 256 { + status.ReportError("Key identifier must not be longer than 256 characters") + } + } + + if newConfig.KeyIdentifierFile == "" && newConfig.KeyIdentifierValue == "" { + status.ReportError("configuration requires a key identifier file or a key identifier value") + } + + if newConfig.KeyIdentifierFile != "" && newConfig.KeyIdentifierValue != "" { + status.ReportError("configuration can't have a key identifier file and a key identifier value at the same time") + } + + return newConfig +} + // New returns an instantiated plugin func New() *Plugin { return newPlugin(newKMSClient, newSTSClient) @@ -128,13 +164,13 @@ func (p *Plugin) SetLogger(log hclog.Logger) { // Configure sets up the plugin func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := parseAndValidateConfig(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - if config.KeyPolicyFile != "" { - policyBytes, err := os.ReadFile(config.KeyPolicyFile) + if newConfig.KeyPolicyFile != "" { + policyBytes, err := os.ReadFile(newConfig.KeyPolicyFile) if err != nil { return nil, status.Errorf(codes.Internal, "failed to read file configured in 'key_policy_file': %v", err) } @@ -142,16 +178,16 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) p.keyPolicy = &policyStr } - serverID := config.KeyIdentifierValue + serverID := newConfig.KeyIdentifierValue if serverID == "" { - serverID, err = getOrCreateServerID(config.KeyIdentifierFile) + serverID, err = getOrCreateServerID(newConfig.KeyIdentifierFile) if err != nil { return nil, err } } p.log.Debug("Loaded server id", "server_id", serverID) - awsCfg, err := newAWSConfig(ctx, config) + awsCfg, err := newAWSConfig(ctx, newConfig) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create client configuration: %v", err) } @@ -202,6 +238,15 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // GenerateKey creates a key in KMS. If a key already exists in the local storage, it is updated. func (p *Plugin) GenerateKey(ctx context.Context, req *keymanagerv1.GenerateKeyRequest) (*keymanagerv1.GenerateKeyResponse, error) { if req.KeyId == "" { @@ -825,42 +870,6 @@ func sanitizeTrustDomain(trustDomain string) string { return strings.ReplaceAll(trustDomain, ".", "_") } -// parseAndValidateConfig returns an error if any configuration provided does not meet acceptable criteria -func parseAndValidateConfig(c string) (*Config, error) { - config := new(Config) - - if err := hcl.Decode(config, c); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.Region == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing a region") - } - - if config.KeyIdentifierValue != "" { - re := regexp.MustCompile(".*[^A-z0-9/_-].*") - if re.MatchString(config.KeyIdentifierValue) { - return nil, status.Error(codes.InvalidArgument, "Key identifier must contain only alphanumeric characters, forward slashes (/), underscores (_), and dashes (-)") - } - if strings.HasPrefix(config.KeyIdentifierValue, "alias/aws/") { - return nil, status.Error(codes.InvalidArgument, "Key identifier must not start with alias/aws/") - } - if len(config.KeyIdentifierValue) > 256 { - return nil, status.Error(codes.InvalidArgument, "Key identifier must not be longer than 256 characters") - } - } - - if config.KeyIdentifierFile == "" && config.KeyIdentifierValue == "" { - return nil, status.Error(codes.InvalidArgument, "configuration requires a key identifier file or a key identifier value") - } - - if config.KeyIdentifierFile != "" && config.KeyIdentifierValue != "" { - return nil, status.Error(codes.InvalidArgument, "configuration can't have a key identifier file and a key identifier value at the same time") - } - - return config, nil -} - func signingAlgorithmForKMS(keyType keymanagerv1.KeyType, signerOpts any) (types.SigningAlgorithmSpec, error) { var ( hashAlgo keymanagerv1.HashAlgorithm diff --git a/pkg/server/plugin/keymanager/awskms/awskms_test.go b/pkg/server/plugin/keymanager/awskms/awskms_test.go index c689156092..2a1d8695fb 100644 --- a/pkg/server/plugin/keymanager/awskms/awskms_test.go +++ b/pkg/server/plugin/keymanager/awskms/awskms_test.go @@ -19,8 +19,10 @@ import ( "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/server/plugin/keymanager" keymanagertest "github.com/spiffe/spire/pkg/server/plugin/keymanager/test" "github.com/spiffe/spire/test/plugintest" @@ -95,7 +97,10 @@ func TestKeyManagerContract(t *testing.T) { if isWindows { keyIdentifierFile = filepath.ToSlash(keyIdentifierFile) } - plugintest.Load(t, builtin(p), km, plugintest.Configuref(` + plugintest.Load(t, builtin(p), km, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configuref(` region = "fake-region" key_identifier_file = %q `, keyIdentifierFile)) @@ -283,16 +288,16 @@ func TestConfigure(t *testing.T) { }, { name: "list aliases error", + configureRequest: configureRequestWithDefaults(t), err: "failed to fetch aliases: fake list aliases error", code: codes.Internal, - configureRequest: configureRequestWithDefaults(t), listAliasesErr: "fake list aliases error", }, { name: "describe key error", + configureRequest: configureRequestWithDefaults(t), err: "failed to describe key: describe key error", code: codes.Internal, - configureRequest: configureRequestWithDefaults(t), fakeEntries: []fakeKeyEntry{ { AliasName: aws.String(aliasName), @@ -306,9 +311,9 @@ func TestConfigure(t *testing.T) { }, { name: "unsupported key error", + configureRequest: configureRequestWithDefaults(t), err: "unsupported key spec: unsupported key spec", code: codes.Internal, - configureRequest: configureRequestWithDefaults(t), fakeEntries: []fakeKeyEntry{ { AliasName: aws.String(aliasName), @@ -321,9 +326,9 @@ func TestConfigure(t *testing.T) { }, { name: "get public key error", + configureRequest: configureRequestWithDefaults(t), err: "failed to fetch aliases: failed to get public key: get public key error", code: codes.Internal, - configureRequest: configureRequestWithDefaults(t), fakeEntries: []fakeKeyEntry{ { AliasName: aws.String(aliasName), @@ -338,9 +343,9 @@ func TestConfigure(t *testing.T) { { name: "disabled key", + configureRequest: configureRequestWithDefaults(t), err: "failed to fetch aliases: found disabled SPIRE key: \"arn:aws:kms:region:1234:key/abcd-fghi\", alias: \"arn:aws:kms:region:1234:alias/SPIRE_SERVER/test_example_org/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/spireKeyID\"", code: codes.FailedPrecondition, - configureRequest: configureRequestWithDefaults(t), fakeEntries: []fakeKeyEntry{ { AliasName: aws.String(aliasName), @@ -1964,8 +1969,8 @@ func TestDisposeKeys(t *testing.T) { func configureRequestWithString(config string) *configv1.ConfigureRequest { return &configv1.ConfigureRequest{ - HclConfiguration: config, CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: "test.example.org"}, + HclConfiguration: config, } } @@ -1978,6 +1983,7 @@ const ( func configureRequestWithVars(accessKeyID, secretAccessKey, region, keyIdentifierConfigName KeyIdentifierConfigName, keyIdentifierConfigValue, keyPolicyFile string) *configv1.ConfigureRequest { return &configv1.ConfigureRequest{ + CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: "test.example.org"}, HclConfiguration: fmt.Sprintf(`{ "access_key_id": "%s", "secret_access_key": "%s", @@ -1991,14 +1997,13 @@ func configureRequestWithVars(accessKeyID, secretAccessKey, region, keyIdentifie keyIdentifierConfigName, keyIdentifierConfigValue, keyPolicyFile), - CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: "test.example.org"}, } } func configureRequestWithDefaults(t *testing.T) *configv1.ConfigureRequest { return &configv1.ConfigureRequest{ - HclConfiguration: serializedConfiguration(validAccessKeyID, validSecretAccessKey, validRegion, KeyIdentifierFile, getKeyIdentifierFile(t)), CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: "test.example.org"}, + HclConfiguration: serializedConfiguration(validAccessKeyID, validSecretAccessKey, validRegion, KeyIdentifierFile, getKeyIdentifierFile(t)), } } diff --git a/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault.go b/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault.go index e8325fb3fa..696dbefe79 100644 --- a/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault.go +++ b/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault.go @@ -27,6 +27,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/diskutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" "google.golang.org/grpc/codes" @@ -86,6 +87,35 @@ type Config struct { AppSecret string `hcl:"app_secret" json:"app_secret"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.KeyVaultURI == "" { + status.ReportError("configuration is missing the Key Vault URI") + } + + if newConfig.KeyIdentifierValue != "" { + if len(newConfig.KeyIdentifierValue) > 256 { + status.ReportError("Key identifier must not be longer than 256 characters") + } + } + + if newConfig.KeyIdentifierFile == "" && newConfig.KeyIdentifierValue == "" { + status.ReportError("configuration requires a key identifier file or a key identifier value") + } + + if newConfig.KeyIdentifierFile != "" && newConfig.KeyIdentifierValue != "" { + status.ReportError("configuration can't have a key identifier file and a key identifier value at the same time") + } + + return newConfig +} + // Plugin is the main representation of this keymanager plugin type Plugin struct { keymanagerv1.UnsafeKeyManagerServer @@ -131,14 +161,14 @@ func (p *Plugin) SetLogger(log hclog.Logger) { } func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := parseAndValidateConfig(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - serverID := config.KeyIdentifierValue + serverID := newConfig.KeyIdentifierValue if serverID == "" { - serverID, err = getOrCreateServerID(config.KeyIdentifierFile) + serverID, err = getOrCreateServerID(newConfig.KeyIdentifierFile) if err != nil { return nil, err } @@ -148,26 +178,26 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) var client cloudKeyManagementService switch { - case config.SubscriptionID != "", config.AppID != "", config.AppSecret != "", config.TenantID != "": - if config.TenantID == "" { + case newConfig.SubscriptionID != "", newConfig.AppID != "", newConfig.AppSecret != "", newConfig.TenantID != "": + if newConfig.TenantID == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid configuration, missing tenant id") } - if config.SubscriptionID == "" { + if newConfig.SubscriptionID == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid configuration, missing subscription id") } - if config.AppID == "" { + if newConfig.AppID == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid configuration, missing application id") } - if config.AppSecret == "" { + if newConfig.AppSecret == "" { return nil, status.Errorf(codes.InvalidArgument, "invalid configuration, missing app secret") } - creds, err := azidentity.NewClientSecretCredential(config.TenantID, config.AppID, config.AppSecret, nil) + creds, err := azidentity.NewClientSecretCredential(newConfig.TenantID, newConfig.AppID, newConfig.AppSecret, nil) if err != nil { return nil, status.Errorf(codes.Internal, "unable to get client credential: %v", err) } - client, err = p.hooks.newKeyVaultClient(creds, config.KeyVaultURI) + client, err = p.hooks.newKeyVaultClient(creds, newConfig.KeyVaultURI) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create Key Vault client with client credentials: %v", err) } @@ -176,7 +206,7 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) if err != nil { return nil, status.Errorf(codes.Internal, "unable to fetch client credential: %v", err) } - client, err = p.hooks.newKeyVaultClient(cred, config.KeyVaultURI) + client, err = p.hooks.newKeyVaultClient(cred, newConfig.KeyVaultURI) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create Key Vault client with MSI credential: %v", err) } @@ -189,7 +219,7 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) trustDomain: req.CoreConfiguration.TrustDomain, } - p.log.Debug("Fetching keys from Azure Key Vault", "key_vault_uri", config.KeyVaultURI) + p.log.Debug("Fetching keys from Azure Key Vault", "key_vault_uri", newConfig.KeyVaultURI) keyEntries, err := fetcher.fetchKeyEntries(ctx) if err != nil { return nil, err @@ -220,6 +250,15 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + // refreshKeysTask will update the keys in the cache every 6 hours. // Keys will be updated with the same Operations they already have (Sign and Verify). // The consequence of this is that the value of the field "Updated" in each key belonging to the server will be set to the current timestamp. @@ -665,35 +704,6 @@ func (p *Plugin) generateKeyName(spireKeyID string) (keyName string, err error) return fmt.Sprintf("%s-%s-%s", keyNamePrefix, uniqueID, spireKeyID), nil } -// parseAndValidateConfig returns an error if any configuration provided does not meet acceptable criteria -func parseAndValidateConfig(c string) (*Config, error) { - config := new(Config) - - if err := hcl.Decode(config, c); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.KeyVaultURI == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the Key Vault URI") - } - - if config.KeyIdentifierValue != "" { - if len(config.KeyIdentifierValue) > 256 { - return nil, status.Error(codes.InvalidArgument, "Key identifier must not be longer than 256 characters") - } - } - - if config.KeyIdentifierFile == "" && config.KeyIdentifierValue == "" { - return nil, status.Error(codes.InvalidArgument, "configuration requires a key identifier file or a key identifier value") - } - - if config.KeyIdentifierFile != "" && config.KeyIdentifierValue != "" { - return nil, status.Error(codes.InvalidArgument, "configuration can't have a key identifier file and a key identifier value at the same time") - } - - return config, nil -} - func getOrCreateServerID(idPath string) (string, error) { data, err := os.ReadFile(idPath) switch { diff --git a/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault_test.go b/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault_test.go index d46abaab8f..bbda919591 100644 --- a/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault_test.go +++ b/pkg/server/plugin/keymanager/azurekeyvault/azure_key_vault_test.go @@ -20,8 +20,10 @@ import ( "github.com/gofrs/uuid/v5" "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" + "github.com/spiffe/go-spiffe/v2/spiffeid" keymanagerv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/keymanager/v1" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/server/plugin/keymanager" keymanagertest "github.com/spiffe/spire/pkg/server/plugin/keymanager/test" "github.com/spiffe/spire/test/plugintest" @@ -68,7 +70,11 @@ func TestKeyManagerContract(t *testing.T) { km := new(keymanager.V1) keyIdentifierFile := createKeyIdentifierFile(t) - plugintest.Load(t, builtin(p), km, plugintest.Configuref(` + plugintest.Load(t, builtin(p), km, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configuref(` key_identifier_file = %q key_vault_uri = "https://spire-server.vault.azure.net/" use_msi=true @@ -1002,14 +1008,15 @@ func serializedConfiguration(keyIdentifierConfigName KeyIdentifierConfigName, ke func configureRequestWithVars(keyIdentifierConfigName KeyIdentifierConfigName, keyIdentifierConfigValue, keyVaultURI, tenantID, subscriptionID, appID, appSecret string) *configv1.ConfigureRequest { return &configv1.ConfigureRequest{ - HclConfiguration: serializedConfiguration(keyIdentifierConfigName, keyIdentifierConfigValue, keyVaultURI, tenantID, subscriptionID, appID, appSecret), CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: trustDomain}, + HclConfiguration: serializedConfiguration(keyIdentifierConfigName, keyIdentifierConfigValue, keyVaultURI, tenantID, subscriptionID, appID, appSecret), } } func configureRequestWithString(config string) *configv1.ConfigureRequest { return &configv1.ConfigureRequest{ - HclConfiguration: config, + CoreConfiguration: &configv1.CoreConfiguration{TrustDomain: trustDomain}, + HclConfiguration: config, } } diff --git a/pkg/server/plugin/keymanager/disk/disk.go b/pkg/server/plugin/keymanager/disk/disk.go index 891baef0b1..087df5e552 100644 --- a/pkg/server/plugin/keymanager/disk/disk.go +++ b/pkg/server/plugin/keymanager/disk/disk.go @@ -12,6 +12,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" catalog "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/diskutil" + "github.com/spiffe/spire/pkg/common/pluginconf" keymanagerbase "github.com/spiffe/spire/pkg/server/plugin/keymanager/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -37,6 +38,20 @@ type configuration struct { KeysPath string `hcl:"keys_path"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *configuration { + newConfig := new(configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if newConfig.KeysPath == "" { + status.ReportError("keys_path is required") + } + + return newConfig +} + type KeyManager struct { *keymanagerbase.Base configv1.UnimplementedConfigServer @@ -55,25 +70,30 @@ func newKeyManager(generator Generator) *KeyManager { } func (m *KeyManager) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := new(configuration) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.KeysPath == "" { - return nil, status.Error(codes.InvalidArgument, "keys_path is required") + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } m.mu.Lock() defer m.mu.Unlock() - if err := m.configure(config); err != nil { + if err := m.configure(newConfig); err != nil { return nil, err } return &configv1.ConfigureResponse{}, nil } +func (m *KeyManager) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + func (m *KeyManager) configure(config *configuration) error { // only load entry information on first configure if m.config == nil { diff --git a/pkg/server/plugin/keymanager/disk/disk_test.go b/pkg/server/plugin/keymanager/disk/disk_test.go index 482a4bcfb9..3ebc1f6531 100644 --- a/pkg/server/plugin/keymanager/disk/disk_test.go +++ b/pkg/server/plugin/keymanager/disk/disk_test.go @@ -7,6 +7,8 @@ import ( "path/filepath" "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/server/plugin/keymanager" "github.com/spiffe/spire/pkg/server/plugin/keymanager/disk" keymanagertest "github.com/spiffe/spire/pkg/server/plugin/keymanager/test" @@ -84,6 +86,9 @@ func loadPlugin(t *testing.T, configFmt string, configArgs ...any) (keymanager.K km := new(keymanager.V1) var configErr error plugintest.Load(t, disk.TestBuiltIn(keymanagertest.NewGenerator()), km, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.Configuref(configFmt, configArgs...), plugintest.CaptureConfigureError(&configErr), ) diff --git a/pkg/server/plugin/keymanager/gcpkms/gcpkms.go b/pkg/server/plugin/keymanager/gcpkms/gcpkms.go index 47b696b1ce..36fa63ce1a 100644 --- a/pkg/server/plugin/keymanager/gcpkms/gcpkms.go +++ b/pkg/server/plugin/keymanager/gcpkms/gcpkms.go @@ -27,6 +27,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/diskutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/api/iterator" "google.golang.org/api/option" "google.golang.org/grpc/codes" @@ -137,6 +138,36 @@ type Config struct { ServiceAccountFile string `hcl:"service_account_file" json:"service_account_file"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + newConfig := new(Config) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + } + + if newConfig.KeyRing == "" { + status.ReportError("configuration is missing the key ring") + } + + if newConfig.KeyIdentifierFile == "" && newConfig.KeyIdentifierValue == "" { + status.ReportError("configuration requires a key identifier file or a key identifier value") + } + + if newConfig.KeyIdentifierFile != "" && newConfig.KeyIdentifierValue != "" { + status.ReportError("configuration can't have a key identifier file and a key identifier value at the same time") + } + + if newConfig.KeyIdentifierValue != "" { + if !validateCharacters(newConfig.KeyIdentifierValue) { + status.ReportError("Key identifier must contain only letters, numbers, underscores (_), and dashes (-)") + } + if len(newConfig.KeyIdentifierValue) > 63 { + status.ReportError("Key identifier must not be longer than 63 characters") + } + } + + return newConfig +} + // New returns an instantiated plugin. func New() *Plugin { return newPlugin(newKMSClient) @@ -166,14 +197,14 @@ func (p *Plugin) Close() error { // Configure sets up the plugin. func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := parseAndValidateConfig(req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - serverID := config.KeyIdentifierValue + serverID := newConfig.KeyIdentifierValue if serverID == "" { - serverID, err = getOrCreateServerID(config.KeyIdentifierFile) + serverID, err = getOrCreateServerID(newConfig.KeyIdentifierFile) if err != nil { return nil, err } @@ -181,8 +212,8 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) p.log.Debug("Loaded server id", "server_id", serverID) var customPolicy *iam.Policy3 - if config.KeyPolicyFile != "" { - if customPolicy, err = parsePolicyFile(config.KeyPolicyFile); err != nil { + if newConfig.KeyPolicyFile != "" { + if customPolicy, err = parsePolicyFile(newConfig.KeyPolicyFile); err != nil { return nil, status.Errorf(codes.Internal, "could not parse policy file: %v", err) } } @@ -200,8 +231,8 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) }) var opts []option.ClientOption - if config.ServiceAccountFile != "" { - opts = append(opts, option.WithCredentialsFile(config.ServiceAccountFile)) + if newConfig.ServiceAccountFile != "" { + opts = append(opts, option.WithCredentialsFile(newConfig.ServiceAccountFile)) } kc, err := p.hooks.newKMSClient(ctx, opts...) @@ -210,13 +241,13 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) } fetcher := &keyFetcher{ - keyRing: config.KeyRing, + keyRing: newConfig.KeyRing, kmsClient: kc, log: p.log, serverID: serverID, tdHash: tdHashString, } - p.log.Debug("Fetching keys from Cloud KMS", "key_ring", config.KeyRing) + p.log.Debug("Fetching keys from Cloud KMS", "key_ring", newConfig.KeyRing) keyEntries, err := fetcher.fetchKeyEntries(ctx) if err != nil { return nil, err @@ -232,7 +263,7 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) p.configMtx.Lock() defer p.configMtx.Unlock() - p.config = config + p.config = newConfig // Start long-running tasks. ctx, p.cancelTasks = context.WithCancel(context.Background()) @@ -243,6 +274,15 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + // GenerateKey creates a key in KMS. If a key already exists in the local storage, // it is updated. func (p *Plugin) GenerateKey(ctx context.Context, req *keymanagerv1.GenerateKeyRequest) (*keymanagerv1.GenerateKeyResponse, error) { @@ -1098,39 +1138,6 @@ func min(x, y time.Duration) time.Duration { return y } -// parseAndValidateConfig returns an error if any configuration provided does -// not meet acceptable criteria -func parseAndValidateConfig(c string) (*Config, error) { - config := new(Config) - - if err := hcl.Decode(config, c); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.KeyRing == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing the key ring") - } - - if config.KeyIdentifierFile == "" && config.KeyIdentifierValue == "" { - return nil, status.Error(codes.InvalidArgument, "configuration requires a key identifier file or a key identifier value") - } - - if config.KeyIdentifierFile != "" && config.KeyIdentifierValue != "" { - return nil, status.Error(codes.InvalidArgument, "configuration can't have a key identifier file and a key identifier value at the same time") - } - - if config.KeyIdentifierValue != "" { - if !validateCharacters(config.KeyIdentifierValue) { - return nil, status.Error(codes.InvalidArgument, "Key identifier must contain only letters, numbers, underscores (_), and dashes (-)") - } - if len(config.KeyIdentifierValue) > 63 { - return nil, status.Error(codes.InvalidArgument, "Key identifier must not be longer than 63 characters") - } - } - - return config, nil -} - func validateCharacters(str string) bool { for _, r := range str { if !unicode.IsLower(r) && !unicode.IsNumber(r) && r != '-' && r != '_' { diff --git a/pkg/server/plugin/keymanager/gcpkms/gcpkms_test.go b/pkg/server/plugin/keymanager/gcpkms/gcpkms_test.go index a20443dadb..5e5254bf2e 100644 --- a/pkg/server/plugin/keymanager/gcpkms/gcpkms_test.go +++ b/pkg/server/plugin/keymanager/gcpkms/gcpkms_test.go @@ -1382,7 +1382,11 @@ func TestKeyManagerContract(t *testing.T) { ) km := new(keymanager.V1) keyIdentifierFile := filepath.ToSlash(filepath.Join(dir, "key_identifier.json")) - plugintest.Load(t, builtin(p), km, plugintest.Configuref(` + plugintest.Load(t, builtin(p), km, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("test.example.org"), + }), + plugintest.Configuref(` key_identifier_file = %q key_ring = "projects/project-id/locations/location/keyRings/keyring" `, keyIdentifierFile)) diff --git a/pkg/server/plugin/nodeattestor/awsiid/iid.go b/pkg/server/plugin/nodeattestor/awsiid/iid.go index 3df80a0c27..d6c0206893 100644 --- a/pkg/server/plugin/nodeattestor/awsiid/iid.go +++ b/pkg/server/plugin/nodeattestor/awsiid/iid.go @@ -36,6 +36,7 @@ import ( "github.com/spiffe/spire/pkg/common/agentpathtemplate" "github.com/spiffe/spire/pkg/common/catalog" caws "github.com/spiffe/spire/pkg/common/plugin/aws" + "github.com/spiffe/spire/pkg/common/pluginconf" nodeattestorbase "github.com/spiffe/spire/pkg/server/plugin/nodeattestor/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -127,6 +128,53 @@ type IIDAttestorConfig struct { getAWSCACertificate func(string, PublicKeyType) (*x509.Certificate, error) } +func (p *IIDAttestorPlugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *IIDAttestorConfig { + newConfig := new(IIDAttestorConfig) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + // Function to get the AWS CA certificate. We do this lazily on configure so deployments + // not using this plugin don't pay for parsing it on startup. This + // operation should not fail, but we check the return value just in case. + newConfig.getAWSCACertificate = p.hooks.getAWSCACertificate + + if err := newConfig.Validate(p.hooks.getenv(accessKeyIDVarName), p.hooks.getenv(secretAccessKeyVarName)); err != nil { + status.ReportError(err.Error()) + } + + newConfig.trustDomain = coreConfig.TrustDomain + + newConfig.pathTemplate = defaultAgentPathTemplate + if len(newConfig.AgentPathTemplate) > 0 { + tmpl, err := agentpathtemplate.Parse(newConfig.AgentPathTemplate) + if err != nil { + status.ReportErrorf("failed to parse agent svid template: %q", newConfig.AgentPathTemplate) + } else { + newConfig.pathTemplate = tmpl + } + } + + if newConfig.Partition == "" { + newConfig.Partition = defaultPartition + } + + if !isValidAWSPartition(newConfig.Partition) { + status.ReportErrorf("invalid partition %q, must be one of: %v", newConfig.Partition, partitions) + } + + // Check if Feature flag for account belongs to organization is enabled. + if newConfig.ValidateOrgAccountID != nil { + err := validateOrganizationConfig(newConfig) + if err != nil { + status.ReportError(err.Error()) + } + } + + return newConfig +} + // New creates a new IIDAttestorPlugin. func New() *IIDAttestorPlugin { p := &IIDAttestorPlugin{} @@ -257,64 +305,22 @@ func (p *IIDAttestorPlugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServ // Configure configures the IIDAttestorPlugin. func (p *IIDAttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := new(IIDAttestorConfig) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - // Function to get the AWS CA certificate. We do this lazily on configure so deployments - // not using this plugin don't pay for parsing it on startup. This - // operation should not fail, but we check the return value just in case. - config.getAWSCACertificate = p.hooks.getAWSCACertificate - - if err := config.Validate(p.hooks.getenv(accessKeyIDVarName), p.hooks.getenv(secretAccessKeyVarName)); err != nil { - return nil, err - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - var err error - config.trustDomain, err = spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "core configuration has invalid trust domain: %v", err) - } - - config.pathTemplate = defaultAgentPathTemplate - if len(config.AgentPathTemplate) > 0 { - tmpl, err := agentpathtemplate.Parse(config.AgentPathTemplate) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse agent svid template: %q", config.AgentPathTemplate) - } - config.pathTemplate = tmpl - } - - if config.Partition == "" { - config.Partition = defaultPartition - } - if !isValidAWSPartition(config.Partition) { - return nil, status.Errorf(codes.InvalidArgument, "invalid partition %q, must be one of: %v", config.Partition, partitions) - } - - // Check if Feature flag for account belongs to organization is enabled. - orgConfig := &orgValidationConfig{} - if config.ValidateOrgAccountID != nil { - err = validateOrganizationConfig(config) - if err != nil { - return nil, err - } - orgConfig = config.ValidateOrgAccountID + return nil, err } p.mtx.Lock() defer p.mtx.Unlock() + p.config = newConfig - p.config = config - p.clients.configure(config.SessionConfig, *orgConfig) - if config.ValidateOrgAccountID != nil { + if newConfig.ValidateOrgAccountID == nil { + // unconfigure existing clients + p.clients.configure(p.config.SessionConfig, orgValidationConfig{}) + } else { + p.clients.configure(p.config.SessionConfig, *p.config.ValidateOrgAccountID) // Setup required config, for validation and for bootstrapping org client - if err := p.orgValidation.configure(orgConfig); err != nil { + if err := p.orgValidation.configure(p.config.ValidateOrgAccountID); err != nil { return nil, err } } @@ -322,6 +328,15 @@ func (p *IIDAttestorPlugin) Configure(_ context.Context, req *configv1.Configure return &configv1.ConfigureResponse{}, nil } +func (p *IIDAttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + // SetLogger sets this plugin's logger func (p *IIDAttestorPlugin) SetLogger(log hclog.Logger) { p.log = log diff --git a/pkg/server/plugin/nodeattestor/awsiid/iid_test.go b/pkg/server/plugin/nodeattestor/awsiid/iid_test.go index 6dc9fc6e9f..b5f1375cf5 100644 --- a/pkg/server/plugin/nodeattestor/awsiid/iid_test.go +++ b/pkg/server/plugin/nodeattestor/awsiid/iid_test.go @@ -582,7 +582,7 @@ func TestConfigure(t *testing.T) { t.Run("missing trust domain", func(t *testing.T) { err := doConfig(t, catalog.CoreConfig{}, ``) - spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "core configuration has invalid trust domain: trust domain is missing") + spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "server core configuration must contain trust_domain") }) t.Run("fails with access id but no secret", func(t *testing.T) { diff --git a/pkg/server/plugin/nodeattestor/awsiid/organization.go b/pkg/server/plugin/nodeattestor/awsiid/organization.go index 720bd4d752..1e388388fa 100644 --- a/pkg/server/plugin/nodeattestor/awsiid/organization.go +++ b/pkg/server/plugin/nodeattestor/awsiid/organization.go @@ -91,7 +91,7 @@ func (o *orgValidator) configure(config *orgValidationConfig) error { t, err := time.ParseDuration(config.AccountListTTL) if err != nil { - return status.Errorf(codes.InvalidArgument, "issue while parsing ttl for organization, while configuring orgnization validation: %v", err) + return status.Errorf(codes.InvalidArgument, "issue while parsing ttl for organization, while configuring organization validation: %v", err) } o.orgAccountListCacheTTL = t diff --git a/pkg/server/plugin/nodeattestor/azuremsi/msi.go b/pkg/server/plugin/nodeattestor/azuremsi/msi.go index b3b662839f..4d2df59e7a 100644 --- a/pkg/server/plugin/nodeattestor/azuremsi/msi.go +++ b/pkg/server/plugin/nodeattestor/azuremsi/msi.go @@ -26,6 +26,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/jwtutil" "github.com/spiffe/spire/pkg/common/plugin/azure" + "github.com/spiffe/spire/pkg/common/pluginconf" nodeattestorbase "github.com/spiffe/spire/pkg/server/plugin/nodeattestor/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -99,6 +100,92 @@ type msiAttestorConfig struct { idPathTemplate *agentpathtemplate.Template } +func (p *MSIAttestorPlugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *msiAttestorConfig { + newConfig := new(MSIAttestorConfig) + + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if len(newConfig.Tenants) == 0 { + status.ReportError("configuration must have at least one tenant") + } + for _, tenant := range newConfig.Tenants { + if tenant.ResourceID == "" { + tenant.ResourceID = azure.DefaultMSIResourceID + } + } + + tenants := make(map[string]*tenantConfig) + for tenantID, tenant := range newConfig.Tenants { + var client apiClient + + // Use tenant-specific credentials for resolving selectors + switch { + case tenant.SubscriptionID != "", tenant.AppID != "", tenant.AppSecret != "": + if tenant.SubscriptionID == "" { + status.ReportErrorf("misconfigured tenant %q: missing subscription id", tenantID) + } + if tenant.AppID == "" { + status.ReportErrorf("misconfigured tenant %q: missing app id", tenantID) + } + if tenant.AppSecret == "" { + status.ReportErrorf("misconfigured tenant %q: missing app secret", tenantID) + } + + cred, err := azidentity.NewClientSecretCredential(tenantID, tenant.AppID, tenant.AppSecret, nil) + if err != nil { + status.ReportErrorf("unable to get tenant client credential: %v", err) + } + + client, err = p.hooks.newClient(tenant.SubscriptionID, cred) + if err != nil { + status.ReportErrorf("unable to create client for tenant %q: %v", tenantID, err) + } + + default: + instanceMetadata, err := p.hooks.fetchInstanceMetadata(http.DefaultClient) + if err != nil { + status.ReportError(err.Error()) + } + cred, err := p.hooks.fetchCredential(tenantID) + if err != nil { + status.ReportErrorf("unable to fetch client credential: %v", err) + } + client, err = p.hooks.newClient(instanceMetadata.Compute.SubscriptionID, cred) + if err != nil { + status.ReportErrorf("unable to create client with default credential: %v", err) + } + } + + // If credentials are not configured then selectors won't be gathered. + if client == nil { + status.ReportErrorf("no client credentials available for tenant %q", tenantID) + } + + tenants[tenantID] = &tenantConfig{ + resourceID: tenant.ResourceID, + client: client, + } + } + + tmpl := azure.DefaultAgentPathTemplate + if len(newConfig.AgentPathTemplate) > 0 { + var err error + tmpl, err = agentpathtemplate.Parse(newConfig.AgentPathTemplate) + if err != nil { + status.ReportErrorf("failed to parse agent path template: %q", newConfig.AgentPathTemplate) + } + } + + return &msiAttestorConfig{ + td: coreConfig.TrustDomain, + tenants: tenants, + idPathTemplate: tmpl, + } +} + type MSIAttestorPlugin struct { nodeattestorbase.Base nodeattestorv1.UnsafeNodeAttestorServer @@ -239,100 +326,25 @@ func (p *MSIAttestorPlugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServ } func (p *MSIAttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - hclConfig := new(MSIAttestorConfig) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "core configuration missing trust domain") - } - - if len(hclConfig.Tenants) == 0 { - return nil, status.Error(codes.InvalidArgument, "configuration must have at least one tenant") - } - for _, tenant := range hclConfig.Tenants { - if tenant.ResourceID == "" { - tenant.ResourceID = azure.DefaultMSIResourceID - } - } - - td, err := spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, err } - tenants := make(map[string]*tenantConfig) - - for tenantID, tenant := range hclConfig.Tenants { - var client apiClient - - // Use tenant-specific credentials for resolving selectors - switch { - case tenant.SubscriptionID != "", tenant.AppID != "", tenant.AppSecret != "": - if tenant.SubscriptionID == "" { - return nil, status.Errorf(codes.InvalidArgument, "misconfigured tenant %q: missing subscription id", tenantID) - } - if tenant.AppID == "" { - return nil, status.Errorf(codes.InvalidArgument, "misconfigured tenant %q: missing app id", tenantID) - } - if tenant.AppSecret == "" { - return nil, status.Errorf(codes.InvalidArgument, "misconfigured tenant %q: missing app secret", tenantID) - } - - cred, err := azidentity.NewClientSecretCredential(tenantID, tenant.AppID, tenant.AppSecret, nil) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to get tenant client credential: %v", err) - } - - client, err = p.hooks.newClient(tenant.SubscriptionID, cred) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to create client for tenant %q: %v", tenantID, err) - } - - default: - instanceMetadata, err := p.hooks.fetchInstanceMetadata(http.DefaultClient) - if err != nil { - return nil, err - } - cred, err := p.hooks.fetchCredential(tenantID) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to fetch client credential: %v", err) - } - client, err = p.hooks.newClient(instanceMetadata.Compute.SubscriptionID, cred) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to create client with default credential: %v", err) - } - } - - // If credentials are not configured then selectors won't be gathered. - if client == nil { - return nil, status.Errorf(codes.Internal, "no client credentials available for tenant %q", tenantID) - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - tenants[tenantID] = &tenantConfig{ - resourceID: tenant.ResourceID, - client: client, - } - } + return &configv1.ConfigureResponse{}, nil +} - tmpl := azure.DefaultAgentPathTemplate - if len(hclConfig.AgentPathTemplate) > 0 { - var err error - tmpl, err = agentpathtemplate.Parse(hclConfig.AgentPathTemplate) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse agent path template: %q", hclConfig.AgentPathTemplate) - } - } +func (p *MSIAttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) - p.setConfig(&msiAttestorConfig{ - td: td, - tenants: tenants, - idPathTemplate: tmpl, - }) - return &configv1.ConfigureResponse{}, nil + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil } func (p *MSIAttestorPlugin) getConfig() (*msiAttestorConfig, error) { @@ -344,12 +356,6 @@ func (p *MSIAttestorPlugin) getConfig() (*msiAttestorConfig, error) { return p.config, nil } -func (p *MSIAttestorPlugin) setConfig(config *msiAttestorConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} - func (p *MSIAttestorPlugin) resolve(ctx context.Context, client apiClient, principalID string) ([]string, error) { // Retrieve the resource belonging to the principal id. vmResourceID, err := client.GetVirtualMachineResourceID(ctx, principalID) diff --git a/pkg/server/plugin/nodeattestor/azuremsi/msi_test.go b/pkg/server/plugin/nodeattestor/azuremsi/msi_test.go index 3f7b1e88c3..5f94cc0964 100644 --- a/pkg/server/plugin/nodeattestor/azuremsi/msi_test.go +++ b/pkg/server/plugin/nodeattestor/azuremsi/msi_test.go @@ -456,7 +456,7 @@ func (s *MSIAttestorSuite) TestConfigure() { s.T().Run("missing trust domain", func(t *testing.T) { err := doConfig(t, catalog.CoreConfig{}, "", nil) - spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "core configuration missing trust domain") + spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "server core configuration must contain trust_domain") }) s.T().Run("missing tenants", func(t *testing.T) { @@ -587,7 +587,7 @@ func (s *MSIAttestorSuite) TestConfigure() { }, }, ) - spiretest.RequireGRPCStatusContains(t, err, codes.Internal, `unable to fetch client credential: some error`) + spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, `unable to fetch client credential: some error`) }) } diff --git a/pkg/server/plugin/nodeattestor/gcpiit/iit.go b/pkg/server/plugin/nodeattestor/gcpiit/iit.go index a38780b6b9..432ced55e5 100644 --- a/pkg/server/plugin/nodeattestor/gcpiit/iit.go +++ b/pkg/server/plugin/nodeattestor/gcpiit/iit.go @@ -18,6 +18,7 @@ import ( "github.com/spiffe/spire/pkg/common/agentpathtemplate" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/gcp" + "github.com/spiffe/spire/pkg/common/pluginconf" nodeattestorbase "github.com/spiffe/spire/pkg/server/plugin/nodeattestor/base" "google.golang.org/api/compute/v1" "google.golang.org/api/option" @@ -84,6 +85,50 @@ type IITAttestorConfig struct { ServiceAccountFile string `hcl:"service_account_file"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *IITAttestorConfig { + newConfig := new(IITAttestorConfig) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if len(newConfig.ProjectIDAllowList) == 0 { + status.ReportError("projectid_allow_list is required") + } + + tmpl := gcp.DefaultAgentPathTemplate + if len(newConfig.AgentPathTemplate) > 0 { + var err error + tmpl, err = agentpathtemplate.Parse(newConfig.AgentPathTemplate) + if err != nil { + status.ReportErrorf("failed to parse agent path template: %q", newConfig.AgentPathTemplate) + } + } + + if len(newConfig.AllowedLabelKeys) > 0 { + newConfig.allowedLabelKeys = make(map[string]bool, len(newConfig.AllowedLabelKeys)) + for _, key := range newConfig.AllowedLabelKeys { + newConfig.allowedLabelKeys[key] = true + } + } + + if len(newConfig.AllowedMetadataKeys) > 0 { + newConfig.allowedMetadataKeys = make(map[string]bool, len(newConfig.AllowedMetadataKeys)) + for _, key := range newConfig.AllowedMetadataKeys { + newConfig.allowedMetadataKeys[key] = true + } + } + + if newConfig.MaxMetadataValueSize == 0 { + newConfig.MaxMetadataValueSize = defaultMaxMetadataValueSize + } + + newConfig.idPathTemplate = tmpl + newConfig.trustDomain = coreConfig.TrustDomain + + return newConfig +} + // New creates a new IITAttestorPlugin. func New() *IITAttestorPlugin { return &IITAttestorPlugin{ @@ -168,66 +213,27 @@ func (p *IITAttestorPlugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServ // Configure configures the IITAttestorPlugin. func (p *IITAttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - hclConfig := new(IITAttestorConfig) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "global configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - trustDomain, err := spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "trust_domain is invalid: %v", err) - } - - if len(hclConfig.ProjectIDAllowList) == 0 { - return nil, status.Error(codes.InvalidArgument, "projectid_allow_list is required") - } - - tmpl := gcp.DefaultAgentPathTemplate - if len(hclConfig.AgentPathTemplate) > 0 { - var err error - tmpl, err = agentpathtemplate.Parse(hclConfig.AgentPathTemplate) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse agent path template: %q", hclConfig.AgentPathTemplate) - } - } - - if len(hclConfig.AllowedLabelKeys) > 0 { - hclConfig.allowedLabelKeys = make(map[string]bool, len(hclConfig.AllowedLabelKeys)) - for _, key := range hclConfig.AllowedLabelKeys { - hclConfig.allowedLabelKeys[key] = true - } - } - - if len(hclConfig.AllowedMetadataKeys) > 0 { - hclConfig.allowedMetadataKeys = make(map[string]bool, len(hclConfig.AllowedMetadataKeys)) - for _, key := range hclConfig.AllowedMetadataKeys { - hclConfig.allowedMetadataKeys[key] = true - } - } - - if hclConfig.MaxMetadataValueSize == 0 { - hclConfig.MaxMetadataValueSize = defaultMaxMetadataValueSize + return nil, err } - hclConfig.idPathTemplate = tmpl - hclConfig.trustDomain = trustDomain - p.mtx.Lock() defer p.mtx.Unlock() - - p.config = hclConfig + p.config = newConfig return &configv1.ConfigureResponse{}, nil } +func (p *IITAttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + func (p *IITAttestorPlugin) getConfig() (*IITAttestorConfig, error) { p.mtx.Lock() defer p.mtx.Unlock() diff --git a/pkg/server/plugin/nodeattestor/gcpiit/iit_test.go b/pkg/server/plugin/nodeattestor/gcpiit/iit_test.go index 449b1b8983..1bdf7943a5 100644 --- a/pkg/server/plugin/nodeattestor/gcpiit/iit_test.go +++ b/pkg/server/plugin/nodeattestor/gcpiit/iit_test.go @@ -259,6 +259,7 @@ agent_path_template = "/{{ .InstanceID }}" func (s *IITAttestorSuite) TestConfigure() { doConfig := func(t *testing.T, coreConfig catalog.CoreConfig, config string) error { + t.Logf("core config: %+v, config: %s\n", coreConfig, config) var err error plugintest.Load(t, BuiltIn(), nil, plugintest.CaptureConfigureError(&err), @@ -282,7 +283,7 @@ func (s *IITAttestorSuite) TestConfigure() { err := doConfig(t, catalog.CoreConfig{}, ` projectid_allow_list = ["bar"] `) - spiretest.AssertGRPCStatusContains(t, err, codes.InvalidArgument, "trust_domain is required") + spiretest.AssertGRPCStatusContains(t, err, codes.InvalidArgument, "server core configuration must contain trust_domain") }) s.T().Run("missing projectID allow list", func(t *testing.T) { diff --git a/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge.go b/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge.go index c4ba141a42..0946f6c7f5 100644 --- a/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge.go +++ b/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge.go @@ -15,6 +15,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/httpchallenge" + "github.com/spiffe/spire/pkg/common/pluginconf" nodeattestorbase "github.com/spiffe/spire/pkg/server/plugin/nodeattestor/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -54,6 +55,55 @@ type configuration struct { tofu bool } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *configuration { + hclConfig := new(Config) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + var dnsPatterns []*regexp.Regexp + for _, r := range hclConfig.AllowedDNSPatterns { + re := regexp.MustCompile(r) + dnsPatterns = append(dnsPatterns, re) + } + + allowNonRootPorts := true + if hclConfig.AllowNonRootPorts != nil { + allowNonRootPorts = *hclConfig.AllowNonRootPorts + } + + tofu := true + if hclConfig.TOFU != nil { + tofu = *hclConfig.TOFU + } + + mustUseTOFU := false + switch { + // User has explicitly asked for a required port that is untrusted + case hclConfig.RequiredPort != nil && *hclConfig.RequiredPort >= 1024: + mustUseTOFU = true + // User has just chosen the defaults, any port is allowed + case hclConfig.AllowNonRootPorts == nil && hclConfig.RequiredPort == nil: + mustUseTOFU = true + // User explicitly set AllowNonRootPorts to true and no required port specified + case hclConfig.AllowNonRootPorts != nil && *hclConfig.AllowNonRootPorts && hclConfig.RequiredPort == nil: + mustUseTOFU = true + } + + if !tofu && mustUseTOFU { + status.ReportError("you can not turn off trust on first use (TOFU) when non-root ports are allowed") + } + + return &configuration{ + trustDomain: coreConfig.TrustDomain, + dnsPatterns: dnsPatterns, + requiredPort: hclConfig.RequiredPort, + allowNonRootPorts: allowNonRootPorts, + tofu: tofu, + } +} + type Config struct { AllowedDNSPatterns []string `hcl:"allowed_dns_patterns"` RequiredPort *int `hcl:"required_port"` @@ -173,66 +223,25 @@ func (p *Plugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServer) error { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - hclConfig := new(Config) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - trustDomain, err := spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "trust_domain is invalid: %v", err) - } - - var dnsPatterns []*regexp.Regexp - for _, r := range hclConfig.AllowedDNSPatterns { - re := regexp.MustCompile(r) - dnsPatterns = append(dnsPatterns, re) - } - - allowNonRootPorts := true - if hclConfig.AllowNonRootPorts != nil { - allowNonRootPorts = *hclConfig.AllowNonRootPorts - } - - tofu := true - if hclConfig.TOFU != nil { - tofu = *hclConfig.TOFU + return nil, err } - mustUseTOFU := false - switch { - // User has explicitly asked for a required port that is untrusted - case hclConfig.RequiredPort != nil && *hclConfig.RequiredPort >= 1024: - mustUseTOFU = true - // User has just chosen the defaults, any port is allowed - case hclConfig.AllowNonRootPorts == nil && hclConfig.RequiredPort == nil: - mustUseTOFU = true - // User explicitly set AllowNonRootPorts to true and no required port specified - case hclConfig.AllowNonRootPorts != nil && *hclConfig.AllowNonRootPorts && hclConfig.RequiredPort == nil: - mustUseTOFU = true - } + p.m.Lock() + defer p.m.Unlock() + p.config = newConfig - if !tofu && mustUseTOFU { - return nil, status.Errorf(codes.InvalidArgument, "you can not turn off trust on first use (TOFU) when non-root ports are allowed") - } + return &configv1.ConfigureResponse{}, nil +} - p.setConfiguration(&configuration{ - trustDomain: trustDomain, - dnsPatterns: dnsPatterns, - requiredPort: hclConfig.RequiredPort, - allowNonRootPorts: allowNonRootPorts, - tofu: tofu, - }) +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) - return &configv1.ConfigureResponse{}, nil + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil } // SetLogger sets this plugin's logger @@ -249,12 +258,6 @@ func (p *Plugin) getConfig() (*configuration, error) { return p.config, nil } -func (p *Plugin) setConfiguration(config *configuration) { - p.m.Lock() - defer p.m.Unlock() - p.config = config -} - func buildSelectorValues(hostName string) []string { var selectorValues []string diff --git a/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge_test.go b/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge_test.go index cdde03f8e9..8cf9ea6a29 100644 --- a/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge_test.go +++ b/pkg/server/plugin/nodeattestor/httpchallenge/httpchallenge_test.go @@ -32,11 +32,11 @@ func TestConfigure(t *testing.T) { }{ { name: "Configure fails if core config is not provided", - expErr: "rpc error: code = InvalidArgument desc = core configuration is required", + expErr: "rpc error: code = InvalidArgument desc = server core configuration is required", }, { name: "Configure fails if trust domain is empty", - expErr: "rpc error: code = InvalidArgument desc = trust_domain is required", + expErr: "rpc error: code = InvalidArgument desc = server core configuration must contain trust_domain", coreConf: &configv1.CoreConfiguration{}, }, { diff --git a/pkg/server/plugin/nodeattestor/jointoken/join_token.go b/pkg/server/plugin/nodeattestor/jointoken/join_token.go index 9f16ff38ed..4237e58172 100644 --- a/pkg/server/plugin/nodeattestor/jointoken/join_token.go +++ b/pkg/server/plugin/nodeattestor/jointoken/join_token.go @@ -3,9 +3,12 @@ package jointoken import ( "context" + "github.com/hashicorp/hcl" + "github.com/hashicorp/hcl/hcl/token" nodeattestorv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/nodeattestor/v1" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -25,6 +28,24 @@ func builtin(p *Plugin) catalog.BuiltIn { ) } +type Configuration struct { + Extra map[string][]token.Pos `hcl:",unusedKeyPositions"` +} + +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + for key, _ := range newConfig.Extra { + status.ReportInfof("unknown setting \"%s\" encountered", key) + } + + return newConfig +} + type Plugin struct { nodeattestorv1.UnsafeNodeAttestorServer configv1.UnsafeConfigServer @@ -38,6 +59,20 @@ func (p *Plugin) Attest(nodeattestorv1.NodeAttestor_AttestServer) error { return status.Error(codes.Unimplemented, "join token attestation is currently implemented within the server") } -func (p *Plugin) Configure(context.Context, *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { +func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { + _, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err + } + return &configv1.ConfigureResponse{}, nil } + +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} diff --git a/pkg/server/plugin/nodeattestor/k8spsat/psat.go b/pkg/server/plugin/nodeattestor/k8spsat/psat.go index e56a475fce..2b71890ae4 100644 --- a/pkg/server/plugin/nodeattestor/k8spsat/psat.go +++ b/pkg/server/plugin/nodeattestor/k8spsat/psat.go @@ -13,6 +13,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/pkg/common/plugin/k8s/apiserver" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -79,6 +80,61 @@ type clusterConfig struct { allowedPodLabelKeys map[string]bool } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *attestorConfig { + hclConfig := new(AttestorConfig) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + if len(hclConfig.Clusters) < 1 { + status.ReportInfo("No clusters configured, PSAT attestation is effectively disabled") + } + + newConfig := &attestorConfig{ + trustDomain: coreConfig.TrustDomain.String(), + clusters: make(map[string]*clusterConfig), + } + + for name, hclCluster := range hclConfig.Clusters { + if len(hclCluster.ServiceAccountAllowList) == 0 { + status.ReportErrorf("cluster %q configuration must have at least one service account allowed", name) + } + + serviceAccounts := make(map[string]bool) + for _, serviceAccount := range hclCluster.ServiceAccountAllowList { + serviceAccounts[serviceAccount] = true + } + + var audience []string + if hclCluster.Audience == nil { + audience = defaultAudience + } else { + audience = *hclCluster.Audience + } + + allowedNodeLabelKeys := make(map[string]bool) + for _, label := range hclCluster.AllowedNodeLabelKeys { + allowedNodeLabelKeys[label] = true + } + + allowedPodLabelKeys := make(map[string]bool) + for _, label := range hclCluster.AllowedPodLabelKeys { + allowedPodLabelKeys[label] = true + } + + newConfig.clusters[name] = &clusterConfig{ + serviceAccounts: serviceAccounts, + audience: audience, + client: apiserver.New(hclCluster.KubeConfigFile), + allowedNodeLabelKeys: allowedNodeLabelKeys, + allowedPodLabelKeys: allowedPodLabelKeys, + } + } + + return newConfig +} + // AttestorPlugin is a PSAT (Projected SAT) node attestor plugin type AttestorPlugin struct { nodeattestorv1.UnsafeNodeAttestorServer @@ -214,65 +270,25 @@ func (p *AttestorPlugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServer) } func (p *AttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - hclConfig := new(AttestorConfig) - - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "core configuration missing trust domain") - } - - config := &attestorConfig{ - trustDomain: req.CoreConfiguration.TrustDomain, - clusters: make(map[string]*clusterConfig), + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - for name, cluster := range hclConfig.Clusters { - if len(cluster.ServiceAccountAllowList) == 0 { - return nil, status.Errorf(codes.InvalidArgument, "cluster %q configuration must have at least one service account allowed", name) - } - - serviceAccounts := make(map[string]bool) - for _, serviceAccount := range cluster.ServiceAccountAllowList { - serviceAccounts[serviceAccount] = true - } - - var audience []string - if cluster.Audience == nil { - audience = defaultAudience - } else { - audience = *cluster.Audience - } - - allowedNodeLabelKeys := make(map[string]bool) - for _, label := range cluster.AllowedNodeLabelKeys { - allowedNodeLabelKeys[label] = true - } - - allowedPodLabelKeys := make(map[string]bool) - for _, label := range cluster.AllowedPodLabelKeys { - allowedPodLabelKeys[label] = true - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - config.clusters[name] = &clusterConfig{ - serviceAccounts: serviceAccounts, - audience: audience, - client: apiserver.New(cluster.KubeConfigFile), - allowedNodeLabelKeys: allowedNodeLabelKeys, - allowedPodLabelKeys: allowedPodLabelKeys, - } - } + return &configv1.ConfigureResponse{}, nil +} - if len(hclConfig.Clusters) < 1 { - p.log.Warn("No clusters configured, PSAT attestation is effectively disabled") - } +func (p *AttestorPlugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) - p.setConfig(config) - return &configv1.ConfigureResponse{}, nil + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err } func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { @@ -283,9 +299,3 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { } return p.config, nil } - -func (p *AttestorPlugin) setConfig(config *attestorConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} diff --git a/pkg/server/plugin/nodeattestor/k8spsat/psat_test.go b/pkg/server/plugin/nodeattestor/k8spsat/psat_test.go index f4ba577b15..1d3cc0cadd 100644 --- a/pkg/server/plugin/nodeattestor/k8spsat/psat_test.go +++ b/pkg/server/plugin/nodeattestor/k8spsat/psat_test.go @@ -368,11 +368,11 @@ func (s *AttestorSuite) TestConfigure() { // malformed configuration err := doConfig(coreConfig, "blah") - s.RequireGRPCStatusContains(err, codes.InvalidArgument, "unable to decode configuration") + s.RequireGRPCStatusContains(err, codes.InvalidArgument, "plugin configuration is malformed") // missing trust domain err = doConfig(catalog.CoreConfig{}, "") - s.RequireGRPCStatus(err, codes.InvalidArgument, "core configuration missing trust domain") + s.RequireGRPCStatus(err, codes.InvalidArgument, "server core configuration must contain trust_domain") // missing clusters err = doConfig(coreConfig, "") diff --git a/pkg/server/plugin/nodeattestor/k8ssat/sat.go b/pkg/server/plugin/nodeattestor/k8ssat/sat.go index 0c360739d8..f80045a399 100644 --- a/pkg/server/plugin/nodeattestor/k8ssat/sat.go +++ b/pkg/server/plugin/nodeattestor/k8ssat/sat.go @@ -24,6 +24,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/k8s" "github.com/spiffe/spire/pkg/common/plugin/k8s/apiserver" + "github.com/spiffe/spire/pkg/common/pluginconf" nodeattestorbase "github.com/spiffe/spire/pkg/server/plugin/nodeattestor/base" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -86,6 +87,66 @@ type AttestorConfig struct { Clusters map[string]*ClusterConfig `hcl:"clusters"` } +func (p *AttestorPlugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *attestorConfig { + status.ReportInfof("The %q node attestor plugin has been deprecated in favor of the \"k8s_psat\" plugin and will be removed in a future release", pluginName) + p.log.Warn(fmt.Sprintf("The %q node attestor plugin has been deprecated in favor of the \"k8s_psat\" plugin and will be removed in a future release", pluginName)) + + hclConfig := new(AttestorConfig) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + if len(hclConfig.Clusters) == 0 { + status.ReportError("configuration must have at least one cluster") + } + + newConfig := &attestorConfig{ + trustDomain: coreConfig.TrustDomain.String(), + clusters: make(map[string]*clusterConfig), + } + + for name, cluster := range hclConfig.Clusters { + var serviceAccountKeys []crypto.PublicKey + var apiserverClient apiServerClient + var err error + if cluster.UseTokenReviewAPI { + apiserverClient = apiserver.New(cluster.KubeConfigFile) + } else { + if cluster.ServiceAccountKeyFile == "" { + status.ReportErrorf("cluster %q configuration missing service account key file", name) + } + + serviceAccountKeys, err = loadServiceAccountKeys(cluster.ServiceAccountKeyFile) + if err != nil { + status.ReportErrorf("failed to load cluster %q service account keys from %q: %v", name, cluster.ServiceAccountKeyFile, err) + } + + if len(serviceAccountKeys) == 0 { + status.ReportErrorf("cluster %q has no service account keys in %q", name, cluster.ServiceAccountKeyFile) + } + } + + if len(cluster.ServiceAccountAllowList) == 0 { + status.ReportErrorf("cluster %q configuration must have at least one service account allowed", name) + } + + serviceAccounts := make(map[string]bool) + for _, serviceAccount := range cluster.ServiceAccountAllowList { + serviceAccounts[serviceAccount] = true + } + + newConfig.clusters[name] = &clusterConfig{ + serviceAccountKeys: serviceAccountKeys, + serviceAccounts: serviceAccounts, + useTokenReviewAPI: cluster.UseTokenReviewAPI, + client: apiserverClient, + } + } + + return newConfig +} + type apiServerClient interface { ValidateToken(ctx context.Context, token string, audiences []string) (*authv1.TokenReviewStatus, error) } @@ -277,68 +338,25 @@ func (p *AttestorPlugin) getNamesFromClaims(claims *k8s.SATClaims) (namespace st } func (p *AttestorPlugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - p.log.Warn(fmt.Sprintf("The %q node attestor plugin has been deprecated in favor of the \"k8s_psat\" plugin and will be removed in a future release", pluginName)) - - hclConfig := new(AttestorConfig) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "core configuration missing trust domain") - } - - if len(hclConfig.Clusters) == 0 { - return nil, status.Error(codes.InvalidArgument, "configuration must have at least one cluster") + newConfig, _, err := pluginconf.Build(req, p.buildConfig) + if err != nil { + return nil, err } - config := &attestorConfig{ - trustDomain: req.CoreConfiguration.TrustDomain, - clusters: make(map[string]*clusterConfig), - } - config.trustDomain = req.CoreConfiguration.TrustDomain - for name, cluster := range hclConfig.Clusters { - var serviceAccountKeys []crypto.PublicKey - var apiserverClient apiServerClient - var err error - if cluster.UseTokenReviewAPI { - apiserverClient = apiserver.New(cluster.KubeConfigFile) - } else { - if cluster.ServiceAccountKeyFile == "" { - return nil, status.Errorf(codes.InvalidArgument, "cluster %q configuration missing service account key file", name) - } - - serviceAccountKeys, err = loadServiceAccountKeys(cluster.ServiceAccountKeyFile) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to load cluster %q service account keys from %q: %v", name, cluster.ServiceAccountKeyFile, err) - } - - if len(serviceAccountKeys) == 0 { - return nil, status.Errorf(codes.InvalidArgument, "cluster %q has no service account keys in %q", name, cluster.ServiceAccountKeyFile) - } - } - - if len(cluster.ServiceAccountAllowList) == 0 { - return nil, status.Errorf(codes.InvalidArgument, "cluster %q configuration must have at least one service account allowed", name) - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - serviceAccounts := make(map[string]bool) - for _, serviceAccount := range cluster.ServiceAccountAllowList { - serviceAccounts[serviceAccount] = true - } + return &configv1.ConfigureResponse{}, nil +} - config.clusters[name] = &clusterConfig{ - serviceAccountKeys: serviceAccountKeys, - serviceAccounts: serviceAccounts, - useTokenReviewAPI: cluster.UseTokenReviewAPI, - client: apiserverClient, - } - } +func (p *AttestorPlugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) - p.setConfig(config) - return &configv1.ConfigureResponse{}, nil + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err } func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { @@ -350,12 +368,6 @@ func (p *AttestorPlugin) getConfig() (*attestorConfig, error) { return p.config, nil } -func (p *AttestorPlugin) setConfig(config *attestorConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} - func verifyTokenSignature(keys []crypto.PublicKey, token *jwt.JSONWebToken, claims any) (err error) { var lastErr error for _, key := range keys { diff --git a/pkg/server/plugin/nodeattestor/k8ssat/sat_test.go b/pkg/server/plugin/nodeattestor/k8ssat/sat_test.go index b0768c1e85..40bdb38955 100644 --- a/pkg/server/plugin/nodeattestor/k8ssat/sat_test.go +++ b/pkg/server/plugin/nodeattestor/k8ssat/sat_test.go @@ -308,7 +308,7 @@ func (s *AttestorSuite) TestConfigure() { // missing trust domain err = doConfig(catalog.CoreConfig{}, "") - s.RequireGRPCStatus(err, codes.InvalidArgument, "core configuration missing trust domain") + s.RequireGRPCStatus(err, codes.InvalidArgument, "server core configuration must contain trust_domain") // missing clusters err = doConfig(coreConfig, "") diff --git a/pkg/server/plugin/nodeattestor/sshpop/sshpop.go b/pkg/server/plugin/nodeattestor/sshpop/sshpop.go index 4a573b5926..53993cd06a 100644 --- a/pkg/server/plugin/nodeattestor/sshpop/sshpop.go +++ b/pkg/server/plugin/nodeattestor/sshpop/sshpop.go @@ -8,6 +8,7 @@ import ( configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/sshpop" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -95,15 +96,23 @@ func (p *Plugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServer) error { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - sshserver, err := sshpop.NewServer(req.CoreConfiguration.GetTrustDomain(), req.HclConfiguration) + newConfig, _, err := pluginconf.Build(req, sshpop.BuildServerConfig) if err != nil { return nil, err } + p.mu.Lock() - p.sshserver = sshserver + p.sshserver = newConfig.NewServer() p.mu.Unlock() + return &configv1.ConfigureResponse{}, nil } + +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, sshpop.BuildServerConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} diff --git a/pkg/server/plugin/nodeattestor/sshpop/sshpop_test.go b/pkg/server/plugin/nodeattestor/sshpop/sshpop_test.go index dc5fbe335f..c8ae9dcfa0 100644 --- a/pkg/server/plugin/nodeattestor/sshpop/sshpop_test.go +++ b/pkg/server/plugin/nodeattestor/sshpop/sshpop_test.go @@ -59,7 +59,7 @@ func (s *Suite) loadPlugin(t *testing.T) nodeattestor.NodeAttestor { clientConfig := fmt.Sprintf(` host_key_path = %q host_cert_path = %q`, privateKeyPath, certificatePath) - sshclient, err := sshpop.NewClient(clientConfig) + sshclient, err := sshpop.NewClient("example.org", clientConfig) require.NoError(t, err) s.sshclient = sshclient diff --git a/pkg/server/plugin/nodeattestor/tpmdevid/devid.go b/pkg/server/plugin/nodeattestor/tpmdevid/devid.go index 381b97adcc..c196c2b03b 100644 --- a/pkg/server/plugin/nodeattestor/tpmdevid/devid.go +++ b/pkg/server/plugin/nodeattestor/tpmdevid/devid.go @@ -19,6 +19,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/idutil" common_devid "github.com/spiffe/spire/pkg/common/plugin/tpmdevid" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -39,6 +40,11 @@ func builtin(p *Plugin) catalog.BuiltIn { ) } +type Config struct { + DevIDBundlePath string `hcl:"devid_ca_path"` + EndorsementBundlePath string `hcl:"endorsement_ca_path"` +} + type config struct { trustDomain spiffeid.TrustDomain @@ -46,9 +52,39 @@ type config struct { ekRoots *x509.CertPool } -type Config struct { - DevIDBundlePath string `hcl:"devid_ca_path"` - EndorsementBundlePath string `hcl:"endorsement_ca_path"` +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *config { + hclConfig := new(Config) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + if hclConfig.DevIDBundlePath == "" { + status.ReportError("devid_ca_path is required") + } + if hclConfig.EndorsementBundlePath == "" { + status.ReportError("endorsement_ca_path is required") + } + + // Create initial internal configuration + newConfig := &config{ + trustDomain: coreConfig.TrustDomain, + } + + // Load DevID bundle + var err error + newConfig.devIDRoots, err = util.LoadCertPool(hclConfig.DevIDBundlePath) + if err != nil { + status.ReportErrorf("unable to load DevID trust bundle: %v", err) + } + + // Load endorsement bundle if configured + newConfig.ekRoots, err = util.LoadCertPool(hclConfig.EndorsementBundlePath) + if err != nil { + status.ReportErrorf("unable to load endorsement trust bundle: %v", err) + } + + return newConfig } type Plugin struct { @@ -188,41 +224,25 @@ func (p *Plugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServer) error { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - trustDomain, err := parseCoreConfig(req.CoreConfiguration) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - extConf, err := decodePluginConfig(req.HclConfiguration) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - err = validatePluginConfig(extConf) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid configuration: %v", err) - } - - // Create initial internal configuration - intConf := &config{ - trustDomain: trustDomain, - } - - // Load DevID bundle - intConf.devIDRoots, err = util.LoadCertPool(extConf.DevIDBundlePath) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to load DevID trust bundle: %v", err) - } + p.m.Lock() + defer p.m.Unlock() + p.c = newConfig - // Load endorsement bundle if configured - intConf.ekRoots, err = util.LoadCertPool(extConf.EndorsementBundlePath) - if err != nil { - return nil, status.Errorf(codes.Internal, "unable to load endorsement trust bundle: %v", err) - } + return &configv1.ConfigureResponse{}, nil +} - p.setConfiguration(intConf) +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) - return &configv1.ConfigureResponse{}, nil + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err } func (p *Plugin) getConfiguration() *config { @@ -231,50 +251,6 @@ func (p *Plugin) getConfiguration() *config { return p.c } -func (p *Plugin) setConfiguration(c *config) { - p.m.Lock() - defer p.m.Unlock() - p.c = c -} - -func decodePluginConfig(hclConf string) (*Config, error) { - extConfig := new(Config) - if err := hcl.Decode(extConfig, hclConf); err != nil { - return nil, err - } - - return extConfig, nil -} - -func parseCoreConfig(c *configv1.CoreConfiguration) (spiffeid.TrustDomain, error) { - if c == nil { - return spiffeid.TrustDomain{}, status.Error(codes.InvalidArgument, "core configuration is missing") - } - - if c.TrustDomain == "" { - return spiffeid.TrustDomain{}, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - trustDomain, err := spiffeid.TrustDomainFromString(c.TrustDomain) - if err != nil { - return spiffeid.TrustDomain{}, status.Errorf(codes.InvalidArgument, "trust_domain is invalid: %v", err) - } - - return trustDomain, nil -} - -func validatePluginConfig(extConf *Config) error { - switch { - case extConf.DevIDBundlePath == "": - return errors.New("devid_ca_path is required") - - case extConf.EndorsementBundlePath == "": - return errors.New("endorsement_ca_path is required") - } - - return nil -} - func verifyDevIDSignature(cert *x509.Certificate, intermediates *x509.CertPool, roots *x509.CertPool) ([][]*x509.Certificate, error) { chains, err := cert.Verify(x509.VerifyOptions{ Roots: roots, diff --git a/pkg/server/plugin/nodeattestor/tpmdevid/devid_test.go b/pkg/server/plugin/nodeattestor/tpmdevid/devid_test.go index a674505946..50e66ec549 100644 --- a/pkg/server/plugin/nodeattestor/tpmdevid/devid_test.go +++ b/pkg/server/plugin/nodeattestor/tpmdevid/devid_test.go @@ -88,40 +88,40 @@ func TestConfigure(t *testing.T) { }{ { name: "Configure fails if core config is not provided", - expErr: "rpc error: code = InvalidArgument desc = core configuration is missing", + expErr: "rpc error: code = InvalidArgument desc = server core configuration is required", }, { name: "Configure fails if trust domain is empty", - expErr: "rpc error: code = InvalidArgument desc = trust_domain is required", + expErr: "rpc error: code = InvalidArgument desc = server core configuration must contain trust_domain", coreConf: &configv1.CoreConfiguration{}, }, { name: "Configure fails if HCL config cannot be decoded", - expErr: "rpc error: code = InvalidArgument desc = unable to decode configuration", + expErr: "rpc error: code = InvalidArgument desc = plugin configuration is malformed", coreConf: &configv1.CoreConfiguration{TrustDomain: "example.org"}, hclConf: "not an HCL configuration", }, { name: "Configure fails if devid_ca_path is not provided", - expErr: "rpc error: code = InvalidArgument desc = invalid configuration: devid_ca_path is required", + expErr: "rpc error: code = InvalidArgument desc = devid_ca_path is required", coreConf: &configv1.CoreConfiguration{TrustDomain: "example.org"}, }, { name: "Configure fails if endorsement_ca_path is not provided", - expErr: "rpc error: code = InvalidArgument desc = invalid configuration: endorsement_ca_path is required", + expErr: "rpc error: code = InvalidArgument desc = endorsement_ca_path is required", coreConf: &configv1.CoreConfiguration{TrustDomain: "example.org"}, hclConf: `devid_ca_path = "non-existent/devid/bundle/path"`, }, { name: "Configure fails if DevID trust bundle cannot be loaded", - expErr: "rpc error: code = Internal desc = unable to load DevID trust bundle: open non-existent/devid/bundle/path:", + expErr: "rpc error: code = InvalidArgument desc = unable to load DevID trust bundle: open non-existent/devid/bundle/path:", coreConf: &configv1.CoreConfiguration{TrustDomain: "example.org"}, hclConf: `devid_ca_path = "non-existent/devid/bundle/path" endorsement_ca_path = "non-existent/endorsement/bundle/path"`, }, { name: "Configure fails if endorsement trust bundle cannot be opened", - expErr: "rpc error: code = Internal desc = unable to load endorsement trust bundle: open non-existent/endorsement/bundle/path:", + expErr: "rpc error: code = InvalidArgument desc = unable to load endorsement trust bundle: open non-existent/endorsement/bundle/path:", coreConf: &configv1.CoreConfiguration{TrustDomain: "example.org"}, hclConf: fmt.Sprintf(`devid_ca_path = %q endorsement_ca_path = "non-existent/endorsement/bundle/path"`, diff --git a/pkg/server/plugin/nodeattestor/x509pop/x509pop.go b/pkg/server/plugin/nodeattestor/x509pop/x509pop.go index da229f9b9b..d8ba2d7fae 100644 --- a/pkg/server/plugin/nodeattestor/x509pop/x509pop.go +++ b/pkg/server/plugin/nodeattestor/x509pop/x509pop.go @@ -13,6 +13,7 @@ import ( "github.com/spiffe/spire/pkg/common/agentpathtemplate" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/plugin/x509pop" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -33,16 +34,63 @@ func builtin(p *Plugin) catalog.BuiltIn { ) } +type Config struct { + CABundlePath string `hcl:"ca_bundle_path"` + CABundlePaths []string `hcl:"ca_bundle_paths"` + AgentPathTemplate string `hcl:"agent_path_template"` +} + type configuration struct { trustDomain spiffeid.TrustDomain trustBundle *x509.CertPool pathTemplate *agentpathtemplate.Template } -type Config struct { - CABundlePath string `hcl:"ca_bundle_path"` - CABundlePaths []string `hcl:"ca_bundle_paths"` - AgentPathTemplate string `hcl:"agent_path_template"` +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *configuration { + hclConfig := new(Config) + if err := hcl.Decode(hclConfig, hclText); err != nil { + status.ReportErrorf("unable to decode configuration: %v", err) + return nil + } + + var caPaths []string + if hclConfig.CABundlePath != "" && len(hclConfig.CABundlePaths) > 0 { + status.ReportError("only one of ca_bundle_path or ca_bundle_paths can be configured, not both") + } + if hclConfig.CABundlePath != "" { + caPaths = []string{hclConfig.CABundlePath} + } else { + caPaths = hclConfig.CABundlePaths + } + if len(caPaths) == 0 { + status.ReportError("one of ca_bundle_path or ca_bundle_paths must be configured") + } + + var trustBundles []*x509.Certificate + for _, caPath := range caPaths { + certs, err := util.LoadCertificates(caPath) + if err != nil { + status.ReportErrorf("unable to load trust bundle %q: %v", caPath, err) + } + trustBundles = append(trustBundles, certs...) + } + + pathTemplate := x509pop.DefaultAgentPathTemplate + if len(hclConfig.AgentPathTemplate) > 0 { + tmpl, err := agentpathtemplate.Parse(hclConfig.AgentPathTemplate) + if err != nil { + status.ReportErrorf("failed to parse agent svid template: %q", hclConfig.AgentPathTemplate) + } + pathTemplate = tmpl + } + + newConfig := &configuration{ + trustDomain: coreConfig.TrustDomain, + trustBundle: util.NewCertPool(trustBundles...), + pathTemplate: pathTemplate, + } + + return newConfig } type Plugin struct { @@ -157,71 +205,25 @@ func (p *Plugin) Attest(stream nodeattestorv1.NodeAttestor_AttestServer) error { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - hclConfig := new(Config) - if err := hcl.Decode(hclConfig, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - trustDomain, err := spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "trust_domain is invalid: %v", err) - } - - bundles, err := getBundles(hclConfig) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } - pathTemplate := x509pop.DefaultAgentPathTemplate - if len(hclConfig.AgentPathTemplate) > 0 { - tmpl, err := agentpathtemplate.Parse(hclConfig.AgentPathTemplate) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to parse agent svid template: %q", hclConfig.AgentPathTemplate) - } - pathTemplate = tmpl - } - - p.setConfiguration(&configuration{ - trustDomain: trustDomain, - trustBundle: util.NewCertPool(bundles...), - pathTemplate: pathTemplate, - }) + p.m.Lock() + defer p.m.Unlock() + p.config = newConfig return &configv1.ConfigureResponse{}, nil } -func getBundles(config *Config) ([]*x509.Certificate, error) { - var caPaths []string +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) - switch { - case config.CABundlePath != "" && len(config.CABundlePaths) > 0: - return nil, status.Error(codes.InvalidArgument, "only one of ca_bundle_path or ca_bundle_paths can be configured, not both") - case config.CABundlePath != "": - caPaths = append(caPaths, config.CABundlePath) - case len(config.CABundlePaths) > 0: - caPaths = append(caPaths, config.CABundlePaths...) - default: - return nil, status.Error(codes.InvalidArgument, "ca_bundle_path or ca_bundle_paths must be configured") - } - - var cas []*x509.Certificate - for _, caPath := range caPaths { - certs, err := util.LoadCertificates(caPath) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to load trust bundle %q: %v", caPath, err) - } - cas = append(cas, certs...) - } - - return cas, nil + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err } func (p *Plugin) getConfig() (*configuration, error) { @@ -233,12 +235,6 @@ func (p *Plugin) getConfig() (*configuration, error) { return p.config, nil } -func (p *Plugin) setConfiguration(config *configuration) { - p.m.Lock() - defer p.m.Unlock() - p.config = config -} - func buildSelectorValues(leaf *x509.Certificate, chains [][]*x509.Certificate) []string { var selectorValues []string diff --git a/pkg/server/plugin/nodeattestor/x509pop/x509pop_test.go b/pkg/server/plugin/nodeattestor/x509pop/x509pop_test.go index 1fc69de8a1..59bf072cbb 100644 --- a/pkg/server/plugin/nodeattestor/x509pop/x509pop_test.go +++ b/pkg/server/plugin/nodeattestor/x509pop/x509pop_test.go @@ -236,7 +236,7 @@ func (s *Suite) TestConfigure() { err := doConfig(t, catalog.CoreConfig{}, ` ca_bundle_path = "blah" `) - spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "trust_domain is required") + spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "server core configuration must contain trust_domain") }) s.T().Run("missing ca_bundle_path and ca_bundle_paths", func(t *testing.T) { diff --git a/pkg/server/plugin/notifier/gcsbundle/gcsbundle.go b/pkg/server/plugin/notifier/gcsbundle/gcsbundle.go index 4a91919ee1..8756b00368 100644 --- a/pkg/server/plugin/notifier/gcsbundle/gcsbundle.go +++ b/pkg/server/plugin/notifier/gcsbundle/gcsbundle.go @@ -17,6 +17,7 @@ import ( plugintypes "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/telemetry" "google.golang.org/api/googleapi" "google.golang.org/api/option" @@ -41,19 +42,36 @@ type bucketClient interface { Close() error } -type pluginConfig struct { +type configuration struct { Bucket string `hcl:"bucket"` ObjectPath string `hcl:"object_path"` ServiceAccountFile string `hcl:"service_account_file"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *configuration { + newConfig := new(configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportErrorf("plugin configuration is malformed: %s", err) + return nil + } + + if newConfig.Bucket == "" { + status.ReportError("bucket must be set") + } + if newConfig.ObjectPath == "" { + status.ReportError("object_path must be set") + } + + return newConfig +} + type Plugin struct { notifierv1.UnsafeNotifierServer configv1.UnsafeConfigServer mu sync.RWMutex log hclog.Logger - config *pluginConfig + config *configuration identityProvider identityproviderv1.IdentityProviderServiceClient hooks struct { @@ -109,23 +127,28 @@ func (p *Plugin) NotifyAndAdvise(ctx context.Context, req *notifierv1.NotifyAndA } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (resp *configv1.ConfigureResponse, err error) { - config := new(pluginConfig) - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - if config.Bucket == "" { - return nil, status.Error(codes.InvalidArgument, "bucket must be set") - } - if config.ObjectPath == "" { - return nil, status.Error(codes.InvalidArgument, "object_path must be set") - } + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig - p.setConfig(config) return &configv1.ConfigureResponse{}, nil } -func (p *Plugin) getConfig() (*pluginConfig, error) { +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (resp *configv1.ValidateResponse, err error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + +func (p *Plugin) getConfig() (*configuration, error) { p.mu.RLock() defer p.mu.RUnlock() if p.config == nil { @@ -134,13 +157,7 @@ func (p *Plugin) getConfig() (*pluginConfig, error) { return p.config, nil } -func (p *Plugin) setConfig(config *pluginConfig) { - p.mu.Lock() - defer p.mu.Unlock() - p.config = config -} - -func (p *Plugin) updateBundleObject(ctx context.Context, c *pluginConfig) (err error) { +func (p *Plugin) updateBundleObject(ctx context.Context, c *configuration) (err error) { client, err := p.hooks.newBucketClient(ctx, c.ServiceAccountFile) if err != nil { return status.Errorf(codes.Unknown, "unable to instantiate bucket client: %v", err) diff --git a/pkg/server/plugin/notifier/gcsbundle/gcsbundle_test.go b/pkg/server/plugin/notifier/gcsbundle/gcsbundle_test.go index dae509de6e..d211bf816b 100644 --- a/pkg/server/plugin/notifier/gcsbundle/gcsbundle_test.go +++ b/pkg/server/plugin/notifier/gcsbundle/gcsbundle_test.go @@ -8,8 +8,10 @@ import ( "sync" "testing" + "github.com/spiffe/go-spiffe/v2/spiffeid" identityproviderv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/hostservice/server/identityprovider/v1" plugintypes "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/server/plugin/notifier" "github.com/spiffe/spire/proto/spire/common" "github.com/spiffe/spire/test/fakes/fakeidentityprovider" @@ -28,21 +30,24 @@ func TestRequiresIdentityProvider(t *testing.T) { func TestConfigure(t *testing.T) { testCases := []struct { - name string - config string - code codes.Code - desc string + name string + trustDomain string + config string + code codes.Code + desc string }{ { - name: "malformed", + name: "malformed", + trustDomain: "example.org", config: ` MALFORMED `, code: codes.InvalidArgument, - desc: "unable to decode configuration", + desc: "plugin configuration is malformed", }, { - name: "missing bucket", + name: "missing bucket", + trustDomain: "example.org", config: ` object_path = "bundle.pem" `, @@ -50,7 +55,8 @@ func TestConfigure(t *testing.T) { desc: "bucket must be set", }, { - name: "missing object path", + name: "missing object path", + trustDomain: "example.org", config: ` bucket = "the-bucket" `, @@ -58,7 +64,8 @@ func TestConfigure(t *testing.T) { desc: "object_path must be set", }, { - name: "success without service account file", + name: "success without service account file", + trustDomain: "example.org", config: ` bucket = "the-bucket" object_path = "bundle.pem" @@ -66,7 +73,8 @@ func TestConfigure(t *testing.T) { code: codes.OK, }, { - name: "success with service account file", + name: "success with service account file", + trustDomain: "example.org", config: ` bucket = "the-bucket" object_path = "bundle.pem" @@ -83,6 +91,9 @@ func TestConfigure(t *testing.T) { var err error plugintest.Load(t, BuiltIn(), nil, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + }), plugintest.Configure(tt.config), plugintest.CaptureConfigureError(&err), plugintest.HostServices(identityproviderv1.IdentityProviderServiceServer(idp))) @@ -220,6 +231,9 @@ func testUpdateBundleObject(t *testing.T, notify func(notifier.Notifier) error) plugintest.HostServices(identityproviderv1.IdentityProviderServiceServer(idp)), } if !tt.skipConfigure { + options = append(options, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + })) options = append(options, plugintest.Configure(` bucket = "the-bucket" object_path = "bundle.pem" diff --git a/pkg/server/plugin/notifier/k8sbundle/k8sbundle.go b/pkg/server/plugin/notifier/k8sbundle/k8sbundle.go index a58991675b..a6c3a6daee 100644 --- a/pkg/server/plugin/notifier/k8sbundle/k8sbundle.go +++ b/pkg/server/plugin/notifier/k8sbundle/k8sbundle.go @@ -18,6 +18,7 @@ import ( plugintypes "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" admissionv1 "k8s.io/api/admissionregistration/v1" @@ -63,24 +64,36 @@ type cluster struct { KubeConfigFilePath string `hcl:"kube_config_file_path"` } -type pluginConfig struct { +type Configuration struct { cluster `hcl:",squash"` // for hcl v2 it should be `hcl:",remain"` Clusters []cluster `hcl:"clusters"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + // TODO: move some of the Configure func stuff here. + + return newConfig +} + type Plugin struct { notifierv1.UnsafeNotifierServer configv1.UnsafeConfigServer mu sync.RWMutex log hclog.Logger - config *pluginConfig + config *Configuration identityProvider identityproviderv1.IdentityProviderServiceClient clients []kubeClient stopCh chan struct{} hooks struct { - newKubeClients func(c *pluginConfig) ([]kubeClient, error) + newKubeClients func(c *Configuration) ([]kubeClient, error) informerCallback informerCallback } } @@ -124,39 +137,49 @@ func (p *Plugin) NotifyAndAdvise(ctx context.Context, req *notifierv1.NotifyAndA } func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (resp *configv1.ConfigureResponse, err error) { - config := new(pluginConfig) - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - // root set with at least one value or the whole configuration is empty - if hasRootCluster(&config.cluster) || !hasRootCluster(&config.cluster) && !hasMultipleClusters(config.Clusters) { - setDefaultValues(&config.cluster) + if hasRootCluster(&newConfig.cluster) || !hasRootCluster(&newConfig.cluster) && !hasMultipleClusters(newConfig.Clusters) { + setDefaultValues(&newConfig.cluster) } - for i := range config.Clusters { - if config.Clusters[i].KubeConfigFilePath == "" { + + // root set with at least one value or the whole configuration is empty + for i := range newConfig.Clusters { + if newConfig.Clusters[i].KubeConfigFilePath == "" { return nil, status.Error(codes.InvalidArgument, "cluster configuration is missing kube_config_file_path") } - setDefaultValues(&config.Clusters[i]) + setDefaultValues(&newConfig.Clusters[i]) } - clients, err := p.hooks.newKubeClients(config) + clients, err := p.hooks.newKubeClients(newConfig) if err != nil { return nil, status.Errorf(codes.Internal, "unable to create new kubeClients: %v", err) } stopCh := make(chan struct{}) - if err = p.startInformers(ctx, config, clients, stopCh); err != nil { + if err = p.startInformers(ctx, newConfig, clients, stopCh); err != nil { close(stopCh) return nil, status.Errorf(codes.Internal, "unable to start informers: %v", err) } - p.setConfig(config, clients, stopCh) + p.setConfig(newConfig, clients, stopCh) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // startInformers creates informers to set CA Bundle in objects created after server has started -func (p *Plugin) startInformers(ctx context.Context, config *pluginConfig, clients []kubeClient, stopCh chan struct{}) error { +func (p *Plugin) startInformers(ctx context.Context, config *Configuration, clients []kubeClient, stopCh chan struct{}) error { if config.WebhookLabel != "" || config.APIServiceLabel != "" { informerSynced := []cache.InformerSynced{} for _, client := range clients { @@ -177,7 +200,7 @@ func (p *Plugin) startInformers(ctx context.Context, config *pluginConfig, clien return nil } -func (p *Plugin) setConfig(config *pluginConfig, clients []kubeClient, stopCh chan struct{}) { +func (p *Plugin) setConfig(config *Configuration, clients []kubeClient, stopCh chan struct{}) { p.mu.Lock() defer p.mu.Unlock() @@ -304,7 +327,7 @@ func (p *Plugin) informerCallback(client kubeClient, obj runtime.Object) { } } -func newKubeClients(c *pluginConfig) ([]kubeClient, error) { +func newKubeClients(c *Configuration) ([]kubeClient, error) { clients := []kubeClient{} if hasRootCluster(&c.cluster) { diff --git a/pkg/server/plugin/notifier/k8sbundle/k8sbundle_test.go b/pkg/server/plugin/notifier/k8sbundle/k8sbundle_test.go index 38a9c14673..68f002ab51 100644 --- a/pkg/server/plugin/notifier/k8sbundle/k8sbundle_test.go +++ b/pkg/server/plugin/notifier/k8sbundle/k8sbundle_test.go @@ -18,6 +18,7 @@ import ( identityproviderv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/hostservice/server/identityprovider/v1" plugintypes "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/types" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/pemutil" "github.com/spiffe/spire/pkg/server/plugin/notifier" "github.com/spiffe/spire/proto/spire/common" @@ -382,8 +383,10 @@ api_service_label = "API_SERVICE_LABEL2" kube_config_file_path = "/some/file/path" ` _, err := test.rawPlugin.Configure(context.Background(), &configv1.ConfigureRequest{ - CoreConfiguration: &configv1.CoreConfiguration{}, - HclConfiguration: finalConfig, + CoreConfiguration: &configv1.CoreConfiguration{ + TrustDomain: "example.org", + }, + HclConfiguration: finalConfig, }) require.NoError(t, err) require.NotNil(t, test.rawPlugin.stopCh) @@ -534,7 +537,7 @@ func TestConfigureWithMalformedConfiguration(t *testing.T) { CoreConfiguration: coreConfig, }) - spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "unable to decode configuration") + spiretest.RequireGRPCStatusContains(t, err, codes.InvalidArgument, "plugin configuration is malformed") } func TestBundleFailsToLoadIfHostServicesUnavailable(t *testing.T) { @@ -547,14 +550,16 @@ func TestBundleFailsToLoadIfHostServicesUnavailable(t *testing.T) { func TestConfigure(t *testing.T) { for _, tt := range []struct { name string + trustDomain string configuration string expectedErr string expectedCode codes.Code - expectedConfig *pluginConfig + expectedConfig *Configuration }{ { - name: "empty configuration", - expectedConfig: &pluginConfig{ + name: "empty configuration", + trustDomain: "example.org", + expectedConfig: &Configuration{ cluster: cluster{ Namespace: "spire", ConfigMap: "spire-bundle", @@ -563,7 +568,8 @@ func TestConfigure(t *testing.T) { }, }, { - name: "full configuration", + name: "full configuration", + trustDomain: "example.org", configuration: ` namespace = "root" config_map = "root_config_map" @@ -590,7 +596,7 @@ func TestConfigure(t *testing.T) { }, ] `, - expectedConfig: &pluginConfig{ + expectedConfig: &Configuration{ cluster: cluster{ Namespace: "root", ConfigMap: "root_config_map", @@ -620,11 +626,12 @@ func TestConfigure(t *testing.T) { }, }, { - name: "root only with partial configuration", + name: "root only with partial configuration", + trustDomain: "example.org", configuration: ` api_service_label = "root_api_label" `, - expectedConfig: &pluginConfig{ + expectedConfig: &Configuration{ cluster: cluster{ Namespace: "spire", ConfigMap: "spire-bundle", @@ -635,7 +642,8 @@ func TestConfigure(t *testing.T) { }, }, { - name: "clusters only with partial configuration", + name: "clusters only with partial configuration", + trustDomain: "example.org", configuration: ` clusters = [ { @@ -648,7 +656,7 @@ func TestConfigure(t *testing.T) { }, ] `, - expectedConfig: &pluginConfig{ + expectedConfig: &Configuration{ Clusters: []cluster{ { Namespace: "spire", @@ -667,6 +675,7 @@ func TestConfigure(t *testing.T) { }, { name: "clusters only missing kube_config_file_path", + trustDomain: "example.org", expectedErr: "cluster configuration is missing kube_config_file_path", expectedCode: codes.InvalidArgument, configuration: ` @@ -711,7 +720,7 @@ type fakeKubeClient struct { configMapKey string } -func newFakeKubeClient(config *pluginConfig, configMaps ...*corev1.ConfigMap) *fakeKubeClient { +func newFakeKubeClient(config *Configuration, configMaps ...*corev1.ConfigMap) *fakeKubeClient { fake := &fakeKubeClient{ configMaps: make(map[string]*corev1.ConfigMap), namespace: config.Namespace, @@ -838,7 +847,7 @@ type fakeWebhookClient struct { watcherStarted chan struct{} } -func newFakeWebhookClient(config *pluginConfig) *fakeWebhookClient { +func newFakeWebhookClient(config *Configuration) *fakeWebhookClient { client := fake.NewSimpleClientset() w := &fakeWebhookClient{ mutatingWebhookClient: mutatingWebhookClient{ @@ -894,7 +903,7 @@ type fakeAPIServiceClient struct { watcherStarted chan struct{} } -func newFakeAPIServiceClient(config *pluginConfig) *fakeAPIServiceClient { +func newFakeAPIServiceClient(config *Configuration) *fakeAPIServiceClient { client := fakeaggregator.NewSimpleClientset() a := &fakeAPIServiceClient{ apiServiceClient: apiServiceClient{ @@ -952,6 +961,7 @@ type test struct { } type testOptions struct { + trustDomain spiffeid.TrustDomain plainConfig string kubeClientError bool doConfigure bool @@ -981,6 +991,7 @@ func withInformerCallback(callback informerCallback) testOption { func setupTest(t *testing.T, options ...testOption) *test { args := &testOptions{ doConfigure: true, + trustDomain: spiffeid.RequireTrustDomainFromString("example.org"), plainConfig: fmt.Sprintf(` namespace = "%s" config_map = "%s" @@ -992,7 +1003,7 @@ func setupTest(t *testing.T, options ...testOption) *test { opt(args) } - config := new(pluginConfig) + config := new(Configuration) err := hcl.Decode(&config, args.plainConfig) require.Nil(t, err) @@ -1007,7 +1018,7 @@ func setupTest(t *testing.T, options ...testOption) *test { } test.kubeClient = newFakeKubeClient(config) - raw.hooks.newKubeClients = func(c *pluginConfig) ([]kubeClient, error) { + raw.hooks.newKubeClients = func(c *Configuration) ([]kubeClient, error) { if args.kubeClientError { return nil, errors.New("kube client not configured") } @@ -1036,6 +1047,9 @@ func setupTest(t *testing.T, options ...testOption) *test { builtIn(raw), notifier, plugintest.HostServices(identityproviderv1.IdentityProviderServiceServer(identityProvider)), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: args.trustDomain, + }), plugintest.Configure(args.plainConfig), ) } else { diff --git a/pkg/server/plugin/upstreamauthority/awspca/pca.go b/pkg/server/plugin/upstreamauthority/awspca/pca.go index e3c596c190..a4ece3856c 100644 --- a/pkg/server/plugin/upstreamauthority/awspca/pca.go +++ b/pkg/server/plugin/upstreamauthority/awspca/pca.go @@ -19,6 +19,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/x509util" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -61,6 +62,22 @@ type Configuration struct { SupplementalBundlePath string `hcl:"supplemental_bundle_path" json:"supplemental_bundle_path"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + if newConfig.Region == "" { + status.ReportError("plugin configuration is missing the region") + } + if newConfig.CertificateAuthorityARN == "" { + status.ReportError("plugin configuration is missing the certificate_authority_arn") + } + + return newConfig +} + // PCAPlugin is the main representation of this upstreamauthority plugin type PCAPlugin struct { upstreamauthorityv1.UnsafeUpstreamAuthorityServer @@ -105,30 +122,30 @@ func (p *PCAPlugin) SetLogger(log hclog.Logger) { // Configure sets up the plugin for use as an upstream authority func (p *PCAPlugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := p.validateConfig(req) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { return nil, err } var supplementalBundle []*x509.Certificate - if config.SupplementalBundlePath != "" { - p.log.Info("Loading supplemental certificates for inclusion in the bundle", "supplemental_bundle_path", config.SupplementalBundlePath) - supplementalBundle, err = pemutil.LoadCertificates(config.SupplementalBundlePath) + if newConfig.SupplementalBundlePath != "" { + p.log.Info("Loading supplemental certificates for inclusion in the bundle", "supplemental_bundle_path", newConfig.SupplementalBundlePath) + supplementalBundle, err = pemutil.LoadCertificates(newConfig.SupplementalBundlePath) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to load supplemental bundle: %v", err) } } // Create the client - pcaClient, err := p.hooks.newClient(ctx, config) + pcaClient, err := p.hooks.newClient(ctx, newConfig) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create client: %v", err) } // Perform a check for the presence of the CA - p.log.Info("Looking up certificate authority from ACM", "certificate_authority_arn", config.CertificateAuthorityARN) + p.log.Info("Looking up certificate authority from ACM", "certificate_authority_arn", newConfig.CertificateAuthorityARN) describeResponse, err := pcaClient.DescribeCertificateAuthority(ctx, &acmpca.DescribeCertificateAuthorityInput{ - CertificateAuthorityArn: aws.String(config.CertificateAuthorityARN), + CertificateAuthorityArn: aws.String(newConfig.CertificateAuthorityARN), }) if err != nil { return nil, status.Errorf(codes.Internal, "failed to describe CertificateAuthority: %v", err) @@ -138,13 +155,13 @@ func (p *PCAPlugin) Configure(ctx context.Context, req *configv1.ConfigureReques caStatus := describeResponse.CertificateAuthority.Status if caStatus != "ACTIVE" { p.log.Warn("Certificate is in an invalid state for issuance", - "certificate_authority_arn", config.CertificateAuthorityARN, + "certificate_authority_arn", newConfig.CertificateAuthorityARN, "status", caStatus) } // If a signing algorithm has been provided, use it. // Otherwise, fall back to the pre-configured value on the CA - signingAlgorithm := config.SigningAlgorithm + signingAlgorithm := newConfig.SigningAlgorithm if signingAlgorithm == "" { signingAlgorithm = string(describeResponse.CertificateAuthority.CertificateAuthorityConfiguration.SigningAlgorithm) p.log.Info("No signing algorithm specified, using the CA default", "signing_algorithm", signingAlgorithm) @@ -152,7 +169,7 @@ func (p *PCAPlugin) Configure(ctx context.Context, req *configv1.ConfigureReques // If a CA signing template ARN has been provided, use it. // Otherwise, fall back to the default value (PathLen=0) - caSigningTemplateArn := config.CASigningTemplateARN + caSigningTemplateArn := newConfig.CASigningTemplateARN if caSigningTemplateArn == "" { p.log.Info("No CA signing template ARN specified, using the default", "ca_signing_template_arn", defaultCASigningTemplateArn) caSigningTemplateArn = defaultCASigningTemplateArn @@ -167,12 +184,21 @@ func (p *PCAPlugin) Configure(ctx context.Context, req *configv1.ConfigureReques supplementalBundle: supplementalBundle, signingAlgorithm: signingAlgorithm, caSigningTemplateArn: caSigningTemplateArn, - certificateAuthorityArn: config.CertificateAuthorityARN, + certificateAuthorityArn: newConfig.CertificateAuthorityARN, } return &configv1.ConfigureResponse{}, nil } +func (p *PCAPlugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // MintX509CA mints an X509CA by submitting the CSR to ACM to be signed by the certificate authority func (p *PCAPlugin) MintX509CAAndSubscribe(request *upstreamauthorityv1.MintX509CARequest, stream upstreamauthorityv1.UpstreamAuthority_MintX509CAAndSubscribeServer) error { ctx := stream.Context() @@ -293,22 +319,3 @@ func (p *PCAPlugin) getConfig() (*configuration, error) { } return p.config, nil } - -// validateConfig returns an error if any configuration provided does not meet acceptable criteria -func (p *PCAPlugin) validateConfig(req *configv1.ConfigureRequest) (*Configuration, error) { - config := new(Configuration) - - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if config.Region == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing a region") - } - - if config.CertificateAuthorityARN == "" { - return nil, status.Error(codes.InvalidArgument, "configuration is missing a certificate authority ARN") - } - - return config, nil -} diff --git a/pkg/server/plugin/upstreamauthority/awspca/pca_test.go b/pkg/server/plugin/upstreamauthority/awspca/pca_test.go index ffea2f64a9..1854db2379 100644 --- a/pkg/server/plugin/upstreamauthority/awspca/pca_test.go +++ b/pkg/server/plugin/upstreamauthority/awspca/pca_test.go @@ -15,6 +15,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/acmpca" acmpcatypes "github.com/aws/aws-sdk-go-v2/service/acmpca/types" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" "github.com/spiffe/spire/pkg/server/plugin/upstreamauthority" @@ -53,6 +55,9 @@ func TestConfigure(t *testing.T) { expectDescribeErr error expectConfig *configuration + // core config configurations + trustDomain string + // All allowed configurations region string endpoint string @@ -65,6 +70,7 @@ func TestConfigure(t *testing.T) { { test: "success", expectedDescribeStatus: "ACTIVE", + trustDomain: "example.org", region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, caSigningTemplateARN: validCASigningTemplateARN, @@ -79,6 +85,7 @@ func TestConfigure(t *testing.T) { }, { test: "using default signing algorithm", + trustDomain: "example.org", expectedDescribeStatus: "ACTIVE", region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, @@ -93,6 +100,7 @@ func TestConfigure(t *testing.T) { }, { test: "using default signing template ARN", + trustDomain: "example.org", expectedDescribeStatus: "ACTIVE", region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, @@ -107,6 +115,7 @@ func TestConfigure(t *testing.T) { }, { test: "DISABLED template", + trustDomain: "example.org", expectedDescribeStatus: "DISABLED", region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, @@ -122,6 +131,7 @@ func TestConfigure(t *testing.T) { }, { test: "Describe certificate fails", + trustDomain: "example.org", expectDescribeErr: awsErr("Internal", "some error", errors.New("oh no")), region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, @@ -134,6 +144,7 @@ func TestConfigure(t *testing.T) { }, { test: "Invalid supplemental bundle Path", + trustDomain: "example.org", expectedDescribeStatus: "ACTIVE", region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, @@ -146,6 +157,7 @@ func TestConfigure(t *testing.T) { }, { test: "Missing region", + trustDomain: "example.org", expectedDescribeStatus: "ACTIVE", certificateAuthorityARN: validCertificateAuthorityARN, caSigningTemplateARN: validCASigningTemplateARN, @@ -153,10 +165,11 @@ func TestConfigure(t *testing.T) { assumeRoleARN: validAssumeRoleARN, supplementalBundlePath: validSupplementalBundlePath, expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration is missing a region", + expectMsgPrefix: "plugin configuration is missing the region", }, { test: "Missing certificate ARN", + trustDomain: "example.org", expectedDescribeStatus: "ACTIVE", region: validRegion, caSigningTemplateARN: validCASigningTemplateARN, @@ -164,18 +177,20 @@ func TestConfigure(t *testing.T) { assumeRoleARN: validAssumeRoleARN, supplementalBundlePath: validSupplementalBundlePath, expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration is missing a certificate authority ARN", + expectMsgPrefix: "plugin configuration is missing the certificate_authority_arn", }, { - test: "Malformed config", + test: "Malformed config", + trustDomain: "example.org", overrideConfig: `{ badjson }`, expectCode: codes.InvalidArgument, - expectMsgPrefix: "unable to decode configuration:", + expectMsgPrefix: "plugin configuration is malformed", }, { test: "Fail to create client", + trustDomain: "example.org", newClientErr: awsErr("MissingEndpoint", "'Endpoint' configuration is required for this service", nil), region: validRegion, certificateAuthorityARN: validCertificateAuthorityARN, @@ -198,6 +213,12 @@ badjson plugintest.CaptureConfigureError(&err), } + if tt.trustDomain != "" { + options = append(options, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + })) + } + if tt.overrideConfig != "" { options = append(options, plugintest.Configure(tt.overrideConfig)) } else { @@ -278,8 +299,9 @@ func TestMintX509CA(t *testing.T) { } for _, tt := range []struct { - test string - config *Configuration + test string + trustDomain string + config *Configuration client *pcaClientFake @@ -298,6 +320,7 @@ func TestMintX509CA(t *testing.T) { }{ { test: "Successful mint", + trustDomain: "example.org", config: successConfig, csr: makeCSR("spiffe://example.com/foo"), preferredTTL: 300 * time.Second, @@ -311,7 +334,8 @@ func TestMintX509CA(t *testing.T) { getCertificateCertChain: encodedCertChain.String(), }, { - test: "With supplemental bundle", + test: "With supplemental bundle", + trustDomain: "example.org", config: &Configuration{ Region: validRegion, CertificateAuthorityARN: validCertificateAuthorityARN, @@ -336,6 +360,7 @@ func TestMintX509CA(t *testing.T) { }, { test: "Issuance fails", + trustDomain: "example.org", config: successConfig, csr: makeCSR("spiffe://example.com/foo"), preferredTTL: 300 * time.Second, @@ -345,6 +370,7 @@ func TestMintX509CA(t *testing.T) { }, { test: "Issuance wait fails", + trustDomain: "example.org", config: successConfig, csr: makeCSR("spiffe://example.com/foo"), preferredTTL: 300 * time.Second, @@ -354,6 +380,7 @@ func TestMintX509CA(t *testing.T) { }, { test: "Get certificate fails", + trustDomain: "example.org", config: successConfig, csr: makeCSR("spiffe://example.com/foo"), preferredTTL: 300 * time.Second, @@ -363,6 +390,7 @@ func TestMintX509CA(t *testing.T) { }, { test: "Fails to parse certificate from GetCertificate", + trustDomain: "example.org", config: successConfig, csr: makeCSR("spiffe://example.com/foo"), preferredTTL: 300 * time.Second, @@ -373,6 +401,7 @@ func TestMintX509CA(t *testing.T) { }, { test: "Fails to parse certificate chain from GetCertificate", + trustDomain: "example.org", config: successConfig, csr: makeCSR("spiffe://example.com/foo"), preferredTTL: 300 * time.Second, @@ -397,6 +426,9 @@ func TestMintX509CA(t *testing.T) { ua := new(upstreamauthority.V1) plugintest.Load(t, builtin(p), ua, + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString(tt.trustDomain), + }), plugintest.ConfigureJSON(tt.config), ) @@ -445,6 +477,9 @@ func TestPublishJWTKey(t *testing.T) { var err error plugintest.Load(t, builtin(p), ua, plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), plugintest.ConfigureJSON(&Configuration{ Region: validRegion, CertificateAuthorityARN: validCertificateAuthorityARN, diff --git a/pkg/server/plugin/upstreamauthority/awssecret/awssecret.go b/pkg/server/plugin/upstreamauthority/awssecret/awssecret.go index 300e90b86f..5bdcca8949 100644 --- a/pkg/server/plugin/upstreamauthority/awssecret/awssecret.go +++ b/pkg/server/plugin/upstreamauthority/awssecret/awssecret.go @@ -16,6 +16,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/x509svid" "github.com/spiffe/spire/pkg/common/x509util" "google.golang.org/grpc/codes" @@ -24,6 +25,10 @@ import ( const ( pluginName = "awssecret" + + CoreConfigRequired = "server core configuration is required" + CoreConfigTrustdomainRequired = "server core configuration must contain trust_domain" + CoreConfigTrustdomainMalformed = "server core configuration trust_domain is malformed" ) func BuiltIn() catalog.BuiltIn { @@ -48,6 +53,27 @@ type Configuration struct { AssumeRoleARN string `hcl:"assume_role_arn" json:"assume_role_arn"` } +func (p *Plugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + if newConfig.SecurityToken == "" { + newConfig.SecurityToken = p.hooks.getenv("AWS_SESSION_TOKEN") + } + + if newConfig.CertFileARN == "" { + status.ReportError("configuration missing 'cert_file_arn' value") + } + if newConfig.KeyFileARN == "" { + status.ReportError("configuration missing 'key_file_arn' value") + } + + return newConfig +} + type Plugin struct { upstreamauthorityv1.UnsafeUpstreamAuthorityServer configv1.UnsafeConfigServer @@ -83,19 +109,21 @@ func (p *Plugin) SetLogger(log hclog.Logger) { } func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := p.validateConfig(req) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { return nil, err } + // TODO: determine if the items before the lock contain configuration validation. + // set the AWS configuration and reset clients + // Set local vars from config struct - sm, err := p.hooks.newClient(ctx, config, config.Region) + sm, err := p.hooks.newClient(ctx, newConfig, newConfig.Region) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to create AWS client: %v", err) } - keyPEMstr, certsPEMstr, bundleCertsPEMstr, err := fetchFromSecretsManager(ctx, config, sm) + keyPEMstr, certsPEMstr, bundleCertsPEMstr, err := fetchFromSecretsManager(ctx, newConfig, sm) if err != nil { p.log.Error("Error loading files from AWS: %v", err) return nil, err @@ -123,6 +151,15 @@ func (p *Plugin) Configure(ctx context.Context, req *configv1.ConfigureRequest) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(ctx context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + // MintX509CAAndSubscribe mints an X509CA by signing presented CSR with root CA fetched from AWS Secrets Manager func (p *Plugin) MintX509CAAndSubscribe(request *upstreamauthorityv1.MintX509CARequest, stream upstreamauthorityv1.UpstreamAuthority_MintX509CAAndSubscribeServer) error { ctx := stream.Context() @@ -253,37 +290,3 @@ func fetchFromSecretsManager(ctx context.Context, config *Configuration, sm secr return keyPEMstr, certsPEMstr, bundlePEMstr, nil } - -func (p *Plugin) validateConfig(req *configv1.ConfigureRequest) (*Configuration, error) { - // Parse HCL config payload into config struct - config := new(Configuration) - - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - // Set defaults from the environment - if config.SecurityToken == "" { - config.SecurityToken = p.hooks.getenv("AWS_SESSION_TOKEN") - } - - switch { - case config.CertFileARN != "" && config.KeyFileARN != "": - case config.CertFileARN != "" && config.KeyFileARN == "": - return nil, status.Error(codes.InvalidArgument, "configuration missing key ARN") - case config.CertFileARN == "" && config.KeyFileARN != "": - return nil, status.Error(codes.InvalidArgument, "configuration missing cert ARN") - case config.CertFileARN == "" && config.KeyFileARN == "": - return nil, status.Error(codes.InvalidArgument, "configuration missing both cert ARN and key ARN") - } - - return config, nil -} diff --git a/pkg/server/plugin/upstreamauthority/awssecret/awssecret_test.go b/pkg/server/plugin/upstreamauthority/awssecret/awssecret_test.go index 4f426b8f70..4f6051679a 100644 --- a/pkg/server/plugin/upstreamauthority/awssecret/awssecret_test.go +++ b/pkg/server/plugin/upstreamauthority/awssecret/awssecret_test.go @@ -59,13 +59,13 @@ func TestConfigure(t *testing.T) { test: "malformed configuration", overrideConfig: "MALFORMED", expectCode: codes.InvalidArgument, - expectMsgPrefix: "unable to decode configuration:", + expectMsgPrefix: "plugin configuration is malformed", }, { test: "no trust domain", overrideCoreConfig: &catalog.CoreConfig{}, expectCode: codes.InvalidArgument, - expectMsgPrefix: "trust_domain is required", + expectMsgPrefix: "server core configuration must contain trust_domain", }, { test: "missing key ARN", @@ -76,7 +76,7 @@ func TestConfigure(t *testing.T) { securityToken: "security_token", assumeRoleARN: "assume_role_arn", expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration missing key ARN", + expectMsgPrefix: "configuration missing 'key_file_arn' value", }, { test: "missing cert ARN", @@ -87,7 +87,7 @@ func TestConfigure(t *testing.T) { securityToken: "security_token", assumeRoleARN: "assume_role_arn", expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration missing cert ARN", + expectMsgPrefix: "configuration missing 'cert_file_arn' value", }, { test: "missing cert and key ARNs", @@ -97,7 +97,7 @@ func TestConfigure(t *testing.T) { securityToken: "security_token", assumeRoleARN: "assume_role_arn", expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration missing both cert ARN and key ARN", + expectMsgPrefix: "configuration missing 'cert_file_arn' value", }, { test: "fails to create client", @@ -228,7 +228,7 @@ func TestConfigure(t *testing.T) { })) } - p := new(Plugin) + p := New() p.hooks.clock = clk p.hooks.newClient = fakeStorageClientCreator @@ -330,7 +330,7 @@ func TestMintX509CA(t *testing.T) { } { tt := tt t.Run(tt.test, func(t *testing.T) { - p := new(Plugin) + p := New() p.hooks.clock = clk p.hooks.getenv = func(s string) string { return "" @@ -396,7 +396,7 @@ func TestMintX509CA(t *testing.T) { func TestPublishJWTKey(t *testing.T) { clk := clock.NewMock(t) _, fakeStorageClientCreator := generateTestData(t, clk) - p := new(Plugin) + p := New() p.hooks.clock = clk p.hooks.newClient = fakeStorageClientCreator diff --git a/pkg/server/plugin/upstreamauthority/certmanager/api_test.go b/pkg/server/plugin/upstreamauthority/certmanager/api_test.go index 4b0ea45237..b97f54b5fe 100644 --- a/pkg/server/plugin/upstreamauthority/certmanager/api_test.go +++ b/pkg/server/plugin/upstreamauthority/certmanager/api_test.go @@ -201,7 +201,7 @@ func Test_cleanupStaleCertificateRequests(t *testing.T) { log: hclog.New(logOptions), cmclient: client, trustDomain: trustDomain, - config: &Config{ + config: &Configuration{ Namespace: namespace, }, } diff --git a/pkg/server/plugin/upstreamauthority/certmanager/certmanager.go b/pkg/server/plugin/upstreamauthority/certmanager/certmanager.go index 7b565f4ab5..3c93a775bf 100644 --- a/pkg/server/plugin/upstreamauthority/certmanager/certmanager.go +++ b/pkg/server/plugin/upstreamauthority/certmanager/certmanager.go @@ -12,6 +12,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" cmapi "github.com/spiffe/spire/pkg/server/plugin/upstreamauthority/certmanager/internal/v1" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -36,7 +37,7 @@ func builtin(p *Plugin) catalog.BuiltIn { ) } -type Config struct { +type Configuration struct { // Options which are used for configuring the target issuer to sign requests. // The CertificateRequest will be created in the configured namespace. IssuerName string `hcl:"issuer_name" json:"issuer_name"` @@ -48,6 +49,36 @@ type Config struct { KubeConfigFilePath string `hcl:"kube_config_file" json:"kube_config_file"` } +func (p *Plugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + // namespace is a required field + if len(newConfig.Namespace) == 0 { + status.ReportError("plugin configuration has empty namespace property") + } + // issuer_name is a required field + if len(newConfig.IssuerName) == 0 { + status.ReportError("plugin configuration has empty issuer_name property") + } + // If no issuer_kind given, default to Issuer + if len(newConfig.IssuerKind) == 0 { + status.ReportInfo("plugin configuration has empty issuer_kind property, defaulting to value 'Issuer'") + newConfig.IssuerKind = "Issuer" + } + // If no issuer_group given, default to cert-manager.io + if len(newConfig.IssuerGroup) == 0 { + status.ReportInfo("plugin configuration has empty issuer_group property, defaulting to value 'cert-manager.io'") + p.log.Debug("plugin configuration has empty issuer_group property, defaulting to 'cert-manager.io'") + newConfig.IssuerGroup = "cert-manager.io" + } + + return newConfig +} + // Event hooks used by unit tests to coordinate goroutines type hooks struct { newClient func(configPath string) (client.Client, error) @@ -62,7 +93,7 @@ type Plugin struct { configv1.UnsafeConfigServer log hclog.Logger - config *Config + config *Configuration mtx sync.RWMutex // trustDomain is the trust domain of this SPIRE server. Used to label @@ -96,20 +127,12 @@ func (p *Plugin) SetLogger(log hclog.Logger) { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := p.loadConfig(req) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { return nil, err } - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - cmclient, err := p.hooks.newClient(config.KubeConfigFilePath) + cmclient, err := p.hooks.newClient(newConfig.KubeConfigFilePath) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create cert-manager client: %v", err) } @@ -118,7 +141,7 @@ func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (* defer p.mtx.Unlock() p.cmclient = cmclient - p.config = config + p.config = newConfig // Used for adding labels to created CertificateRequests, which can be listed // for cleanup. p.trustDomain = req.CoreConfiguration.TrustDomain @@ -126,6 +149,15 @@ func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (* return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + func (p *Plugin) MintX509CAAndSubscribe(request *upstreamauthorityv1.MintX509CARequest, stream upstreamauthorityv1.UpstreamAuthority_MintX509CAAndSubscribeServer) error { ctx := stream.Context() p.mtx.RLock() @@ -229,34 +261,6 @@ func (*Plugin) PublishJWTKeyAndSubscribe(*upstreamauthorityv1.PublishJWTKeyReque return status.Error(codes.Unimplemented, "publishing upstream is unsupported") } -// loadConfig parses and defaults incoming configure requests -func (p *Plugin) loadConfig(req *configv1.ConfigureRequest) (*Config, error) { - config := new(Config) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode configuration file: %v", err) - } - - // namespace is a required field - if len(config.Namespace) == 0 { - return nil, status.Error(codes.InvalidArgument, "configuration has empty namespace property") - } - // issuer_name is a required field - if len(config.IssuerName) == 0 { - return nil, status.Error(codes.InvalidArgument, "configuration has empty issuer_name property") - } - // If no issuer_kind given, default to Issuer - if len(config.IssuerKind) == 0 { - p.log.Debug("Configuration has empty issuer_kind property, defaulting to 'Issuer'") - config.IssuerKind = "Issuer" - } - // If no issuer_group given, default to cert-manager.io - if len(config.IssuerGroup) == 0 { - p.log.Debug("Configuration has empty issuer_group property, defaulting to 'cert-manager.io'") - config.IssuerGroup = "cert-manager.io" - } - return config, nil -} - func newCertManagerClient(configPath string) (client.Client, error) { config, err := getKubeConfig(configPath) if err != nil { diff --git a/pkg/server/plugin/upstreamauthority/certmanager/certmanager_test.go b/pkg/server/plugin/upstreamauthority/certmanager/certmanager_test.go index d4115600d8..6d9a7ceaf8 100644 --- a/pkg/server/plugin/upstreamauthority/certmanager/certmanager_test.go +++ b/pkg/server/plugin/upstreamauthority/certmanager/certmanager_test.go @@ -147,7 +147,7 @@ func Test_MintX509CA(t *testing.T) { }, }, } - config := &Config{ + config := &Configuration{ IssuerName: issuerName, IssuerKind: issuerKind, IssuerGroup: issuerGroup, @@ -210,7 +210,7 @@ func Test_Configure(t *testing.T) { inpConfig string expectCode codes.Code expectMsgPrefix string - expectConfig *Config + expectConfig *Configuration expectConfigFile string overrideCoreConfig *catalog.CoreConfig newClientErr error @@ -218,7 +218,7 @@ func Test_Configure(t *testing.T) { "if config is malformed, expect error": { inpConfig: "MALFORMED", expectCode: codes.InvalidArgument, - expectMsgPrefix: "failed to decode configuration file:", + expectMsgPrefix: "plugin configuration is malformed", }, "if config is missing an issuer_name, expect error": { inpConfig: ` @@ -229,7 +229,7 @@ func Test_Configure(t *testing.T) { `, expectConfig: nil, expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration has empty issuer_name property", + expectMsgPrefix: "plugin configuration has empty issuer_name property", }, "if config is missing a namespace, expect error": { inpConfig: ` @@ -240,7 +240,7 @@ func Test_Configure(t *testing.T) { `, expectConfig: nil, expectCode: codes.InvalidArgument, - expectMsgPrefix: "configuration has empty namespace property", + expectMsgPrefix: "plugin configuration has empty namespace property", }, "if config is fully populated, return config": { inpConfig: ` @@ -250,7 +250,7 @@ func Test_Configure(t *testing.T) { namespace = "my-namespace" kube_config_file = "/path/to/config" `, - expectConfig: &Config{ + expectConfig: &Configuration{ IssuerName: "my-issuer", IssuerKind: "my-kind", IssuerGroup: "my-group", @@ -265,7 +265,7 @@ func Test_Configure(t *testing.T) { namespace = "my-namespace" kube_config_file = "/path/to/config" `, - expectConfig: &Config{ + expectConfig: &Configuration{ IssuerName: "my-issuer", IssuerKind: "Issuer", IssuerGroup: "cert-manager.io", @@ -282,7 +282,7 @@ func Test_Configure(t *testing.T) { `, overrideCoreConfig: &catalog.CoreConfig{}, expectCode: codes.InvalidArgument, - expectMsgPrefix: "trust_domain is required", + expectMsgPrefix: "server core configuration must contain trust_domain", }, "failed to create client": { inpConfig: ` @@ -350,7 +350,7 @@ func TestPublishJWTKey(t *testing.T) { }, }, } - config := &Config{ + config := &Configuration{ IssuerName: "test-issuer", IssuerKind: "Issuer", IssuerGroup: "example.cert-manager.io", diff --git a/pkg/server/plugin/upstreamauthority/disk/disk.go b/pkg/server/plugin/upstreamauthority/disk/disk.go index 0820d8694d..2b111b00d9 100644 --- a/pkg/server/plugin/upstreamauthority/disk/disk.go +++ b/pkg/server/plugin/upstreamauthority/disk/disk.go @@ -19,10 +19,17 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/x509svid" "github.com/spiffe/spire/pkg/common/x509util" ) +const ( + CoreConfigRequired = "server core configuration is required" + CoreConfigTrustDomainRequired = "server core configuration must contain trust_domain" + CoreConfigTrustDomainMalformed = "server core configuration trust_domain is malformed" +) + func BuiltIn() catalog.BuiltIn { return builtin(New()) } @@ -42,6 +49,19 @@ type Configuration struct { BundleFilePath string `hcl:"bundle_file_path" json:"bundle_file_path"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + newConfig.trustDomain = coreConfig.TrustDomain + // TODO: add field validation + + return newConfig +} + type Plugin struct { upstreamauthorityv1.UnsafeUpstreamAuthorityServer configv1.UnsafeConfigServer @@ -73,27 +93,12 @@ func (p *Plugin) SetLogger(log hclog.Logger) { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := &Configuration{} - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") - } - - trustDomain, err := spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) + newConfig, _, err := pluginconf.Build(req, buildConfig) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "trust_domain is malformed: %v", err) + return nil, err } - config.trustDomain = trustDomain - upstreamCA, certs, err := p.loadUpstreamCAAndCerts(config) + upstreamCA, certs, err := p.loadUpstreamCAAndCerts(newConfig) if err != nil { return nil, err } @@ -102,13 +107,22 @@ func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (* p.mtx.Lock() defer p.mtx.Unlock() - p.config = config + p.config = newConfig p.certs = certs p.upstreamCA = upstreamCA return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + func (p *Plugin) MintX509CAAndSubscribe(request *upstreamauthorityv1.MintX509CARequest, stream upstreamauthorityv1.UpstreamAuthority_MintX509CAAndSubscribeServer) error { ctx := stream.Context() @@ -162,6 +176,7 @@ func (p *Plugin) reloadCA() (*x509svid.UpstreamCA, *caCerts, error) { return upstreamCA, upstreamCerts, nil } +// TODO: perhaps load this into the config func (p *Plugin) loadUpstreamCAAndCerts(config *Configuration) (*x509svid.UpstreamCA, *caCerts, error) { key, err := pemutil.LoadPrivateKey(config.KeyFilePath) if err != nil { diff --git a/pkg/server/plugin/upstreamauthority/disk/disk_test.go b/pkg/server/plugin/upstreamauthority/disk/disk_test.go index 4b854b22a5..eb21fe2393 100644 --- a/pkg/server/plugin/upstreamauthority/disk/disk_test.go +++ b/pkg/server/plugin/upstreamauthority/disk/disk_test.go @@ -296,7 +296,7 @@ func TestConfigure(t *testing.T) { test: "malformed config", overrideConfig: "MALFORMED", expectCode: codes.InvalidArgument, - expectMsgPrefix: "unable to decode configuration: ", + expectMsgPrefix: "plugin configuration is malformed", }, { test: "missing trust domain", @@ -304,7 +304,7 @@ func TestConfigure(t *testing.T) { keyFilePath: testData.ECRootKey, overrideCoreConfig: &catalog.CoreConfig{}, expectCode: codes.InvalidArgument, - expectMsgPrefix: "trust_domain is required", + expectMsgPrefix: "server core configuration must contain trust_domain", }, } { tt := tt diff --git a/pkg/server/plugin/upstreamauthority/ejbca/ejbca.go b/pkg/server/plugin/upstreamauthority/ejbca/ejbca.go index 57226c8f37..d023627b88 100644 --- a/pkg/server/plugin/upstreamauthority/ejbca/ejbca.go +++ b/pkg/server/plugin/upstreamauthority/ejbca/ejbca.go @@ -14,11 +14,13 @@ import ( ejbcaclient "github.com/Keyfactor/ejbca-go-client-sdk/api/ejbca" "github.com/hashicorp/go-hclog" + "github.com/hashicorp/hcl" "github.com/spiffe/spire-plugin-sdk/pluginsdk" upstreamauthorityv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/plugin/server/upstreamauthority/v1" configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -84,6 +86,56 @@ type Config struct { AccountBindingID string `hcl:"account_binding_id" json:"account_binding_id"` } +func (p *Plugin) buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Config { + logger := p.logger.Named("parseConfig") + logger.Debug("Decoding EJBCA configuration") + + newConfig := &Config{} + if err := hcl.Decode(&newConfig, hclText); err != nil { + status.ReportErrorf("failed to decode configuration: %v", err) + return nil + } + + if newConfig.Hostname == "" { + status.ReportError("hostname is required") + } + if newConfig.CAName == "" { + status.ReportError("ca_name is required") + } + if newConfig.EndEntityProfileName == "" { + status.ReportError("end_entity_profile_name is required") + } + if newConfig.CertificateProfileName == "" { + status.ReportError("certificate_profile_name is required") + } + + // If ClientCertPath or ClientCertKeyPath were not found in the main server conf file, + // load them from the environment. + if newConfig.ClientCertPath == "" { + newConfig.ClientCertPath = p.hooks.getEnv("EJBCA_CLIENT_CERT_PATH") + } + if newConfig.ClientCertKeyPath == "" { + newConfig.ClientCertKeyPath = p.hooks.getEnv("EJBCA_CLIENT_CERT_KEY_PATH") + } + + // If ClientCertPath or ClientCertKeyPath were not present in either the conf file or + // the environment, return an error. + if newConfig.ClientCertPath == "" { + logger.Error("Client certificate is required for mTLS authentication") + status.ReportError("client_cert or EJBCA_CLIENT_CERT_PATH is required for mTLS authentication") + } + if newConfig.ClientCertKeyPath == "" { + logger.Error("Client key is required for mTLS authentication") + status.ReportError("client_key or EJBCA_CLIENT_KEY_PATH is required for mTLS authentication") + } + + if newConfig.CaCertPath == "" { + newConfig.CaCertPath = p.hooks.getEnv("EJBCA_CA_CERT_PATH") + } + + return newConfig +} + // New returns an instantiated EJBCA UpstreamAuthority plugin func New() *Plugin { p := &Plugin{} @@ -96,26 +148,35 @@ func New() *Plugin { // Configure configures the EJBCA UpstreamAuthority plugin. This is invoked by SPIRE when the plugin is // first loaded. After the first invocation, it may be used to reconfigure the plugin. func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config, err := p.parseConfig(req) + newConfig, _, err := pluginconf.Build(req, p.buildConfig) if err != nil { return nil, err } - authenticator, err := p.hooks.newAuthenticator(config) + authenticator, err := p.hooks.newAuthenticator(newConfig) if err != nil { return nil, err } - client, err := p.newEjbcaClient(config, authenticator) + client, err := p.newEjbcaClient(newConfig, authenticator) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "failed to create EJBCA client: %v", err) } - p.setConfig(config) + p.setConfig(newConfig) p.setClient(client) return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, p.buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, nil +} + // SetLogger is called by the framework when the plugin is loaded and provides // the plugin with a logger wired up to SPIRE's logging facilities. func (p *Plugin) SetLogger(logger hclog.Logger) { diff --git a/pkg/server/plugin/upstreamauthority/ejbca/ejbca_client.go b/pkg/server/plugin/upstreamauthority/ejbca/ejbca_client.go index 7d263120bb..d6ce78d53f 100644 --- a/pkg/server/plugin/upstreamauthority/ejbca/ejbca_client.go +++ b/pkg/server/plugin/upstreamauthority/ejbca/ejbca_client.go @@ -8,8 +8,6 @@ import ( ejbcaclient "github.com/Keyfactor/ejbca-go-client-sdk/api/ejbca" "github.com/gogo/status" - "github.com/hashicorp/hcl" - configv1 "github.com/spiffe/spire-plugin-sdk/proto/spire/service/common/config/v1" "github.com/spiffe/spire/pkg/common/pemutil" "google.golang.org/grpc/codes" ) @@ -18,54 +16,6 @@ type ejbcaClient interface { EnrollPkcs10Certificate(ctx context.Context) ejbcaclient.ApiEnrollPkcs10CertificateRequest } -func (p *Plugin) parseConfig(req *configv1.ConfigureRequest) (*Config, error) { - logger := p.logger.Named("parseConfig") - config := new(Config) - logger.Debug("Decoding EJBCA configuration") - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "failed to decode configuration: %v", err) - } - - if config.Hostname == "" { - return nil, status.Error(codes.InvalidArgument, "hostname is required") - } - if config.CAName == "" { - return nil, status.Error(codes.InvalidArgument, "ca_name is required") - } - if config.EndEntityProfileName == "" { - return nil, status.Error(codes.InvalidArgument, "end_entity_profile_name is required") - } - if config.CertificateProfileName == "" { - return nil, status.Error(codes.InvalidArgument, "certificate_profile_name is required") - } - - // If ClientCertPath or ClientCertKeyPath were not found in the main server conf file, - // load them from the environment. - if config.ClientCertPath == "" { - config.ClientCertPath = p.hooks.getEnv("EJBCA_CLIENT_CERT_PATH") - } - if config.ClientCertKeyPath == "" { - config.ClientCertKeyPath = p.hooks.getEnv("EJBCA_CLIENT_CERT_KEY_PATH") - } - - // If ClientCertPath or ClientCertKeyPath were not present in either the conf file or - // the environment, return an error. - if config.ClientCertPath == "" { - logger.Error("Client certificate is required for mTLS authentication") - return nil, status.Error(codes.InvalidArgument, "client_cert or EJBCA_CLIENT_CERT_PATH is required for mTLS authentication") - } - if config.ClientCertKeyPath == "" { - logger.Error("Client key is required for mTLS authentication") - return nil, status.Error(codes.InvalidArgument, "client_key or EJBCA_CLIENT_KEY_PATH is required for mTLS authentication") - } - - if config.CaCertPath == "" { - config.CaCertPath = p.hooks.getEnv("EJBCA_CA_CERT_PATH") - } - - return config, nil -} - func (p *Plugin) getAuthenticator(config *Config) (ejbcaclient.Authenticator, error) { var err error logger := p.logger.Named("getAuthenticator") diff --git a/pkg/server/plugin/upstreamauthority/ejbca/ejbca_test.go b/pkg/server/plugin/upstreamauthority/ejbca/ejbca_test.go index 5b740ffcfd..7fe520f806 100644 --- a/pkg/server/plugin/upstreamauthority/ejbca/ejbca_test.go +++ b/pkg/server/plugin/upstreamauthority/ejbca/ejbca_test.go @@ -23,6 +23,7 @@ import ( ejbcaclient "github.com/Keyfactor/ejbca-go-client-sdk/api/ejbca" "github.com/hashicorp/go-hclog" "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/catalog" commonutil "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/pkg/server/plugin/upstreamauthority" "github.com/spiffe/spire/test/plugintest" @@ -374,6 +375,9 @@ func TestConfigure(t *testing.T) { options := []plugintest.Option{ plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: trustDomain, + }), plugintest.Configure(tt.config), } @@ -536,6 +540,9 @@ func TestMintX509CAAndSubscribe(t *testing.T) { options := []plugintest.Option{ plugintest.CaptureConfigureError(&err), + plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: trustDomain, + }), plugintest.ConfigureJSON(config), } diff --git a/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas.go b/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas.go index 15891380d9..7cecf37f05 100644 --- a/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas.go +++ b/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas.go @@ -21,6 +21,7 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "github.com/spiffe/spire/pkg/common/x509util" "google.golang.org/api/iterator" "google.golang.org/grpc/codes" @@ -66,6 +67,36 @@ type Configuration struct { RootSpec CertificateAuthoritySpec `hcl:"root_cert_spec,block"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + // Without a project and location, we can never locate CAs + if newConfig.RootSpec.Project == "" { + status.ReportError("plugin configuration root_cert_spec.Project is missing") + } + if newConfig.RootSpec.Location == "" { + status.ReportError("plugin configuration root_cert_spec.Location is missing") + } + + // Even LabelKey/Value pair is necessary + if newConfig.RootSpec.LabelKey == "" { + status.ReportError("plugin configuration root_cert_spec.LabelKey is missing") + } + if newConfig.RootSpec.LabelValue == "" { + status.ReportError("plugin configuration root_cert_spec.LabelValue is missing") + } + + if newConfig.RootSpec.CaPool == "" { + status.ReportInfo("The ca_pool value is not configured. Falling back to searching the region for matching CAs. The ca_pool configurable will be required in a future release.") + } + + return newConfig +} + type CAClient interface { CreateCertificate(ctx context.Context, req *privatecapb.CreateCertificateRequest) (*privatecapb.Certificate, error) LoadCertificateAuthorities(ctx context.Context, spec CertificateAuthoritySpec) ([]*privatecapb.CertificateAuthority, error) @@ -79,8 +110,8 @@ type Plugin struct { // need to support hot-reloading of configuration (by receiving another // call to Configure). So we need to prevent the configuration from // being used concurrently and make sure it is updated atomically. - mu sync.Mutex - c *Configuration + mu sync.Mutex + config *Configuration log hclog.Logger @@ -135,48 +166,36 @@ func (p *Plugin) PublishJWTKeyAndSubscribe(*upstreamauthorityv1.PublishJWTKeyReq func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { // Parse HCL config payload into config struct - config := new(Configuration) - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - // Without a project and location, we can never locate CAs - if config.RootSpec.Project == "" { - return nil, status.Error(codes.InvalidArgument, "configuration has empty root_cert_spec.Project property") - } - if config.RootSpec.Location == "" { - return nil, status.Error(codes.InvalidArgument, "configuration has empty root_cert_spec.Location property") - } - // Even LabelKey/Value pair is necessary - if config.RootSpec.LabelKey == "" { - return nil, status.Error(codes.InvalidArgument, "configuration has empty root_cert_spec.LabelKey property") - } - if config.RootSpec.LabelValue == "" { - return nil, status.Error(codes.InvalidArgument, "configuration has empty root_cert_spec.LabelValue property") - } - if config.RootSpec.CaPool == "" { - p.log.Warn("The ca_pool value is not configured. Falling back to searching the region for matching CAs. The ca_pool configurable will be required in a future release.") + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } - // Swap out the current configuration with the new configuration - p.setConfig(config) + + p.mu.Lock() + defer p.mu.Unlock() + p.config = newConfig return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + func (p *Plugin) getConfig() (*Configuration, error) { p.mu.Lock() defer p.mu.Unlock() - if p.c == nil { + if p.config == nil { return nil, status.Error(codes.FailedPrecondition, "not configured") } - return p.c, nil -} - -func (p *Plugin) setConfig(c *Configuration) { - p.mu.Lock() - defer p.mu.Unlock() - p.c = c + return p.config, nil } func (p *Plugin) mintX509CA(ctx context.Context, csr []byte, preferredTTL int32) (*upstreamauthorityv1.MintX509CAResponse, error) { diff --git a/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas_test.go b/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas_test.go index 8c5de8a7da..e8975cfcf8 100644 --- a/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas_test.go +++ b/pkg/server/plugin/upstreamauthority/gcpcas/gcpcas_test.go @@ -12,6 +12,8 @@ import ( "time" "cloud.google.com/go/security/privateca/apiv1/privatecapb" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/pemutil" commonutil "github.com/spiffe/spire/pkg/common/util" "github.com/spiffe/spire/pkg/server/plugin/upstreamauthority" @@ -136,7 +138,10 @@ func TestGcpCAS(t *testing.T) { } upplugin := new(upstreamauthority.V1) - plugintest.Load(t, builtin(p), upplugin, plugintest.Configure(` + plugintest.Load(t, builtin(p), upplugin, plugintest.CoreConfig(catalog.CoreConfig{ + TrustDomain: spiffeid.RequireTrustDomainFromString("example.org"), + }), + plugintest.Configure(` root_cert_spec { project_name = "proj1" region_name = "us-central1" @@ -144,7 +149,7 @@ func TestGcpCAS(t *testing.T) { label_key = "proj-signer" label_value = "true" } - `)) + `)) priv := testkey.NewEC384(t) csr, err := commonutil.MakeCSRWithoutURISAN(priv) diff --git a/pkg/server/plugin/upstreamauthority/spire/spire.go b/pkg/server/plugin/upstreamauthority/spire/spire.go index 7d2a849150..c955c30fd7 100644 --- a/pkg/server/plugin/upstreamauthority/spire/spire.go +++ b/pkg/server/plugin/upstreamauthority/spire/spire.go @@ -19,6 +19,7 @@ import ( "github.com/spiffe/spire/pkg/common/coretypes/jwtkey" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/idutil" + "github.com/spiffe/spire/pkg/common/pluginconf" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" @@ -37,6 +38,17 @@ type Configuration struct { Experimental experimentalConfig `hcl:"experimental"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + // TODO: add field validation + return newConfig +} + type experimentalConfig struct { WorkloadAPINamedPipeName string `hcl:"workload_api_named_pipe_name" json:"workload_api_named_pipe_name"` } @@ -83,33 +95,17 @@ func New() *Plugin { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - // Parse HCL config payload into config struct - config := new(Configuration) - - if err := hcl.Decode(config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) - } - - if req.CoreConfiguration == nil { - return nil, status.Error(codes.InvalidArgument, "core configuration is required") - } - - if req.CoreConfiguration.TrustDomain == "" { - return nil, status.Error(codes.InvalidArgument, "trust_domain is required") + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } p.mtx.Lock() defer p.mtx.Unlock() - // Create trust domain - td, err := spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "trust_domain is malformed: %v", err) - } - p.trustDomain = td - - // Set config - p.config = config + // Swap Running Config + p.trustDomain, _ = spiffeid.TrustDomainFromString(req.CoreConfiguration.TrustDomain) + p.config = newConfig // Create spire-server client serverAddr := fmt.Sprintf("%s:%s", p.config.ServerAddr, p.config.ServerPort) @@ -118,7 +114,7 @@ func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (* return nil, status.Errorf(codes.InvalidArgument, "unable to set Workload API address: %v", err) } - serverID, err := idutil.ServerID(td) + serverID, err := idutil.ServerID(p.trustDomain) if err != nil { return nil, status.Errorf(codes.Internal, "unable to build server ID: %v", err) } @@ -128,6 +124,15 @@ func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (* return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + func (p *Plugin) SetLogger(log hclog.Logger) { p.log = log } diff --git a/pkg/server/plugin/upstreamauthority/spire/spire_test.go b/pkg/server/plugin/upstreamauthority/spire/spire_test.go index 86c147e3b3..58ad8c78ae 100644 --- a/pkg/server/plugin/upstreamauthority/spire/spire_test.go +++ b/pkg/server/plugin/upstreamauthority/spire/spire_test.go @@ -66,7 +66,7 @@ func TestConfigure(t *testing.T) { name: "malformed configuration", overrideConfig: "{1}", expectCode: codes.InvalidArgument, - expectMsgPrefix: "unable to decode configuration: expected: STRING got: NUMBER", + expectMsgPrefix: "plugin configuration is malformed", }, { name: "no trust domain", @@ -75,7 +75,7 @@ func TestConfigure(t *testing.T) { workloadAPISocket: "socketPath", overrideCoreConfig: &catalog.CoreConfig{}, expectCode: codes.InvalidArgument, - expectMsgPrefix: "trust_domain is required", + expectMsgPrefix: "server core configuration must contain trust_domain", }, } cases = append(cases, configureCasesOS(t)...) diff --git a/pkg/server/plugin/upstreamauthority/vault/vault.go b/pkg/server/plugin/upstreamauthority/vault/vault.go index 05bb09d057..c2de0a8b45 100644 --- a/pkg/server/plugin/upstreamauthority/vault/vault.go +++ b/pkg/server/plugin/upstreamauthority/vault/vault.go @@ -17,10 +17,13 @@ import ( "github.com/spiffe/spire/pkg/common/catalog" "github.com/spiffe/spire/pkg/common/coretypes/x509certificate" "github.com/spiffe/spire/pkg/common/pemutil" + "github.com/spiffe/spire/pkg/common/pluginconf" ) const ( pluginName = "vault" + + PluginConfigMalformed = "plugin configuration is malformed" ) // BuiltIn constructs a catalog.BuiltIn using a new instance of this plugin. @@ -58,6 +61,22 @@ type Configuration struct { Namespace string `hcl:"namespace" json:"namespace"` } +func buildConfig(coreConfig catalog.CoreConfig, hclText string, status *pluginconf.Status) *Configuration { + newConfig := new(Configuration) + if err := hcl.Decode(newConfig, hclText); err != nil { + status.ReportError("plugin configuration is malformed") + return nil + } + + // TODO: add field validations + + // TODO: consider moving some elements of parseAuthMethod into config checking + // TODO: consider moving some elements of genClientParams into config checking + // TODO: consider moving some elements of NewClientConfig into config checking + + return newConfig +} + // TokenAuthConfig represents parameters for token auth method type TokenAuthConfig struct { // Token string to set into "X-Vault-Token" header @@ -134,20 +153,19 @@ func (p *Plugin) SetLogger(log hclog.Logger) { } func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (*configv1.ConfigureResponse, error) { - config := new(Configuration) - - if err := hcl.Decode(&config, req.HclConfiguration); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "unable to decode configuration: %v", err) + newConfig, _, err := pluginconf.Build(req, buildConfig) + if err != nil { + return nil, err } p.mtx.Lock() defer p.mtx.Unlock() - am, err := parseAuthMethod(config) + am, err := parseAuthMethod(newConfig) if err != nil { return nil, err } - cp, err := p.genClientParams(am, config) + cp, err := p.genClientParams(am, newConfig) if err != nil { return nil, err } @@ -162,6 +180,15 @@ func (p *Plugin) Configure(_ context.Context, req *configv1.ConfigureRequest) (* return &configv1.ConfigureResponse{}, nil } +func (p *Plugin) Validate(_ context.Context, req *configv1.ValidateRequest) (*configv1.ValidateResponse, error) { + _, notes, err := pluginconf.Build(req, buildConfig) + + return &configv1.ValidateResponse{ + Valid: err == nil, + Notes: notes, + }, err +} + func (p *Plugin) MintX509CAAndSubscribe(req *upstreamauthorityv1.MintX509CARequest, stream upstreamauthorityv1.UpstreamAuthority_MintX509CAAndSubscribeServer) error { if p.cc == nil { return status.Error(codes.FailedPrecondition, "plugin not configured") diff --git a/pkg/server/plugin/upstreamauthority/vault/vault_test.go b/pkg/server/plugin/upstreamauthority/vault/vault_test.go index f1913cd546..65f704f728 100644 --- a/pkg/server/plugin/upstreamauthority/vault/vault_test.go +++ b/pkg/server/plugin/upstreamauthority/vault/vault_test.go @@ -149,7 +149,7 @@ func TestConfigure(t *testing.T) { name: "Malformed configuration", plainConfig: "invalid-config", expectCode: codes.InvalidArgument, - expectMsgPrefix: "unable to decode configuration:", + expectMsgPrefix: "plugin configuration is malformed", }, { name: "Required parameters are not given / k8s_auth_role_name",