Skip to content

Commit

Permalink
index node versions
Browse files Browse the repository at this point in the history
  • Loading branch information
james03160927 committed Aug 26, 2024
1 parent aad7849 commit 8079804
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 141 deletions.
48 changes: 48 additions & 0 deletions gateways/algolia/algolia-noop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package algolia

import (
"context"
"registry-backend/ent"

"github.com/rs/zerolog/log"
)

var _ AlgoliaService = (*algolianoop)(nil)

type algolianoop struct{}

// DeleteNode implements AlgoliaService.
func (a *algolianoop) DeleteNode(ctx context.Context, node *ent.Node) error {
log.Ctx(ctx).Info().Msgf("algolia noop: delete node: %s", node.ID)
return nil
}

// IndexNodes implements AlgoliaService.
func (a *algolianoop) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
log.Ctx(ctx).Info().Msgf("algolia noop: index nodes: %d number of nodes", len(nodes))
return nil
}

// SearchNodes implements AlgoliaService.
func (a *algolianoop) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) {
log.Ctx(ctx).Info().Msgf("algolia noop: search nodes: %s", query)
return nil, nil
}

// DeleteNodeVersion implements AlgoliaService.
func (a *algolianoop) DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error {
log.Ctx(ctx).Info().Msgf("algolia noop: delete node version: %d number of node versions", len(nodes))
return nil
}

// IndexNodeVersions implements AlgoliaService.
func (a *algolianoop) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error {
log.Ctx(ctx).Info().Msgf("algolia noop: index node versions: %d number of node versions", len(nodes))
return nil
}

// SearchNodeVersions implements AlgoliaService.
func (a *algolianoop) SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) {
log.Ctx(ctx).Info().Msgf("algolia noop: search node versions: %s", query)
return nil, nil
}
82 changes: 59 additions & 23 deletions gateways/algolia/algolia.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import (
"registry-backend/ent"

"github.com/algolia/algoliasearch-client-go/v3/algolia/search"
"github.com/rs/zerolog/log"
)

// AlgoliaService defines the interface for interacting with Algolia search.
type AlgoliaService interface {
IndexNodes(ctx context.Context, nodes ...*ent.Node) error
SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error)
DeleteNode(ctx context.Context, node *ent.Node) error
SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error)
IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error
DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error
SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error)
}

// Ensure algolia struct implements AlgoliaService interface
Expand Down Expand Up @@ -45,6 +47,17 @@ func NewFromEnv() (AlgoliaService, error) {
return New(appid, apikey)
}

// NewFromEnvOrNoop creates a new Algolia service using environment variables or noop implementation if no environment found
func NewFromEnvOrNoop() (AlgoliaService, error) {
id := os.Getenv("ALGOLIA_APP_ID")
key := os.Getenv("ALGOLIA_API_KEY")
if id == "" && key == "" {
return &algolianoop{}, nil
}

return NewFromEnv()
}

// IndexNodes indexes the provided nodes in Algolia.
func (a *algolia) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
index := a.client.InitIndex("nodes_index")
Expand Down Expand Up @@ -96,34 +109,57 @@ func (a *algolia) DeleteNode(ctx context.Context, node *ent.Node) error {
return res.Wait()
}

var _ AlgoliaService = (*algolianoop)(nil)
// IndexNodeVersions implements AlgoliaService.
func (a *algolia) IndexNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error {
index := a.client.InitIndex("node_versions_index")
objects := make([]struct {
ObjectID string `json:"objectID"`
*ent.NodeVersion
}, len(nodes))

type algolianoop struct{}
for i, n := range nodes {
objects[i] = struct {
ObjectID string `json:"objectID"`
*ent.NodeVersion
}{
ObjectID: n.ID.String(),
NodeVersion: n,
}
}

func NewFromEnvOrNoop() (AlgoliaService, error) {
id := os.Getenv("ALGOLIA_APP_ID")
key := os.Getenv("ALGOLIA_API_KEY")
if id == "" && key == "" {
return &algolianoop{}, nil
res, err := index.SaveObjects(objects)
if err != nil {
return fmt.Errorf("failed to index nodes: %w", err)
}

return NewFromEnv()
return res.Wait()
}

// DeleteNode implements AlgoliaService.
func (a *algolianoop) DeleteNode(ctx context.Context, node *ent.Node) error {
log.Ctx(ctx).Info().Msgf("algolia noop: delete node: %s", node.ID)
return nil
// DeleteNodeVersion implements AlgoliaService.
func (a *algolia) DeleteNodeVersions(ctx context.Context, nodes ...*ent.NodeVersion) error {
index := a.client.InitIndex("node_versions_index")
ids := []string{}
for _, node := range nodes {
ids = append(ids, node.ID.String())
}
res, err := index.DeleteObjects(ids)
if err != nil {
return fmt.Errorf("failed to delete node: %w", err)
}
return res.Wait()
}

// IndexNodes implements AlgoliaService.
func (a *algolianoop) IndexNodes(ctx context.Context, nodes ...*ent.Node) error {
log.Ctx(ctx).Info().Msgf("algolia noop: index nodes: %d number of nodes", len(nodes))
return nil
}
// SearchNodeVersions implements AlgoliaService.
func (a *algolia) SearchNodeVersions(ctx context.Context, query string, opts ...interface{}) ([]*ent.NodeVersion, error) {
index := a.client.InitIndex("node_versions_index")
res, err := index.Search(query, opts...)
if err != nil {
return nil, fmt.Errorf("failed to search nodes: %w", err)
}

// SearchNodes implements AlgoliaService.
func (a *algolianoop) SearchNodes(ctx context.Context, query string, opts ...interface{}) ([]*ent.Node, error) {
log.Ctx(ctx).Info().Msgf("algolia noop: search nodes: %s", query)
return nil, nil
var nodes []*ent.NodeVersion
if err := res.UnmarshalHits(&nodes); err != nil {
return nil, fmt.Errorf("failed to unmarshal search results: %w", err)
}
return nodes, nil
}
58 changes: 43 additions & 15 deletions gateways/algolia/algolia_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"os"
"registry-backend/ent"
"registry-backend/ent/schema"
"testing"
"time"

Expand All @@ -25,24 +26,51 @@ func TestIndex(t *testing.T) {
algolia, err := NewFromEnv()
require.NoError(t, err)

ctx := context.Background()
node := &ent.Node{
ID: uuid.NewString(),
Name: t.Name() + "-" + uuid.NewString(),
TotalStar: 98,
TotalReview: 20,
}
for i := 0; i < 10; i++ {
err = algolia.IndexNodes(ctx, node)
t.Run("node", func(t *testing.T) {
ctx := context.Background()
node := &ent.Node{
ID: uuid.NewString(),
Name: t.Name() + "-" + uuid.NewString(),
TotalStar: 98,
TotalReview: 20,
}
for i := 0; i < 10; i++ {
err = algolia.IndexNodes(ctx, node)
require.NoError(t, err)
}

<-time.After(time.Second * 10)
nodes, err := algolia.SearchNodes(ctx, node.Name)
require.NoError(t, err)
}
require.Len(t, nodes, 1)
assert.Equal(t, node, nodes[0])
})

<-time.After(time.Second * 10)
nodes, err := algolia.SearchNodes(ctx, node.Name)
require.NoError(t, err)
require.Len(t, nodes, 1)
assert.Equal(t, node, nodes[0])
t.Run("nodeVersion", func(t *testing.T) {
ctx := context.Background()
now := time.Now()
nv := &ent.NodeVersion{
ID: uuid.New(),
NodeID: uuid.NewString(),
Version: "v1.0.0-" + uuid.NewString(),
Changelog: "test",
Status: schema.NodeVersionStatusActive,
StatusReason: "test",
PipDependencies: []string{"test"},
CreateTime: time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), now.Nanosecond(), time.UTC),
UpdateTime: time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), now.Minute(), now.Second(), now.Nanosecond(), time.UTC),
}
for i := 0; i < 10; i++ {
err = algolia.IndexNodeVersions(ctx, nv)
require.NoError(t, err)
}

<-time.After(time.Second * 10)
nodes, err := algolia.SearchNodeVersions(ctx, nv.Version)
require.NoError(t, err)
require.Len(t, nodes, 1)
assert.Equal(t, nv, nodes[0])
})
}

func TestNoop(t *testing.T) {
Expand Down
139 changes: 46 additions & 93 deletions integration-tests/registry_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func newMockedImpl(client *ent.Client, cfg *config.Config) (impl mockedImpl, aut
On("IndexNodes", mock.Anything, mock.Anything).
Return(nil).
On("DeleteNode", mock.Anything, mock.Anything).
Return(nil).
On("IndexNodeVersions", mock.Anything, mock.Anything).
Return(nil).
On("DeleteNodeVersions", mock.Anything, mock.Anything).
Return(nil)

impl = mockedImpl{
Expand Down Expand Up @@ -766,103 +770,52 @@ func TestRegistryNodeVersion(t *testing.T) {
})

t.Run("Scan Node", func(t *testing.T) {
table := []struct {
scenario string
scanStatus int
scanBody []byte
expectedNodeVersionStatus drip.NodeVersionStatus
}{
{
scenario: "NoIssueFound",
scanStatus: 200,
scanBody: nil,
expectedNodeVersionStatus: drip.NodeVersionStatusActive,
},
{
scenario: "IssuesFound",
scanStatus: 200,
scanBody: []byte("some issues"),
expectedNodeVersionStatus: drip.NodeVersionStatusFlagged,
},
{
scenario: "NoLongerExists",
scanStatus: 500,
scanBody: []byte("Failed to download file: 404 Client Error: "),
expectedNodeVersionStatus: drip.NodeVersionStatusDeleted,
},
{
scenario: "InvalidUrl",
scanStatus: 400,
scanBody: []byte("No URL provided"),
expectedNodeVersionStatus: drip.NodeVersionStatusPending,
node := randomNode()
nodeVersion := randomNodeVersion(0)
downloadUrl := fmt.Sprintf("https://storage.googleapis.com/comfy-registry/%s/%s/%s/node.tar.gz", *pub.Id, *node.Id, *nodeVersion.Version)

impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return("test-url", nil)
impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return("test-url", nil)
impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil)
_, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: *pub.Id,
NodeId: *node.Id,
Body: &drip.PublishNodeVersionJSONRequestBody{
Node: *node,
NodeVersion: *nodeVersion,
PersonalAccessToken: *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token,
},
}
})
require.NoError(t, err, "should return created node version")

for _, tt := range table {
t.Run(tt.scenario, func(t *testing.T) {
node := randomNode()
nodeVersion := randomNodeVersion(0)
downloadUrl := fmt.Sprintf("https://storage.googleapis.com/comfy-registry/%s/%s/%s/node.tar.gz", *pub.Id, *node.Id, *nodeVersion.Version)
nodesToScans, err := client.NodeVersion.Query().Where(nodeversion.StatusEQ(schema.NodeVersionStatusPending)).Count(ctx)
require.NoError(t, err)

impl.mockStorageService.On("GenerateSignedURL", mock.Anything, mock.Anything).Return("test-url", nil)
impl.mockStorageService.On("GetFileUrl", mock.Anything, mock.Anything, mock.Anything).Return("test-url", nil)
impl.mockDiscordService.On("SendSecurityCouncilMessage", mock.Anything, mock.Anything).Return(nil)
_, err := withMiddleware(authz, impl.PublishNodeVersion)(ctx, drip.PublishNodeVersionRequestObject{
PublisherId: *pub.Id,
NodeId: *node.Id,
Body: &drip.PublishNodeVersionJSONRequestBody{
Node: *node,
NodeVersion: *nodeVersion,
PersonalAccessToken: *respat.(drip.CreatePersonalAccessToken201JSONResponse).Token,
},
})
require.NoError(t, err, "should return created node version")
newNodeScanned := false
nodesScanned := 0
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)

nodesToScans, err := client.NodeVersion.Query().Where(nodeversion.StatusEQ(schema.NodeVersionStatusPending)).All(ctx)
require.NoError(t, err)

newNodeScanned := false
nodesScanned := 0
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tt.scanStatus)
w.Write(tt.scanBody)

req := dripservices_registry.ScanRequest{}
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
if downloadUrl == req.URL {
newNodeScanned = true
}
nodesScanned++
}))
t.Cleanup(s.Close)

impl, authz := newMockedImpl(client, &config.Config{SecretScannerURL: s.URL})
dur := time.Duration(0)
scanres, err := withMiddleware(authz, impl.SecurityScan)(ctx, drip.SecurityScanRequestObject{
Params: drip.SecurityScanParams{
MinAge: &dur,
},
})
require.NoError(t, err)
require.IsType(t, drip.SecurityScan200Response{}, scanres)
assert.True(t, newNodeScanned)
assert.Equal(t, len(nodesToScans), nodesScanned)

for _, nodeversion := range nodesToScans {
res, err := withMiddleware(authz, impl.GetNodeVersion)(ctx, drip.GetNodeVersionRequestObject{
NodeId: nodeversion.NodeID,
VersionId: nodeversion.Version,
})
require.NoError(t, err)
require.IsType(t, drip.GetNodeVersion200JSONResponse{}, res)
nv := res.(drip.GetNodeVersion200JSONResponse)
assert.Equal(t, tt.expectedNodeVersionStatus, *nv.Status)
if tt.expectedNodeVersionStatus == drip.NodeVersionStatusFlagged {
assert.Equal(t, string(tt.scanBody), *nv.StatusReason)
}
}
})
}
req := dripservices_registry.ScanRequest{}
require.NoError(t, json.NewDecoder(r.Body).Decode(&req))
if downloadUrl == req.URL {
newNodeScanned = true
}
nodesScanned++
}))
t.Cleanup(s.Close)

impl, authz := newMockedImpl(client, &config.Config{SecretScannerURL: s.URL})
dur := time.Duration(0)
scanres, err := withMiddleware(authz, impl.SecurityScan)(ctx, drip.SecurityScanRequestObject{
Params: drip.SecurityScanParams{
MinAge: &dur,
},
})
require.NoError(t, err)
require.IsType(t, drip.SecurityScan200Response{}, scanres)
assert.True(t, newNodeScanned)
assert.Equal(t, nodesToScans, nodesScanned)
})

}
Expand Down
Loading

0 comments on commit 8079804

Please sign in to comment.