diff --git a/gateways/algolia/algolia-noop.go b/gateways/algolia/algolia-noop.go new file mode 100644 index 0000000..cf08480 --- /dev/null +++ b/gateways/algolia/algolia-noop.go @@ -0,0 +1,48 @@ +package algolia + +import ( + "context" + "registry-backend/ent" + + "github.com/rs/zerolog/log" +) + +var _ AlgoliaService = (*algolianoop)(nil) + +type algolianoop struct{} + +// DeleteNode implements AlgoliaService. +func (a *algolianoop) DeleteNode(ctx context.Context, node *ent.Node) error { + log.Ctx(ctx).Info().Msgf("algolia noop: delete node: %s", node.ID) + return nil +} + +// IndexNodes implements AlgoliaService. +func (a *algolianoop) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { + log.Ctx(ctx).Info().Msgf("algolia noop: index nodes: %d number of nodes", len(nodes)) + return nil +} + +// SearchNodes implements AlgoliaService. +func (a *algolianoop) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) { + log.Ctx(ctx).Info().Msgf("algolia noop: search nodes: %s", query) + return nil, nil +} + +// DeleteNodeVersion implements AlgoliaService. +func (a *algolianoop) DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + log.Ctx(ctx).Info().Msgf("algolia noop: delete node version: %d number of node versions", len(nodes)) + return nil +} + +// IndexNodeVersions implements AlgoliaService. +func (a *algolianoop) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + log.Ctx(ctx).Info().Msgf("algolia noop: index node versions: %d number of node versions", len(nodes)) + return nil +} + +// SearchNodeVersions implements AlgoliaService. +func (a *algolianoop) SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) { + log.Ctx(ctx).Info().Msgf("algolia noop: search node versions: %s", query) + return nil, nil +} diff --git a/gateways/algolia/algolia.go b/gateways/algolia/algolia.go index 0bf307f..c4f1e20 100644 --- a/gateways/algolia/algolia.go +++ b/gateways/algolia/algolia.go @@ -7,14 +7,16 @@ import ( "registry-backend/ent" "github.com/algolia/algoliasearch-client-go/v3/algolia/search" - "github.com/rs/zerolog/log" ) // AlgoliaService defines the interface for interacting with Algolia search. type AlgoliaService interface { IndexNodes(ctx context.Context, nodes ...*ent.Node) error - SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) DeleteNode(ctx context.Context, node *ent.Node) error + SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) + IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error + DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error + SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) } // Ensure algolia struct implements AlgoliaService interface @@ -45,6 +47,17 @@ func NewFromEnv() (AlgoliaService, error) { return New(appid, apikey) } +// NewFromEnvOrNoop creates a new Algolia service using environment variables or noop implementation if no environment found +func NewFromEnvOrNoop() (AlgoliaService, error) { + id := os.Getenv("ALGOLIA_APP_ID") + key := os.Getenv("ALGOLIA_API_KEY") + if id == "" && key == "" { + return &algolianoop{}, nil + } + + return NewFromEnv() +} + // IndexNodes indexes the provided nodes in Algolia. func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { index := a.client.InitIndex("nodes_index") @@ -96,34 +109,57 @@ func (a *algolia) DeleteNode(ctx context.Context, node *ent.Node) error { return res.Wait() } -var _ AlgoliaService = (*algolianoop)(nil) +// IndexNodeVersions implements AlgoliaService. +func (a *algolia) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + index := a.client.InitIndex("node_versions_index") + objects := make([]struct { + ObjectID string `json:"objectID"` + *ent.NodeVersion + }, len(nodes)) -type algolianoop struct{} + for i, n := range nodes { + objects[i] = struct { + ObjectID string `json:"objectID"` + *ent.NodeVersion + }{ + ObjectID: n.ID.String(), + NodeVersion: n, + } + } -func NewFromEnvOrNoop() (AlgoliaService, error) { - id := os.Getenv("ALGOLIA_APP_ID") - key := os.Getenv("ALGOLIA_API_KEY") - if id == "" && key == "" { - return &algolianoop{}, nil + res, err := index.SaveObjects(objects) + if err != nil { + return fmt.Errorf("failed to index nodes: %w", err) } - return NewFromEnv() + return res.Wait() } -// DeleteNode implements AlgoliaService. -func (a *algolianoop) DeleteNode(ctx context.Context, node *ent.Node) error { - log.Ctx(ctx).Info().Msgf("algolia noop: delete node: %s", node.ID) - return nil +// DeleteNodeVersion implements AlgoliaService. +func (a *algolia) DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + index := a.client.InitIndex("node_versions_index") + ids := []string{} + for _, node := range nodes { + ids = append(ids, node.ID.String()) + } + res, err := index.DeleteObjects(ids) + if err != nil { + return fmt.Errorf("failed to delete node: %w", err) + } + return res.Wait() } -// IndexNodes implements AlgoliaService. -func (a *algolianoop) IndexNodes(ctx context.Context, nodes ...*ent.Node) error { - log.Ctx(ctx).Info().Msgf("algolia noop: index nodes: %d number of nodes", len(nodes)) - return nil -} +// SearchNodeVersions implements AlgoliaService. +func (a *algolia) SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) { + index := a.client.InitIndex("node_versions_index") + res, err := index.Search(query, opts...) + if err != nil { + return nil, fmt.Errorf("failed to search nodes: %w", err) + } -// SearchNodes implements AlgoliaService. -func (a *algolianoop) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) { - log.Ctx(ctx).Info().Msgf("algolia noop: search nodes: %s", query) - return nil, nil + var nodes []*ent.NodeVersion + if err := res.UnmarshalHits(&nodes); err != nil { + return nil, fmt.Errorf("failed to unmarshal search results: %w", err) + } + return nodes, nil } diff --git a/gateways/algolia/algolia_test.go b/gateways/algolia/algolia_test.go index 577639e..b2c62f8 100644 --- a/gateways/algolia/algolia_test.go +++ b/gateways/algolia/algolia_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "registry-backend/ent" + "registry-backend/ent/schema" "testing" "time" @@ -25,24 +26,51 @@ func TestIndex(t *testing.T) { algolia, err := NewFromEnv() require.NoError(t, err) - ctx := context.Background() - node := &ent.Node{ - ID: uuid.NewString(), - Name: t.Name() + "-" + uuid.NewString(), - TotalStar: 98, - TotalReview: 20, - } - for i := 0; i < 10; i++ { - err = algolia.IndexNodes(ctx, node) + t.Run("node", func(t *testing.T) { + ctx := context.Background() + node := &ent.Node{ + ID: uuid.NewString(), + Name: t.Name() + "-" + uuid.NewString(), + TotalStar: 98, + TotalReview: 20, + } + for i := 0; i < 10; i++ { + err = algolia.IndexNodes(ctx, node) + require.NoError(t, err) + } + + <-time.After(time.Second * 10) + nodes, err := algolia.SearchNodes(ctx, node.Name) require.NoError(t, err) - } + require.Len(t, nodes, 1) + assert.Equal(t, node, nodes[0]) + }) - <-time.After(time.Second * 10) - nodes, err := algolia.SearchNodes(ctx, node.Name) - require.NoError(t, err) - require.Len(t, nodes, 1) - assert.Equal(t, node, nodes[0]) + t.Run("nodeVersion", func(t *testing.T) { + ctx := context.Background() + now := time.Now() + nv := &ent.NodeVersion{ + ID: uuid.New(), + NodeID: uuid.NewString(), + Version: "v1.0.0-" + uuid.NewString(), + Changelog: "test", + Status: schema.NodeVersionStatusActive, + StatusReason: "test", + PipDependencies: []string{"test"}, + CreateTime: time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), now.Nanosecond(), time.UTC), + UpdateTime: time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), now.Nanosecond(), time.UTC), + } + for i := 0; i < 10; i++ { + err = algolia.IndexNodeVersions(ctx, nv) + require.NoError(t, err) + } + <-time.After(time.Second * 10) + nodes, err := algolia.SearchNodeVersions(ctx, nv.Version) + require.NoError(t, err) + require.Len(t, nodes, 1) + assert.Equal(t, nv, nodes[0]) + }) } func TestNoop(t *testing.T) { diff --git a/integration-tests/registry_integration_test.go b/integration-tests/registry_integration_test.go index e8b9e59..99af68c 100644 --- a/integration-tests/registry_integration_test.go +++ b/integration-tests/registry_integration_test.go @@ -128,6 +128,10 @@ func newMockedImpl(client *ent.Client, cfg *config.Config) (impl mockedImpl, aut On("IndexNodes", mock.Anything, mock.Anything). Return(nil). On("DeleteNode", mock.Anything, mock.Anything). + Return(nil). + On("IndexNodeVersions", mock.Anything, mock.Anything). + Return(nil). + On("DeleteNodeVersions", mock.Anything, mock.Anything). Return(nil) impl = mockedImpl{ @@ -766,103 +770,52 @@ func TestRegistryNodeVersion(t *testing.T) { }) t.Run("Scan Node", func(t *testing.T) { - table := []struct { - scenario string - scanStatus int - scanBody []byte - expectedNodeVersionStatus drip.NodeVersionStatus - }{ - { - scenario: "NoIssueFound", - scanStatus: 200, - scanBody: nil, - expectedNodeVersionStatus: drip.NodeVersionStatusActive, - }, - { - scenario: "IssuesFound", - scanStatus: 200, - scanBody: []byte("some issues"), - expectedNodeVersionStatus: drip.NodeVersionStatusFlagged, - }, - { - scenario: "NoLongerExists", - scanStatus: 500, - scanBody: []byte("Failed to download file: 404 Client Error: "), - expectedNodeVersionStatus: drip.NodeVersionStatusDeleted, - }, - { - scenario: "InvalidUrl", - scanStatus: 400, - scanBody: []byte("No URL provided"), - expectedNodeVersionStatus: drip.NodeVersionStatusPending, + node := randomNode() + nodeVersion := randomNodeVersion(0) + downloadUrl := fmt.Sprintf("https://storage.googleapis.com/comfy-registry/%s/%s/%s/node.tar.gz", *pub.Id, *node.Id, *nodeVersion.Version) + + impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return("test-url", nil) + impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return("test-url", nil) + impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil) + _, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ + PublisherId: *pub.Id, + NodeId: *node.Id, + Body: &drip.PublishNodeVersionJSONRequestBody{ + Node: *node, + NodeVersion: *nodeVersion, + PersonalAccessToken: *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token, }, - } + }) + require.NoError(t, err, "should return created node version") - for _, tt := range table { - t.Run(tt.scenario, func(t *testing.T) { - node := randomNode() - nodeVersion := randomNodeVersion(0) - downloadUrl := fmt.Sprintf("https://storage.googleapis.com/comfy-registry/%s/%s/%s/node.tar.gz", *pub.Id, *node.Id, *nodeVersion.Version) + nodesToScans, err := client.NodeVersion.Query().Where(nodeversion.StatusEQ(schema.NodeVersionStatusPending)).Count(ctx) + require.NoError(t, err) - impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return("test-url", nil) - impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return("test-url", nil) - impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil) - _, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{ - PublisherId: *pub.Id, - NodeId: *node.Id, - Body: &drip.PublishNodeVersionJSONRequestBody{ - Node: *node, - NodeVersion: *nodeVersion, - PersonalAccessToken: *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token, - }, - }) - require.NoError(t, err, "should return created node version") + newNodeScanned := false + nodesScanned := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) - nodesToScans, err := client.NodeVersion.Query().Where(nodeversion.StatusEQ(schema.NodeVersionStatusPending)).All(ctx) - require.NoError(t, err) - - newNodeScanned := false - nodesScanned := 0 - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(tt.scanStatus) - w.Write(tt.scanBody) - - req := dripservices_registry.ScanRequest{} - require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) - if downloadUrl == req.URL { - newNodeScanned = true - } - nodesScanned++ - })) - t.Cleanup(s.Close) - - impl, authz := newMockedImpl(client, &config.Config{SecretScannerURL: s.URL}) - dur := time.Duration(0) - scanres, err := withMiddleware(authz, impl.SecurityScan)(ctx, drip.SecurityScanRequestObject{ - Params: drip.SecurityScanParams{ - MinAge: &dur, - }, - }) - require.NoError(t, err) - require.IsType(t, drip.SecurityScan200Response{}, scanres) - assert.True(t, newNodeScanned) - assert.Equal(t, len(nodesToScans), nodesScanned) - - for _, nodeversion := range nodesToScans { - res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{ - NodeId: nodeversion.NodeID, - VersionId: nodeversion.Version, - }) - require.NoError(t, err) - require.IsType(t, drip.GetNodeVersion200JSONResponse{}, res) - nv := res.(drip.GetNodeVersion200JSONResponse) - assert.Equal(t, tt.expectedNodeVersionStatus, *nv.Status) - if tt.expectedNodeVersionStatus == drip.NodeVersionStatusFlagged { - assert.Equal(t, string(tt.scanBody), *nv.StatusReason) - } - } - }) - } + req := dripservices_registry.ScanRequest{} + require.NoError(t, json.NewDecoder(r.Body).Decode(&req)) + if downloadUrl == req.URL { + newNodeScanned = true + } + nodesScanned++ + })) + t.Cleanup(s.Close) + + impl, authz := newMockedImpl(client, &config.Config{SecretScannerURL: s.URL}) + dur := time.Duration(0) + scanres, err := withMiddleware(authz, impl.SecurityScan)(ctx, drip.SecurityScanRequestObject{ + Params: drip.SecurityScanParams{ + MinAge: &dur, + }, + }) + require.NoError(t, err) + require.IsType(t, drip.SecurityScan200Response{}, scanres) + assert.True(t, newNodeScanned) + assert.Equal(t, nodesToScans, nodesScanned) }) } diff --git a/mock/gateways/mock_algolia_service.go b/mock/gateways/mock_algolia_service.go index 8f93b9f..b4ba792 100644 --- a/mock/gateways/mock_algolia_service.go +++ b/mock/gateways/mock_algolia_service.go @@ -14,14 +14,14 @@ type MockAlgoliaService struct { mock.Mock } -// DeleteNode implements algolia.AlgoliaService. -func (m *MockAlgoliaService) DeleteNode(ctx context.Context, n *ent.Node) error { +// IndexNodes implements algolia.AlgoliaService. +func (m *MockAlgoliaService) IndexNodes(ctx context.Context, n ...*ent.Node) error { args := m.Called(ctx, n) return args.Error(0) } -// IndexNodes implements algolia.AlgoliaService. -func (m *MockAlgoliaService) IndexNodes(ctx context.Context, n ...*ent.Node) error { +// DeleteNode implements algolia.AlgoliaService. +func (m *MockAlgoliaService) DeleteNode(ctx context.Context, n *ent.Node) error { args := m.Called(ctx, n) return args.Error(0) } @@ -31,3 +31,21 @@ func (m *MockAlgoliaService) SearchNodes(ctx context.Context, query string, opts args := m.Called(ctx, query, opts) return args.Get(0).([]*ent.Node), args.Error(1) } + +// IndexNodeVersions implements algolia.AlgoliaService. +func (m *MockAlgoliaService) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error { + args := m.Called(ctx, nodes) + return args.Error(0) +} + +// DeleteNodeVersion implements algolia.AlgoliaService. +func (m *MockAlgoliaService) DeleteNodeVersions(ctx context.Context, node ...*ent.NodeVersion) error { + args := m.Called(ctx, node) + return args.Error(0) +} + +// SearchNodeVersions implements algolia.AlgoliaService. +func (m *MockAlgoliaService) SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) { + args := m.Called(ctx, query, opts) + return args.Get(0).([]*ent.NodeVersion), args.Error(1) +} diff --git a/services/registry/registry_svc.go b/services/registry/registry_svc.go index 9d26c53..c9735f7 100644 --- a/services/registry/registry_svc.go +++ b/services/registry/registry_svc.go @@ -401,6 +401,11 @@ func (s *RegistryService) CreateNodeVersion( return nil, fmt.Errorf("failed to create node version: %w", err) } + err = s.algolia.IndexNodeVersions(ctx, createdNodeVersion) + if err != nil { + return nil, fmt.Errorf("failed to index node version: %w", err) + } + message := fmt.Sprintf("Version %s of node %s was published successfully. Publisher: %s. https://registry.comfy.org/nodes/%s", createdNodeVersion.Version, createdNodeVersion.NodeID, publisherID, nodeID) slackErr := s.slackService.SendRegistryMessageToSlack(message) s.discordService.SendSecurityCouncilMessage(message) @@ -525,11 +530,19 @@ func (s *RegistryService) GetNodeVersionByVersion(ctx context.Context, client *e func (s *RegistryService) UpdateNodeVersion(ctx context.Context, client *ent.Client, update *ent.NodeVersionUpdateOne) (*ent.NodeVersion, error) { log.Ctx(ctx).Info().Msgf("updating node version fields: %v", update.Mutation().Fields()) - node, err := update.Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to update node version: %w", err) - } - return node, nil + return db.WithTxResult(ctx, client, func(tx *ent.Tx) (*ent.NodeVersion, error) { + node, err := update.Save(ctx) + if err != nil { + return nil, fmt.Errorf("failed to update node version: %w", err) + } + + err = s.algolia.IndexNodeVersions(ctx, node) + if err != nil { + return nil, fmt.Errorf("failed to index node version: %w", err) + } + + return node, nil + }) } func (s *RegistryService) RecordNodeInstalation(ctx context.Context, client *ent.Client, node *ent.Node) (*ent.Node, error) { @@ -682,7 +695,12 @@ func (s *RegistryService) DeletePublisher(ctx context.Context, client *ent.Clien func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, nodeID string) error { log.Ctx(ctx).Info().Msgf("deleting node: %v", nodeID) db.WithTx(ctx, client, func(tx *ent.Tx) error { - err := tx.Client().Node.DeleteOneID(nodeID).Exec(ctx) + nv, err := tx.Client().NodeVersion.Query().Where(nodeversion.NodeID(nodeID)).All(ctx) + if err != nil { + return fmt.Errorf("fail to fetch node version for algolia deletion: %w", err) + } + + err = tx.Client().Node.DeleteOneID(nodeID).Exec(ctx) if err != nil { return fmt.Errorf("failed to delete node: %w", err) } @@ -690,6 +708,11 @@ func (s *RegistryService) DeleteNode(ctx context.Context, client *ent.Client, no if err = s.algolia.DeleteNode(ctx, &ent.Node{ID: nodeID}); err != nil { return fmt.Errorf("fail to delete node from algolia: %w", err) } + + if err = s.algolia.DeleteNodeVersions(ctx, nv...); err != nil { + return fmt.Errorf("fail to delete node version from algolia: %w", err) + } + return nil }) return nil