diff --git a/client/client_bean.go b/client/client_bean.go index 0b034fde254..3e7fffdd234 100644 --- a/client/client_bean.go +++ b/client/client_bean.go @@ -51,7 +51,6 @@ type ( GetMatchingClient(namespaceIDToName NamespaceIDToNameFunc) (matchingservice.MatchingServiceClient, error) GetFrontendClient() workflowservice.WorkflowServiceClient GetRemoteAdminClient(string) (adminservice.AdminServiceClient, error) - SetRemoteAdminClient(string, adminservice.AdminServiceClient) GetRemoteFrontendClient(string) (grpc.ClientConnInterface, workflowservice.WorkflowServiceClient, error) } @@ -107,27 +106,6 @@ func NewClientBean(factory Factory, clusterMetadata cluster.Metadata) (Bean, err WorkflowServiceClient: client, } - for clusterName, info := range clusterMetadata.GetAllClusterInfo() { - if !info.Enabled || clusterName == currentClusterName { - continue - } - adminClient = factory.NewRemoteAdminClientWithTimeout( - info.RPCAddress, - admin.DefaultTimeout, - admin.DefaultLargeTimeout, - ) - conn, client = factory.NewRemoteFrontendClientWithTimeout( - info.RPCAddress, - frontend.DefaultTimeout, - frontend.DefaultLongPollTimeout, - ) - adminClients[clusterName] = adminClient - frontendClients[clusterName] = frontendClient{ - connection: conn, - WorkflowServiceClient: client, - } - } - bean := &clientBeanImpl{ factory: factory, historyClient: historyClient, @@ -212,16 +190,6 @@ func (h *clientBeanImpl) GetRemoteAdminClient(cluster string) (adminservice.Admi return client, nil } -func (h *clientBeanImpl) SetRemoteAdminClient( - cluster string, - client adminservice.AdminServiceClient, -) { - h.adminClientsLock.Lock() - defer h.adminClientsLock.Unlock() - - h.adminClients[cluster] = client -} - func (h *clientBeanImpl) GetRemoteFrontendClient(clusterName string) (grpc.ClientConnInterface, workflowservice.WorkflowServiceClient, error) { h.frontendClientsLock.RLock() client, ok := h.frontendClients[clusterName] @@ -266,13 +234,6 @@ func (h *clientBeanImpl) GetRemoteFrontendClient(clusterName string) (grpc.Clien return client.connection, client, nil } -func (h *clientBeanImpl) setRemoteAdminClientLocked( - cluster string, - client adminservice.AdminServiceClient, -) { - h.adminClients[cluster] = client -} - func (h *clientBeanImpl) lazyInitMatchingClient(namespaceIDToName NamespaceIDToNameFunc) (matchingservice.MatchingServiceClient, error) { h.Lock() defer h.Unlock() diff --git a/client/client_bean_mock.go b/client/client_bean_mock.go index 37648542a15..b591e99a864 100644 --- a/client/client_bean_mock.go +++ b/client/client_bean_mock.go @@ -140,15 +140,3 @@ func (mr *MockBeanMockRecorder) GetRemoteFrontendClient(arg0 any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRemoteFrontendClient", reflect.TypeOf((*MockBean)(nil).GetRemoteFrontendClient), arg0) } - -// SetRemoteAdminClient mocks base method. -func (m *MockBean) SetRemoteAdminClient(arg0 string, arg1 adminservice.AdminServiceClient) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetRemoteAdminClient", arg0, arg1) -} - -// SetRemoteAdminClient indicates an expected call of SetRemoteAdminClient. -func (mr *MockBeanMockRecorder) SetRemoteAdminClient(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRemoteAdminClient", reflect.TypeOf((*MockBean)(nil).SetRemoteAdminClient), arg0, arg1) -} diff --git a/tests/onebox.go b/tests/onebox.go index 21bab7eecfb..782d043807b 100644 --- a/tests/onebox.go +++ b/tests/onebox.go @@ -461,6 +461,7 @@ func (c *temporalImpl) startFrontend( fx.Supply( persistenceConfig, serviceName, + c.mockAdminClient, ), fx.Provide(c.frontendConfigProvider), fx.Provide(func() listenHostPort { return listenHostPort(c.FrontendGRPCAddress()) }), @@ -482,7 +483,7 @@ func (c *temporalImpl) startFrontend( fx.Provide(func() authorization.Authorizer { return c }), fx.Provide(func() authorization.ClaimMapper { return c }), fx.Provide(func() authorization.JWTAudienceMapper { return nil }), - fx.Provide(func() client.FactoryProvider { return client.NewFactoryProvider() }), + fx.Provide(c.newClientFactoryProvider), fx.Provide(func() searchattribute.Mapper { return nil }), // Comment the line above and uncomment the line below to test with search attributes mapper. // fx.Provide(func() searchattribute.Mapper { return NewSearchAttributeTestMapper() }), @@ -509,14 +510,6 @@ func (c *temporalImpl) startFrontend( c.logger.Fatal("unable to construct frontend service", tag.Error(err)) } - if c.mockAdminClient != nil { - if clientBean != nil { - for serviceName, client := range c.mockAdminClient { - clientBean.SetRemoteAdminClient(serviceName, client) - } - } - } - c.frontendApp = feApp c.frontendService = frontendService c.frontendNamespaceRegistry = namespaceRegistry @@ -565,6 +558,7 @@ func (c *temporalImpl) startHistory( fx.Supply( persistenceConfig, serviceName, + c.mockAdminClient, ), fx.Provide(c.GetMetricsHandler), fx.Provide(func() listenHostPort { return listenHostPort(host) }), @@ -580,7 +574,7 @@ func (c *temporalImpl) startHistory( fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), fx.Provide(sdkClientFactoryProvider), - fx.Provide(func() client.FactoryProvider { return client.NewFactoryProvider() }), + fx.Provide(c.newClientFactoryProvider), fx.Provide(func() searchattribute.Mapper { return nil }), // Comment the line above and uncomment the line below to test with search attributes mapper. // fx.Provide(func() searchattribute.Mapper { return NewSearchAttributeTestMapper() }), @@ -611,14 +605,6 @@ func (c *temporalImpl) startHistory( c.logger.Fatal("unable to construct history service", tag.Error(err)) } - if c.mockAdminClient != nil { - if clientBean != nil { - for serviceName, client := range c.mockAdminClient { - clientBean.SetRemoteAdminClient(serviceName, client) - } - } - } - // TODO: this is not correct when there are multiple history hosts as later client will overwrite previous ones. // However current interface for getting history client doesn't specify which client it needs and the tests that use this API // depends on the fact that there's only one history host. @@ -668,6 +654,7 @@ func (c *temporalImpl) startMatching( fx.Supply( persistenceConfig, serviceName, + c.mockAdminClient, ), fx.Provide(c.GetMetricsHandler), fx.Provide(func() listenHostPort { return listenHostPort(c.MatchingGRPCServiceAddress()) }), @@ -681,7 +668,7 @@ func (c *temporalImpl) startMatching( fx.Provide(func() *cluster.Config { return c.clusterMetadataConfig }), fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), - fx.Provide(func() client.FactoryProvider { return client.NewFactoryProvider() }), + fx.Provide(c.newClientFactoryProvider), fx.Provide(func() searchattribute.Mapper { return nil }), fx.Provide(func() resolver.ServiceResolver { return resolver.NewNoopResolver() }), fx.Provide(persistenceClient.FactoryProvider), @@ -705,13 +692,6 @@ func (c *temporalImpl) startMatching( if err != nil { c.logger.Fatal("unable to start matching service", tag.Error(err)) } - if c.mockAdminClient != nil { - if clientBean != nil { - for serviceName, client := range c.mockAdminClient { - clientBean.SetRemoteAdminClient(serviceName, client) - } - } - } matchingConnection, err := rpc.Dial(c.MatchingGRPCServiceAddress(), nil, c.logger) if err != nil { @@ -766,6 +746,7 @@ func (c *temporalImpl) startWorker( fx.Supply( persistenceConfig, serviceName, + c.mockAdminClient, ), fx.Provide(c.GetMetricsHandler), fx.Provide(func() listenHostPort { return listenHostPort(c.WorkerGRPCServiceAddress()) }), @@ -781,7 +762,7 @@ func (c *temporalImpl) startWorker( fx.Provide(func() carchiver.ArchivalMetadata { return c.archiverMetadata }), fx.Provide(func() provider.ArchiverProvider { return c.archiverProvider }), fx.Provide(sdkClientFactoryProvider), - fx.Provide(func() client.FactoryProvider { return client.NewFactoryProvider() }), + fx.Provide(c.newClientFactoryProvider), fx.Provide(func() searchattribute.Mapper { return nil }), fx.Provide(func() resolver.ServiceResolver { return resolver.NewNoopResolver() }), fx.Provide(persistenceClient.FactoryProvider), @@ -943,6 +924,66 @@ func (c *temporalImpl) newRPCFactory( ), nil } +func (c *temporalImpl) newClientFactoryProvider( + config *cluster.Config, + mockAdminClient map[string]adminservice.AdminServiceClient, +) client.FactoryProvider { + return &clientFactoryProvider{ + config: config, + mockAdminClient: mockAdminClient, + } +} + +type clientFactoryProvider struct { + config *cluster.Config + mockAdminClient map[string]adminservice.AdminServiceClient +} + +func (p *clientFactoryProvider) NewFactory( + rpcFactory common.RPCFactory, + monitor membership.Monitor, + metricsHandler metrics.Handler, + dc *dynamicconfig.Collection, + numberOfHistoryShards int32, + logger log.Logger, + throttledLogger log.Logger, +) client.Factory { + f := client.NewFactoryProvider().NewFactory( + rpcFactory, + monitor, + metricsHandler, + dc, + numberOfHistoryShards, + logger, + throttledLogger, + ) + return &clientFactory{ + Factory: f, + config: p.config, + mockAdminClient: p.mockAdminClient, + } +} + +type clientFactory struct { + client.Factory + config *cluster.Config + mockAdminClient map[string]adminservice.AdminServiceClient +} + +// override just this one and look up connections in mock admin client map +func (f *clientFactory) NewRemoteAdminClientWithTimeout(rpcAddress string, timeout time.Duration, largeTimeout time.Duration) adminservice.AdminServiceClient { + var clusterName string + for name, info := range f.config.ClusterInformation { + if rpcAddress == info.RPCAddress { + clusterName = name + } + } + if mock, ok := f.mockAdminClient[clusterName]; ok { + return mock + } + return f.Factory.NewRemoteAdminClientWithTimeout(rpcAddress, timeout, largeTimeout) +} + func (c *temporalImpl) SetOnGetClaims(fn func(*authorization.AuthInfo) (*authorization.Claims, error)) { c.callbackLock.Lock() c.onGetClaims = fn