From e6e61ed3f7110e9fca6d152320c62ec472519eb0 Mon Sep 17 00:00:00 2001 From: Alex Shtin Date: Fri, 4 Oct 2024 15:43:34 -0700 Subject: [PATCH] Unalias and validate search attributes when import workflow history (#6563) ## What changed? Unalias and validate search attributes when import workflow history. ## Why? Search attributes in workflow histostory JSON file might use aliases (if history was exported using `temporal` cli) or be invalid (name/type/length). Previously import blindly accepted any search attributes and then background visibility task processing could fail. Now search attributes are validated and unaliased (if needed) before getting to the history service. ## How did you test it? Added new unit tests for both valid and invalid cases. ## Potential risks No risks. ## Documentation No. --- Makefile | 10 +- cmd/tools/gensearchattributehelpers/main.go | 158 ++++++++++++ common/searchattribute/event.go | 28 ++ common/searchattribute/event_gen.go | 66 +++++ common/testing/testvars/any.go | 4 + common/testing/testvars/test_vars.go | 12 +- service/frontend/admin_handler.go | 82 +++++- service/frontend/admin_handler_test.go | 272 +++++++++++++++++++- service/frontend/fx.go | 2 + tools/tdbg/app.go | 2 +- tools/tdbg/commands.go | 3 +- 11 files changed, 622 insertions(+), 17 deletions(-) create mode 100644 cmd/tools/gensearchattributehelpers/main.go create mode 100644 common/searchattribute/event.go create mode 100644 common/searchattribute/event_gen.go diff --git a/Makefile b/Makefile index f2ca62a0e92..b938f276309 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ bins: temporal-server temporal-cassandra-tool temporal-sql-tool tdbg all: clean proto bins check test # Used in CI -ci-build-misc: print-go-version proto go-generate buf-breaking bins temporal-server-debug shell-check copyright-check goimports-all gomodtidy ensure-no-changes +ci-build-misc: print-go-version proto go-generate buf-breaking bins temporal-server-debug shell-check copyright-check goimports-all gomodtidy ensure-no-changes # Delete all build artifacts clean: clean-bins clean-test-results @@ -18,7 +18,7 @@ clean: clean-bins clean-test-results rm -rf $(LOCALBIN) # Recompile proto files. -proto: lint-protos lint-api protoc service-clients server-interceptors +proto: lint-protos lint-api protoc proto-codegen ######################################################################## .PHONY: proto protoc install bins ci-build-misc clean @@ -243,13 +243,13 @@ protoc: $(PROTOGEN) $(MOCKGEN) $(GOIMPORTS) $(PROTOC_GEN_GO) $(PROTOC_GEN_GO_GRP API_BINPB=$(API_BINPB) PROTO_ROOT=$(PROTO_ROOT) PROTO_OUT=$(PROTO_OUT) \ ./develop/protoc.sh -service-clients: +proto-codegen: @printf $(COLOR) "Generate service clients..." @go generate -run genrpcwrappers ./client/... - -server-interceptors: @printf $(COLOR) "Generate server interceptors..." @go generate ./common/rpc/interceptor/logtags/... + @printf $(COLOR) "Generate search attributes helpers..." + @go generate -run gensearchattributehelpers ./common/searchattribute/... update-go-api: @printf $(COLOR) "Update go.temporal.io/api@master..." diff --git a/cmd/tools/gensearchattributehelpers/main.go b/cmd/tools/gensearchattributehelpers/main.go new file mode 100644 index 00000000000..82a4a6f9222 --- /dev/null +++ b/cmd/tools/gensearchattributehelpers/main.go @@ -0,0 +1,158 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package main + +import ( + "flag" + "fmt" + "io" + "log" + "os" + "path" + "reflect" + "regexp" + "strings" + "text/template" + + historypb "go.temporal.io/api/history/v1" +) + +type ( + eventData struct { + AttributesTypeName string + } + + searchAttributesHelpersData struct { + Events []eventData + } +) + +var ( + // Is used to find attribute getters and extract the event type (match[1]). + attributesGetterRegex = regexp.MustCompile("^Get(.+EventAttributes)$") +) + +func main() { + outPathFlag := flag.String("out", ".", "path to write generated files") + licenseFlag := flag.String("copyright_file", "LICENSE", "path to license to copy into header") + flag.Parse() + + licenseText := readLicenseFile(*licenseFlag) + + eventHelpersFile := path.Join(*outPathFlag, "event_gen.go") + callWithFile(generateSearchAttributesEventHelpers, eventHelpersFile, licenseText) +} + +func generateSearchAttributesEventHelpers(w io.Writer) { + + writeSearchAttributesEventHelpers(w, ` +package searchattribute + +import ( + commonpb "go.temporal.io/api/common/v1" + historypb "go.temporal.io/api/history/v1" +) + +func SetToEvent(event *historypb.HistoryEvent, sas *commonpb.SearchAttributes) bool { + switch e := event.GetAttributes().(type) { + {{- range .Events}} + case *historypb.HistoryEvent_{{.AttributesTypeName}}: + e.{{.AttributesTypeName}}.SearchAttributes = sas + return true + {{- end}} + default: + return false + } +} + +func GetFromEvent(event *historypb.HistoryEvent) (*commonpb.SearchAttributes, bool) { + switch e := event.GetAttributes().(type) { + {{- range .Events}} + case *historypb.HistoryEvent_{{.AttributesTypeName}}: + return e.{{.AttributesTypeName}}.GetSearchAttributes(), true + {{- end}} + default: + return nil, false + } +} +`) +} + +func writeSearchAttributesEventHelpers(w io.Writer, tmpl string) { + sahd := searchAttributesHelpersData{} + + historyEventT := reflect.TypeOf((*historypb.HistoryEvent)(nil)) + + for i := 0; i < historyEventT.NumMethod(); i++ { + attributesGetter := historyEventT.Method(i) + matches := attributesGetterRegex.FindStringSubmatch(attributesGetter.Name) + if len(matches) < 2 { + continue + } + if attributesGetter.Type.NumOut() != 1 { + continue + } + if _, found := attributesGetter.Type.Out(0).MethodByName("GetSearchAttributes"); !found { + continue + } + + ed := eventData{ + AttributesTypeName: matches[1], + } + sahd.Events = append(sahd.Events, ed) + } + + fatalIfErr(template.Must(template.New("code").Parse(tmpl)).Execute(w, sahd)) +} + +func callWithFile(generator func(io.Writer), outFile string, licenseText string) { + w, err := os.Create(outFile) + fatalIfErr(err) + defer func() { + fatalIfErr(w.Close()) + }() + _, err = fmt.Fprintf(w, "%s\n// Code generated by cmd/tools/gensearchattributehelpers. DO NOT EDIT.\n", licenseText) + fatalIfErr(err) + generator(w) +} + +func readLicenseFile(filePath string) string { + text, err := os.ReadFile(filePath) + if err != nil { + panic(err) + } + var lines []string + for _, line := range strings.Split(string(text), "\n") { + lines = append(lines, strings.TrimRight("// "+line, " ")) + } + return strings.Join(lines, "\n") + "\n" +} + +func fatalIfErr(err error) { + if err != nil { + //nolint:revive // calls to log.Fatal only in main() or init() functions (revive) + log.Fatal(err) + } +} diff --git a/common/searchattribute/event.go b/common/searchattribute/event.go new file mode 100644 index 00000000000..1b02cddb60d --- /dev/null +++ b/common/searchattribute/event.go @@ -0,0 +1,28 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Generates event helper functions for setting and getting search attributes on history events. +//go:generate go run ../../cmd/tools/gensearchattributehelpers -copyright_file ../../LICENSE + +package searchattribute diff --git a/common/searchattribute/event_gen.go b/common/searchattribute/event_gen.go new file mode 100644 index 00000000000..857ea0bd39f --- /dev/null +++ b/common/searchattribute/event_gen.go @@ -0,0 +1,66 @@ +// The MIT License +// +// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. +// +// Copyright (c) 2020 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Code generated by cmd/tools/gensearchattributehelpers. DO NOT EDIT. + +package searchattribute + +import ( + commonpb "go.temporal.io/api/common/v1" + historypb "go.temporal.io/api/history/v1" +) + +func SetToEvent(event *historypb.HistoryEvent, sas *commonpb.SearchAttributes) bool { + switch e := event.GetAttributes().(type) { + case *historypb.HistoryEvent_StartChildWorkflowExecutionInitiatedEventAttributes: + e.StartChildWorkflowExecutionInitiatedEventAttributes.SearchAttributes = sas + return true + case *historypb.HistoryEvent_UpsertWorkflowSearchAttributesEventAttributes: + e.UpsertWorkflowSearchAttributesEventAttributes.SearchAttributes = sas + return true + case *historypb.HistoryEvent_WorkflowExecutionContinuedAsNewEventAttributes: + e.WorkflowExecutionContinuedAsNewEventAttributes.SearchAttributes = sas + return true + case *historypb.HistoryEvent_WorkflowExecutionStartedEventAttributes: + e.WorkflowExecutionStartedEventAttributes.SearchAttributes = sas + return true + default: + return false + } +} + +func GetFromEvent(event *historypb.HistoryEvent) (*commonpb.SearchAttributes, bool) { + switch e := event.GetAttributes().(type) { + case *historypb.HistoryEvent_StartChildWorkflowExecutionInitiatedEventAttributes: + return e.StartChildWorkflowExecutionInitiatedEventAttributes.GetSearchAttributes(), true + case *historypb.HistoryEvent_UpsertWorkflowSearchAttributesEventAttributes: + return e.UpsertWorkflowSearchAttributesEventAttributes.GetSearchAttributes(), true + case *historypb.HistoryEvent_WorkflowExecutionContinuedAsNewEventAttributes: + return e.WorkflowExecutionContinuedAsNewEventAttributes.GetSearchAttributes(), true + case *historypb.HistoryEvent_WorkflowExecutionStartedEventAttributes: + return e.WorkflowExecutionStartedEventAttributes.GetSearchAttributes(), true + default: + return nil, false + } +} diff --git a/common/testing/testvars/any.go b/common/testing/testvars/any.go index ea75647f55d..52445b42ac0 100644 --- a/common/testing/testvars/any.go +++ b/common/testing/testvars/any.go @@ -61,6 +61,10 @@ func (a Any) Int() int { return randInt(a.testHash, 3, 3, 3) } +func (a Any) Int64() int64 { + return int64(a.Int()) +} + func (a Any) EventID() int64 { // This produces EventID in XX0YY format, where XX is unique for every test and YY is a random number. return int64(randInt(a.testHash, 2, 1, 2)) diff --git a/common/testing/testvars/test_vars.go b/common/testing/testvars/test_vars.go index ef97e57d235..a69fe4c4278 100644 --- a/common/testing/testvars/test_vars.go +++ b/common/testing/testvars/test_vars.go @@ -282,8 +282,16 @@ func (tv *TestVars) QueryType(key ...string) string { return tv.getOrCreate("query_type", key).(string) } -func (tv *TestVars) WithQueryType(queryTypeID string, key ...string) *TestVars { - return tv.cloneSet("query_type", key, queryTypeID) +func (tv *TestVars) WithQueryType(queryType string, key ...string) *TestVars { + return tv.cloneSet("query_type", key, queryType) +} + +func (tv *TestVars) IndexName(key ...string) string { + return tv.getOrCreate("index_name", key).(string) +} + +func (tv *TestVars) WithIndexName(indexName string, key ...string) *TestVars { + return tv.cloneSet("index_name", key, indexName) } // ----------- Generic methods ------------ diff --git a/service/frontend/admin_handler.go b/service/frontend/admin_handler.go index ecf62532124..a607311c33b 100644 --- a/service/frontend/admin_handler.go +++ b/service/frontend/admin_handler.go @@ -36,6 +36,7 @@ import ( "time" "go.temporal.io/server/api/matchingservice/v1" + "go.temporal.io/server/common/persistence/visibility" "github.com/pborman/uuid" commonpb "go.temporal.io/api/common/v1" @@ -125,6 +126,8 @@ type ( namespaceRegistry namespace.Registry saProvider searchattribute.Provider saManager searchattribute.Manager + saMapperProvider searchattribute.MapperProvider + saValidator *searchattribute.Validator clusterMetadata cluster.Metadata healthServer *health.Server historyHealthChecker HealthChecker @@ -157,6 +160,7 @@ type ( NamespaceRegistry namespace.Registry SaProvider searchattribute.Provider SaManager searchattribute.Manager + SaMapperProvider searchattribute.MapperProvider ClusterMetadata cluster.Metadata HealthServer *health.Server EventSerializer serialization.Serializer @@ -227,11 +231,25 @@ func NewAdminHandler( namespaceRegistry: args.NamespaceRegistry, saProvider: args.SaProvider, saManager: args.SaManager, - clusterMetadata: args.ClusterMetadata, - healthServer: args.HealthServer, - historyHealthChecker: historyHealthChecker, - taskCategoryRegistry: args.CategoryRegistry, - matchingClient: args.matchingClient, + saMapperProvider: args.SaMapperProvider, + saValidator: searchattribute.NewValidator( + args.SaProvider, + args.SaMapperProvider, + args.Config.SearchAttributesNumberOfKeysLimit, + args.Config.SearchAttributesSizeOfValueLimit, + args.Config.SearchAttributesTotalSizeLimit, + args.visibilityMgr, + visibility.AllowListForValidation( + args.visibilityMgr.GetStoreNames(), + args.Config.VisibilityAllowList, + ), + args.Config.SuppressErrorSetSystemSearchAttribute, + ), + clusterMetadata: args.ClusterMetadata, + healthServer: args.HealthServer, + historyHealthChecker: historyHealthChecker, + taskCategoryRegistry: args.CategoryRegistry, + matchingClient: args.matchingClient, } } @@ -715,10 +733,14 @@ func (adh *AdminHandler) ImportWorkflowExecution( return nil, err } + unaliasedBatches, err := adh.unaliasAndValidateSearchAttributes(request.HistoryBatches, namespace.Name(request.GetNamespace())) + if err != nil { + return nil, err + } resp, err := adh.historyClient.ImportWorkflowExecution(ctx, &historyservice.ImportWorkflowExecutionRequest{ NamespaceId: namespaceID.String(), Execution: request.Execution, - HistoryBatches: request.HistoryBatches, + HistoryBatches: unaliasedBatches, VersionHistory: request.VersionHistory, Token: request.Token, }) @@ -730,6 +752,54 @@ func (adh *AdminHandler) ImportWorkflowExecution( }, nil } +func (adh *AdminHandler) unaliasAndValidateSearchAttributes(historyBatches []*commonpb.DataBlob, nsName namespace.Name) ([]*commonpb.DataBlob, error) { + var unaliasedBatches []*commonpb.DataBlob + for _, historyBatch := range historyBatches { + events, err := adh.eventSerializer.DeserializeEvents(historyBatch) + if err != nil { + return nil, serviceerror.NewInvalidArgument(err.Error()) + } + hasSas := false + for _, event := range events { + sas, _ := searchattribute.GetFromEvent(event) + if sas == nil { + continue + } + hasSas = true + + unaliasedSas, err := searchattribute.UnaliasFields(adh.saMapperProvider, sas, nsName.String()) + if err != nil { + var invArgErr *serviceerror.InvalidArgument + if !errors.As(err, &invArgErr) { + return nil, err + } + // Mapper returns InvalidArgument if alias is not found. It means that history has field names, not aliases. + // Ignore the error and proceed with the original search attributes. + unaliasedSas = sas + } + // Now validate that search attributes are valid. + err = adh.saValidator.Validate(unaliasedSas, nsName.String()) + if err != nil { + return nil, err + } + + _ = searchattribute.SetToEvent(event, unaliasedSas) + } + // If blob doesn't have search attributes, it can be used as is w/o serialization. + if !hasSas { + unaliasedBatches = append(unaliasedBatches, historyBatch) + continue + } + + unaliasedBatch, err := adh.eventSerializer.SerializeEvents(events, enumspb.ENCODING_TYPE_PROTO3) + if err != nil { + return nil, serviceerror.NewInvalidArgument(err.Error()) + } + unaliasedBatches = append(unaliasedBatches, unaliasedBatch) + } + return unaliasedBatches, nil +} + // DescribeMutableState returns information about the specified workflow execution. func (adh *AdminHandler) DescribeMutableState(ctx context.Context, request *adminservice.DescribeMutableStateRequest) (_ *adminservice.DescribeMutableStateResponse, retError error) { defer log.CapturePanic(adh.logger, &retError) diff --git a/service/frontend/admin_handler_test.go b/service/frontend/admin_handler_test.go index 3082b376536..e76d4dff0b2 100644 --- a/service/frontend/admin_handler_test.go +++ b/service/frontend/admin_handler_test.go @@ -29,6 +29,7 @@ import ( "errors" "fmt" "math/rand" + "strings" "sync" "testing" @@ -38,6 +39,7 @@ import ( "github.com/stretchr/testify/suite" commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" + historypb "go.temporal.io/api/history/v1" namespacepb "go.temporal.io/api/namespace/v1" "go.temporal.io/api/serviceerror" taskqueuepb "go.temporal.io/api/taskqueue/v1" @@ -58,6 +60,7 @@ import ( "go.temporal.io/server/common/clock" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/config" + "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/membership" "go.temporal.io/server/common/namespace" "go.temporal.io/server/common/persistence" @@ -68,10 +71,15 @@ import ( "go.temporal.io/server/common/primitives/timestamp" "go.temporal.io/server/common/resourcetest" "go.temporal.io/server/common/searchattribute" + test "go.temporal.io/server/common/testing" + "go.temporal.io/server/common/testing/historyrequire" "go.temporal.io/server/common/testing/mocksdk" + "go.temporal.io/server/common/testing/protorequire" + "go.temporal.io/server/common/testing/testvars" "go.temporal.io/server/service/history/tasks" "go.temporal.io/server/service/worker/dlq" "go.uber.org/mock/gomock" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" "google.golang.org/grpc/metadata" @@ -81,6 +89,8 @@ type ( adminHandlerSuite struct { suite.Suite *require.Assertions + historyrequire.HistoryRequire + protorequire.ProtoAssertions controller *gomock.Controller mockResource *resourcetest.Test @@ -95,6 +105,7 @@ type ( mockMetadata *cluster.MockMetadata mockProducer *persistence.MockNamespaceReplicationQueue mockMatchingClient *matchingservicemock.MockMatchingServiceClient + mockSaMapper *searchattribute.MockMapper namespace namespace.Name namespaceID namespace.ID @@ -111,6 +122,8 @@ func TestAdminHandlerSuite(t *testing.T) { func (s *adminHandlerSuite) SetupTest() { s.Assertions = require.New(s.T()) + s.ProtoAssertions = protorequire.New(s.T()) + s.HistoryRequire = historyrequire.New(s.T()) s.namespace = "some random namespace name" s.namespaceID = "deadd0d0-c001-face-d00d-000000000000" @@ -138,12 +151,22 @@ func (s *adminHandlerSuite) SetupTest() { s.mockProducer = persistence.NewMockNamespaceReplicationQueue(s.controller) s.mockMatchingClient = s.mockResource.MatchingClient + mockSaMapperProvider := searchattribute.NewMockMapperProvider(s.controller) + s.mockSaMapper = searchattribute.NewMockMapper(s.controller) + mockSaMapperProvider.EXPECT().GetMapper(s.namespace).Return(s.mockSaMapper, nil).AnyTimes() + persistenceConfig := &config.Persistence{ NumHistoryShards: 1, } cfg := &Config{ NumHistoryShards: 4, + + SearchAttributesNumberOfKeysLimit: dynamicconfig.GetIntPropertyFnFilteredByNamespace(10), + SearchAttributesSizeOfValueLimit: dynamicconfig.GetIntPropertyFnFilteredByNamespace(10), + SearchAttributesTotalSizeLimit: dynamicconfig.GetIntPropertyFnFilteredByNamespace(10), + VisibilityAllowList: dynamicconfig.GetBoolPropertyFnFilteredByNamespace(false), + SuppressErrorSetSystemSearchAttribute: dynamicconfig.GetBoolPropertyFnFilteredByNamespace(false), } args := NewAdminHandlerArgs{ persistenceConfig, @@ -151,7 +174,7 @@ func (s *adminHandlerSuite) SetupTest() { s.mockResource.GetNamespaceReplicationQueue(), s.mockProducer, s.mockResource.ESClient, - s.mockResource.GetVisibilityManager(), + s.mockVisibilityMgr, s.mockResource.GetLogger(), s.mockResource.GetTaskManager(), s.mockResource.GetExecutionManager(), @@ -167,6 +190,7 @@ func (s *adminHandlerSuite) SetupTest() { s.mockResource.GetNamespaceRegistry(), s.mockResource.GetSearchAttributesProvider(), s.mockResource.GetSearchAttributesManager(), + mockSaMapperProvider, s.mockMetadata, health.NewServer(), serialization.NewSerializer(), @@ -176,6 +200,7 @@ func (s *adminHandlerSuite) SetupTest() { } s.mockMetadata.EXPECT().GetCurrentClusterName().Return(uuid.New()).AnyTimes() s.mockExecutionMgr.EXPECT().GetName().Return("mock-execution-manager").AnyTimes() + s.mockVisibilityMgr.EXPECT().GetStoreNames().Return([]string{"mock-vis-store"}) s.handler = NewAdminHandler(args) s.handler.Start() } @@ -1834,6 +1859,251 @@ func (s *adminHandlerSuite) TestDescribeTaskQueuePartition() { s.validatePhysicalTaskQueueInfo(versionedPhysicalTaskQueueInfo, resp.VersionsInfoInternal[buildID].GetPhysicalTaskQueueInfo()) } +func (s *adminHandlerSuite) TestImportWorkflowExecution_NoSearchAttributes() { + tv := testvars.New(s.T()).WithNamespaceName(s.namespace).WithNamespaceID(s.namespaceID) + + serializer := serialization.NewSerializer() + generator := test.InitializeHistoryEventGenerator(tv.NamespaceName(), tv.NamespaceID(), tv.Any().Int64()) + + // Generate random history. + var historyBatches []*commonpb.DataBlob + for generator.HasNextVertex() { + events := generator.GetNextVertices() + var historyEvents []*historypb.HistoryEvent + for _, event := range events { + historyEvent := event.GetData().(*historypb.HistoryEvent) + historyEvents = append(historyEvents, historyEvent) + } + historyBatch, err := serializer.SerializeEvents(historyEvents, enumspb.ENCODING_TYPE_PROTO3) + s.NoError(err) + historyBatches = append(historyBatches, historyBatch) + } + + s.mockNamespaceCache.EXPECT().GetNamespaceID(tv.NamespaceName()).Return(tv.NamespaceID(), nil) + + s.mockHistoryClient.EXPECT().ImportWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, request *historyservice.ImportWorkflowExecutionRequest, opts ...grpc.CallOption) (*historyservice.ImportWorkflowExecutionResponse, error) { + s.Equal(tv.NamespaceID().String(), request.NamespaceId) + s.Equal(historyBatches, request.HistoryBatches, "history batches shouldn't be reserialized because there is no search attributes") + return &historyservice.ImportWorkflowExecutionResponse{}, nil + }) + _, err := s.handler.ImportWorkflowExecution(context.Background(), &adminservice.ImportWorkflowExecutionRequest{ + Namespace: tv.NamespaceName().String(), + Execution: tv.WorkflowExecution(), + HistoryBatches: historyBatches, + VersionHistory: nil, + Token: nil, + }) + s.NoError(err) +} + +func (s *adminHandlerSuite) TestImportWorkflowExecution_WithAliasedSearchAttributes() { + tv := testvars.New(s.T()).WithNamespaceName(s.namespace).WithNamespaceID(s.namespaceID) + + serializer := serialization.NewSerializer() + + subTests := []struct { + Name string + SaName string + ExpectedErr error + }{ + { + Name: "valid SA alias", + SaName: "AliasOfCustomKeywordField", + ExpectedErr: nil, + }, + { + Name: "invalid SA alias", + SaName: "InvalidAlias", + ExpectedErr: &serviceerror.InvalidArgument{}, + }, + { + Name: "invalid SA field", + SaName: "AliasOfInvalidField", + ExpectedErr: &serviceerror.InvalidArgument{}, + }, + } + for _, subTest := range subTests { + s.T().Run(subTest.Name, func(t *testing.T) { + generator := test.InitializeHistoryEventGenerator(tv.NamespaceName(), tv.NamespaceID(), tv.Any().Int64()) + saValue := tv.Any().Payload() + aliasedSas := &commonpb.SearchAttributes{IndexedFields: map[string]*commonpb.Payload{ + subTest.SaName: saValue, + }} + + // Generate random history and set search attributes for all events that have search_attributes field. + var historyBatches []*commonpb.DataBlob + eventsWithSasCount := 0 + for generator.HasNextVertex() { + events := generator.GetNextVertices() + var historyEvents []*historypb.HistoryEvent + for _, event := range events { + historyEvent := event.GetData().(*historypb.HistoryEvent) + eventHasSas := searchattribute.SetToEvent(historyEvent, aliasedSas) + if eventHasSas { + eventsWithSasCount++ + } + historyEvents = append(historyEvents, historyEvent) + } + historyBatch, err := serializer.SerializeEvents(historyEvents, enumspb.ENCODING_TYPE_PROTO3) + s.NoError(err) + historyBatches = append(historyBatches, historyBatch) + } + if subTest.ExpectedErr != nil { + // Import will fail fast on first event and won't check other events. + eventsWithSasCount = 1 + } + + s.mockNamespaceCache.EXPECT().GetNamespaceID(tv.NamespaceName()).Return(tv.NamespaceID(), nil) + s.mockVisibilityMgr.EXPECT().GetIndexName().Return(tv.IndexName()).Times(eventsWithSasCount) + + // Mock mapper remove alias from alias name. + s.mockSaMapper.EXPECT().GetFieldName(gomock.Any(), tv.NamespaceName().String()).DoAndReturn(func(alias string, nsName string) (string, error) { + if strings.HasPrefix(alias, "AliasOf") { + return strings.TrimPrefix(alias, "AliasOf"), nil + } + return "", serviceerror.NewInvalidArgument("unknown alias") + }).Times(eventsWithSasCount) + + s.mockResource.SearchAttributesProvider.EXPECT().GetSearchAttributes(tv.IndexName(), gomock.Any()).Return(searchattribute.TestNameTypeMap, nil).Times(eventsWithSasCount) + + if subTest.ExpectedErr != nil { + s.mockSaMapper.EXPECT().GetAlias(gomock.Any(), tv.NamespaceName().String()).Return("", serviceerror.NewInvalidArgument("")) + } else { + s.mockVisibilityMgr.EXPECT().ValidateCustomSearchAttributes(gomock.Any()).Return(nil, nil).Times(eventsWithSasCount) + s.mockHistoryClient.EXPECT().ImportWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, request *historyservice.ImportWorkflowExecutionRequest, opts ...grpc.CallOption) (*historyservice.ImportWorkflowExecutionResponse, error) { + s.Equal(tv.NamespaceID().String(), request.NamespaceId) + for _, historyBatch := range request.HistoryBatches { + events, err := serializer.DeserializeEvents(historyBatch) + s.NoError(err) + for _, event := range events { + unaliasedSas, eventHasSas := searchattribute.GetFromEvent(event) + if eventHasSas { + s.NotNil(unaliasedSas, "search attributes must be set on every event with search_attributes field") + s.Len(unaliasedSas.GetIndexedFields(), 1, "only 1 search attribute must be set") + s.ProtoEqual(saValue, unaliasedSas.GetIndexedFields()["CustomKeywordField"]) + } + } + } + return &historyservice.ImportWorkflowExecutionResponse{}, nil + }) + } + _, err := s.handler.ImportWorkflowExecution(context.Background(), &adminservice.ImportWorkflowExecutionRequest{ + Namespace: tv.NamespaceName().String(), + Execution: tv.WorkflowExecution(), + HistoryBatches: historyBatches, + VersionHistory: nil, + Token: nil, + }) + if subTest.ExpectedErr == nil { + s.NoError(err) + } else { + s.Error(err) + s.ErrorAs(err, &subTest.ExpectedErr) + } + }) + } +} + +func (s *adminHandlerSuite) TestImportWorkflowExecution_WithNonAliasedSearchAttributes() { + tv := testvars.New(s.T()).WithNamespaceName(s.namespace).WithNamespaceID(s.namespaceID) + + serializer := serialization.NewSerializer() + subTests := []struct { + Name string + SaName string + ExpectedErr error + }{ + { + Name: "valid SA field", + SaName: "CustomKeywordField", + ExpectedErr: nil, + }, + { + Name: "invalid SA field", + SaName: "InvalidField", + ExpectedErr: &serviceerror.InvalidArgument{}, + }, + } + for _, subTest := range subTests { + s.T().Run(subTest.Name, func(t *testing.T) { + generator := test.InitializeHistoryEventGenerator(tv.NamespaceName(), tv.NamespaceID(), tv.Any().Int64()) + saValue := tv.Any().Payload() + aliasedSas := &commonpb.SearchAttributes{IndexedFields: map[string]*commonpb.Payload{ + subTest.SaName: saValue, + }} + + // Generate random history and set search attributes for all events that have search_attributes field. + var historyBatches []*commonpb.DataBlob + eventsWithSasCount := 0 + for generator.HasNextVertex() { + events := generator.GetNextVertices() + var historyEvents []*historypb.HistoryEvent + for _, event := range events { + historyEvent := event.GetData().(*historypb.HistoryEvent) + eventHasSas := searchattribute.SetToEvent(historyEvent, aliasedSas) + if eventHasSas { + eventsWithSasCount++ + } + historyEvents = append(historyEvents, historyEvent) + } + historyBatch, err := serializer.SerializeEvents(historyEvents, enumspb.ENCODING_TYPE_PROTO3) + s.NoError(err) + historyBatches = append(historyBatches, historyBatch) + } + if subTest.ExpectedErr != nil { + // Import will fail fast on first event and won't check other events. + eventsWithSasCount = 1 + } + + s.mockNamespaceCache.EXPECT().GetNamespaceID(tv.NamespaceName()).Return(tv.NamespaceID(), nil) + s.mockVisibilityMgr.EXPECT().GetIndexName().Return(tv.IndexName()).Times(eventsWithSasCount) + + s.mockResource.SearchAttributesProvider.EXPECT().GetSearchAttributes(tv.IndexName(), gomock.Any()).Return(searchattribute.TestNameTypeMap, nil).Times(eventsWithSasCount) + + // Mock mapper returns error because field name is not an alias. + s.mockSaMapper.EXPECT().GetFieldName(gomock.Any(), tv.NamespaceName().String()).DoAndReturn(func(alias string, nsName string) (string, error) { + return "", serviceerror.NewInvalidArgument("unknown alias") + }).Times(eventsWithSasCount) + + if subTest.ExpectedErr != nil { + s.mockSaMapper.EXPECT().GetAlias(gomock.Any(), tv.NamespaceName().String()).Return("", serviceerror.NewInvalidArgument("")) + } else { + s.mockVisibilityMgr.EXPECT().ValidateCustomSearchAttributes(gomock.Any()).Return(nil, nil).Times(eventsWithSasCount) + s.mockHistoryClient.EXPECT().ImportWorkflowExecution(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, request *historyservice.ImportWorkflowExecutionRequest, opts ...grpc.CallOption) (*historyservice.ImportWorkflowExecutionResponse, error) { + s.Equal(tv.NamespaceID().String(), request.NamespaceId) + for _, historyBatch := range request.HistoryBatches { + events, err := serializer.DeserializeEvents(historyBatch) + s.NoError(err) + for _, event := range events { + unaliasedSas, eventHasSas := searchattribute.GetFromEvent(event) + if eventHasSas { + s.NotNil(unaliasedSas, "search attributes must be set on every event with search_attributes field") + s.Len(unaliasedSas.GetIndexedFields(), 1, "only 1 search attribute must be set") + s.ProtoEqual(saValue, unaliasedSas.GetIndexedFields()["CustomKeywordField"]) + } + } + } + return &historyservice.ImportWorkflowExecutionResponse{}, nil + }) + } + + _, err := s.handler.ImportWorkflowExecution(context.Background(), &adminservice.ImportWorkflowExecutionRequest{ + Namespace: tv.NamespaceName().String(), + Execution: tv.WorkflowExecution(), + HistoryBatches: historyBatches, + VersionHistory: nil, + Token: nil, + }) + if subTest.ExpectedErr == nil { + s.NoError(err) + } else { + s.Error(err) + s.ErrorAs(err, &subTest.ExpectedErr) + } + }) + } +} + func (s *adminHandlerSuite) validatePhysicalTaskQueueInfo(expectedPhysicalTaskQueueInfo *taskqueuespb.PhysicalTaskQueueInfo, responsePhysicalTaskQueueInfo *taskqueuespb.PhysicalTaskQueueInfo) { diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 66abf1ab157..0e7a320b37b 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -602,6 +602,7 @@ func AdminHandlerProvider( namespaceRegistry namespace.Registry, saProvider searchattribute.Provider, saManager searchattribute.Manager, + saMapperProvider searchattribute.MapperProvider, clusterMetadata cluster.Metadata, healthServer *health.Server, eventSerializer serialization.Serializer, @@ -631,6 +632,7 @@ func AdminHandlerProvider( namespaceRegistry, saProvider, saManager, + saMapperProvider, clusterMetadata, healthServer, eventSerializer, diff --git a/tools/tdbg/app.go b/tools/tdbg/app.go index 25eca920d5a..7c53662cc6f 100644 --- a/tools/tdbg/app.go +++ b/tools/tdbg/app.go @@ -126,7 +126,7 @@ func NewCliApp(opts ...Option) *cli.App { }, &cli.StringFlag{ Name: color.FlagColor, - Usage: fmt.Sprintf("when to use color: %v, %v, %v.", color.Auto, color.Always, color.Never), + Usage: fmt.Sprintf("When to use color: %v, %v, %v.", color.Auto, color.Always, color.Never), Value: string(color.Auto), }, } diff --git a/tools/tdbg/commands.go b/tools/tdbg/commands.go index cd01dc6ef9c..806e314c940 100644 --- a/tools/tdbg/commands.go +++ b/tools/tdbg/commands.go @@ -203,8 +203,7 @@ func AdminImportWorkflow(c *cli.Context, clientFactory ClientFactory) error { } var token []byte - - blobs := []*commonpb.DataBlob{} + var blobs []*commonpb.DataBlob blobSize := 0 for i := 0; i < len(historyBatches)+1; i++ { if i < len(historyBatches) {