Skip to content

Commit

Permalink
impl
Browse files Browse the repository at this point in the history
  • Loading branch information
shibukazu committed Aug 23, 2024
1 parent db3aa53 commit 29dd802
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 78 deletions.
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ services:
environment:
- OPEN-VE_MODE=slave
- OPEN-VE_SLAVE_ID=slave-node
- OPEN-VE_SLAVE_SLAVE_GRPC_ADDR=slave-node:9000
- OPEN-VE_SLAVE_MASTER_GRPC_ADDR=master-node:9000
- OPEN-VE_SLAVE_SLAVE_HTTP_ADDR=http://slave-node:8080
- OPEN-VE_SLAVE_MASTER_HTTP_ADDR=http://master-node:8080
- OPEN-VE_SLAVE_MASTER_TLS_ENABLED=
- OPEN-VE_HTTP_PORT=
- OPEN-VE_HTTP_CORS_ALLOWED_ORIGINS=
Expand Down
34 changes: 17 additions & 17 deletions go/cmd/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ func NewRunCommand() *cobra.Command {
if id == "" {
return failure.New(appError.ErrConfigFileSyntaxError, failure.Message("ID of the slave server is required"))
}
slaveAddr := viper.GetString("slave.slaveGRPCAddr")
if slaveAddr == "" {
return failure.New(appError.ErrConfigFileSyntaxError, failure.Message("gRPC address of the slave server is required"))
slaveHTTPAddr := viper.GetString("slave.slaveHTTPAddr")
if slaveHTTPAddr == "" {
return failure.New(appError.ErrConfigFileSyntaxError, failure.Message("HTTP address of the slave server is required"))
}
masterAddr := viper.GetString("slave.masterGRPCAddr")
if masterAddr == "" {
return failure.New(appError.ErrConfigFileSyntaxError, failure.Message("gRPC address of the master server is required"))
masterHTTPAddr := viper.GetString("slave.masterHTTPAddr")
if masterHTTPAddr == "" {
return failure.New(appError.ErrConfigFileSyntaxError, failure.Message("HTTP address of the master server is required"))
}
}
return nil
Expand All @@ -63,17 +63,17 @@ func NewRunCommand() *cobra.Command {
MustBindPFlag("slave.id", flags.Lookup("slave-id"))
viper.MustBindEnv("slave.id", "OPEN-VE_SLAVE_ID")

flags.String("slave-slave-grpc-addr", defaultConfig.Slave.SlaveGRPCAddr, "gRPC address of the slave server")
MustBindPFlag("slave.slaveGRPCAddr", flags.Lookup("slave-slave-grpc-addr"))
viper.MustBindEnv("slave.slaveGRPCAddr", "OPEN-VE_SLAVE_SLAVE_GRPC_ADDR")
flags.String("slave-slave-http-addr", defaultConfig.Slave.SlaveHTTPAddr, "HTTP address of the slave server")
MustBindPFlag("slave.slaveHTTPAddr", flags.Lookup("slave-slave-http-addr"))
viper.MustBindEnv("slave.slaveHTTPAddr", "OPEN-VE_SLAVE_SLAVE_HTTP_ADDR")

flags.String("slave-master-grpc-addr", defaultConfig.Slave.MasterGRPCAddr, "gRPC address of the master server")
MustBindPFlag("slave.masterGRPCAddr", flags.Lookup("slave-master-grpc-addr"))
viper.MustBindEnv("slave.masterGRPCAddr", "OPEN-VE_SLAVE_MASTER_GRPC_ADDR")
flags.String("slave-master-http-addr", defaultConfig.Slave.MasterHTTPAddr, "HTTP address of the master server")
MustBindPFlag("slave.masterHTTPAddr", flags.Lookup("slave-master-http-addr"))
viper.MustBindEnv("slave.masterHTTPAddr", "OPEN-VE_SLAVE_MASTER_HTTP_ADDR")

flags.Bool("slave-master-grpc-tls-enabled", defaultConfig.Slave.MasterGRPCTLSEnabled, "connect to master server with TLS")
MustBindPFlag("slave.masterGRPCTLSEnabled", flags.Lookup("slave-master-grpc-tls-enabled"))
viper.MustBindEnv("slave.masterGRPCTLSEnabled", "OPEN-VE_SLAVE_MASTER_GRPC_TLS_ENABLED")
flags.Bool("slave-master-http-tls-enabled", defaultConfig.Slave.MasterHTTPTLSEnabled, "connect to master server with TLS")
MustBindPFlag("slave.masterHTTPTLSEnabled", flags.Lookup("slave-master-http-tls-enabled"))
viper.MustBindEnv("slave.masterHTTPTLSEnabled", "OPEN-VE_SLAVE_MASTER_HTTP_TLS_ENABLED")

// HTTP
flags.String("http-port", defaultConfig.Http.Port, "HTTP server port")
Expand Down Expand Up @@ -204,7 +204,7 @@ func run(cmd *cobra.Command, args []string) {
validator := validator.NewValidator(logger, store)
slaveManager := slave.NewSlaveManager(logger)

gw := server.NewGateway(&cfg.Http, &cfg.GRPC, logger, dslReader)
gw := server.NewGateway(cfg.Mode, &cfg.Http, &cfg.GRPC, logger, dslReader, slaveManager)
wg.Add(1)

logger.Info("🚀 Open-VE: starting...", slog.Any("config", cfg))
Expand All @@ -222,7 +222,7 @@ func run(cmd *cobra.Command, args []string) {

if cfg.Mode == "slave" {
wg.Add(1)
slaveRegistrar := slave.NewSlaveRegistrar(cfg.Slave.Id, cfg.Slave.SlaveGRPCAddr, cfg.GRPC.TLS.Enabled, cfg.Slave.MasterGRPCAddr, cfg.Slave.MasterGRPCTLSEnabled, dslReader, logger)
slaveRegistrar := slave.NewSlaveRegistrar(cfg.Slave.Id, cfg.Slave.SlaveHTTPAddr, cfg.GRPC.TLS.Enabled, cfg.Slave.MasterHTTPAddr, cfg.Slave.MasterHTTPTLSEnabled, dslReader, logger)
go func(wg *sync.WaitGroup) {
logger.Info("🚀 slave registration timer: starting..")
slaveRegistrar.RegisterTimer(ctx, wg)
Expand Down
3 changes: 2 additions & 1 deletion go/pkg/appError/serviceError.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ const (

ErrRequestParameterInvalid = "RequestParameterInvalid"

ErrValidateServiceIDNotFound = "ValidateServiceIDNotFound"
ErrValidateServiceIDNotFound = "ValidateServiceIDNotFound"
ErrValidateServiceForwardFailed = "ValidateServiceForwardFailed"

ErrDSLServiceDSLSyntaxError = "DSLServiceDSLSyntaxError"

Expand Down
12 changes: 6 additions & 6 deletions go/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ type Config struct {

type SlaveConfig struct {
Id string `yaml:"id"`
SlaveGRPCAddr string `yaml:"slaveAddr"`
MasterGRPCTLSEnabled bool `yaml:"masterGRPCTLSEnabled"`
MasterGRPCAddr string `yaml:"masterGRPCAddr"`
SlaveHTTPAddr string `yaml:"slaveHTTPAddr"`
MasterHTTPTLSEnabled bool `yaml:"masterHTTPTLSEnabled"`
MasterHTTPAddr string `yaml:"masterHTTPAddr"`
}

type HttpConfig struct {
Expand Down Expand Up @@ -55,9 +55,9 @@ func DefaultConfig() *Config {
Mode: "master",
Slave: SlaveConfig{
Id: "",
SlaveGRPCAddr: "",
MasterGRPCAddr: "",
MasterGRPCTLSEnabled: false,
SlaveHTTPAddr: "",
MasterHTTPAddr: "",
MasterHTTPTLSEnabled: false,
},
Http: HttpConfig{
Port: "8080",
Expand Down
210 changes: 200 additions & 10 deletions go/pkg/server/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package server
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
Expand All @@ -16,32 +18,40 @@ import (
"github.com/shibukazu/open-ve/go/pkg/appError"
"github.com/shibukazu/open-ve/go/pkg/config"
"github.com/shibukazu/open-ve/go/pkg/dsl/reader"
"github.com/shibukazu/open-ve/go/pkg/slave"
pbDSL "github.com/shibukazu/open-ve/go/proto/dsl/v1"
pbSlave "github.com/shibukazu/open-ve/go/proto/slave/v1"
pbValidate "github.com/shibukazu/open-ve/go/proto/validate/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
)

type Gateway struct {
httpConfig *config.HttpConfig
gRPCConfig *config.GRPCConfig
logger *slog.Logger
dslReader *reader.DSLReader
server *http.Server
mode string
httpConfig *config.HttpConfig
gRPCConfig *config.GRPCConfig
logger *slog.Logger
dslReader *reader.DSLReader
slaveManager *slave.SlaveManager
server *http.Server
}

func NewGateway(
mode string,
httpConfig *config.HttpConfig,
gRPCConfig *config.GRPCConfig,
logger *slog.Logger,
dslReader *reader.DSLReader,
slaveManager *slave.SlaveManager,
) *Gateway {
return &Gateway{
httpConfig: httpConfig,
gRPCConfig: gRPCConfig,
logger: logger,
dslReader: dslReader,
mode: mode,
httpConfig: httpConfig,
gRPCConfig: gRPCConfig,
logger: logger,
dslReader: dslReader,
slaveManager: slaveManager,
}
}

Expand Down Expand Up @@ -71,7 +81,13 @@ func (g *Gateway) Run(ctx context.Context, wg *sync.WaitGroup) {
panic(failure.Translate(err, appError.ErrServerStartFailed, failure.Messagef("failed to register dsl service on gateway")))
}

withMiddleware := g.validateRequestTypeConvertMiddleware(grpcGateway)
if g.mode == "master" {
if err := pbSlave.RegisterSlaveServiceHandlerFromEndpoint(ctx, grpcGateway, ":"+g.gRPCConfig.Port, dialOpts); err != nil {
panic(failure.Translate(err, appError.ErrServerStartFailed, failure.Messagef("failed to register slave service on gateway")))
}
}

withMiddleware := g.forwardCheckRequestMiddleware(g.validateRequestTypeConvertMiddleware(grpcGateway))

withCors := cors.New(cors.Options{
AllowedOrigins: g.httpConfig.CORSAllowedOrigins,
Expand Down Expand Up @@ -122,6 +138,180 @@ func (g *Gateway) shutdown(ctx context.Context) {
g.logger.Info("🛑 gateway server is stopped")
}

type responseRecorder struct {
http.ResponseWriter
statusCode int
body *bytes.Buffer
}

func (rec *responseRecorder) WriteHeader(code int) {
rec.statusCode = code
}

func (rec *responseRecorder) Write(b []byte) (int, error) {
return rec.body.Write(b)
}

func (g *Gateway) forwardCheckRequestMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if g.mode == "master" && r.URL.Path == "/v1/check" && r.Method == "POST" {
ctx := context.Background()
modifiedRequestValidations := make([]interface{}, 0)
validationResults := make([]interface{}, 0)

var reqBody map[string]interface{}
var resBody map[string]interface{}
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
http.Error(w, failure.Translate(err, appError.ErrRequestParameterInvalid).Error(), http.StatusBadRequest)
return
}

validations, ok := reqBody["validations"].([]interface{})
if !ok {
http.Error(w, failure.New(appError.ErrRequestParameterInvalid, failure.Messagef("validations field is invalid")).Error(), http.StatusBadRequest)
return
}

dslFound := false
dsl, err := g.dslReader.Read(ctx)
if err == nil {
dslFound = true
}
// TODO: 各処理を並列化する
for _, validation := range validations {
validation, ok := validation.(map[string]interface{})
if !ok {
http.Error(w, failure.New(appError.ErrRequestParameterInvalid, failure.Messagef("validation field is invalid")).Error(), http.StatusBadRequest)
return
}
id, ok := validation["id"].(string)
if !ok {
http.Error(w, failure.New(appError.ErrRequestParameterInvalid, failure.Messagef("id field is invalid")).Error(), http.StatusBadRequest)
return
}

// Check if the request forward is needed
isForwardNeed := false
if !dslFound {
isForwardNeed = true
} else {
for _, validation := range dsl.Validations {
if validation.ID == id {
isForwardNeed = true
break
}
}
}

if isForwardNeed {
// Find the slave node that can handle validation ID
slaveNode, err := g.slaveManager.FindSlave(id)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

var client *http.Client
if slaveNode.TLSEnabled {
transport := &http.Transport{
TLSClientConfig: &tls.Config{},
}
client = &http.Client{Transport: transport}
} else {
client = &http.Client{}
}
client.Timeout = 5 * time.Second

reqBody := map[string]interface{}{
"validations": []interface{}{validation},
}
body, err := json.Marshal(reqBody)
if err != nil {
http.Error(w, failure.Translate(err, appError.ErrValidateServiceForwardFailed).Error(), http.StatusInternalServerError)
return
}
req, err := http.NewRequest("POST", slaveNode.Addr+"/v1/check", bytes.NewBuffer(body))
if err != nil {
http.Error(w, failure.Translate(err, appError.ErrValidateServiceForwardFailed).Error(), http.StatusInternalServerError)
return
}
req.Header.Set("Content-Type", "application/json")

resp, err := client.Do(req)
if err != nil {
http.Error(w, failure.Translate(err, appError.ErrValidateServiceForwardFailed).Error(), http.StatusInternalServerError)
return
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
http.Error(w, failure.New(appError.ErrValidateServiceForwardFailed, failure.Messagef("Failed to forward the validate request to slave: %d", resp.StatusCode)).Error(), http.StatusInternalServerError)
return
}

var respBody map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
http.Error(w, failure.Translate(err, appError.ErrValidateServiceForwardFailed).Error(), http.StatusInternalServerError)
return
}
results, ok := respBody["results"].([]interface{})
if !ok {
http.Error(w, failure.New(appError.ErrValidateServiceForwardFailed, failure.Messagef("results field is invalid")).Error(), http.StatusInternalServerError)
return
}
validationResults = append(validationResults, results...)

g.logger.Info(fmt.Sprintf("⚽️ Request (id:%s) Forwarded to Slave %s", id, slaveNode.Id))
} else {
modifiedRequestValidations = append(modifiedRequestValidations, validation)
}
}

reqBody["validations"] = modifiedRequestValidations
modifiedReqBody, err := json.Marshal(reqBody)
if err != nil {
http.Error(w, failure.Translate(err, appError.ErrRequestParameterInvalid).Error(), http.StatusInternalServerError)
return
}
r.Body = io.NopCloser(bytes.NewBuffer(modifiedReqBody))
r.ContentLength = int64(len(modifiedReqBody))

rec := &responseRecorder{
ResponseWriter: w,
body: &bytes.Buffer{},
}
next.ServeHTTP(rec, r)

// Concat the validation results
if err := json.Unmarshal(rec.body.Bytes(), &resBody); err != nil {
http.Error(w, failure.Translate(err, appError.ErrRequestParameterInvalid).Error(), http.StatusInternalServerError)
return
}
originalValidationResults, ok := resBody["results"].([]interface{})
if !ok {
http.Error(w, failure.New(appError.ErrRequestParameterInvalid, failure.Messagef("results field is invalid")).Error(), http.StatusInternalServerError)
return
}
resBody["results"] = append(originalValidationResults, validationResults...)
resBodyJson, err := json.Marshal(resBody)
if err != nil {
http.Error(w, failure.Translate(err, appError.ErrRequestParameterInvalid).Error(), http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", fmt.Sprint(len(resBodyJson)))
w.WriteHeader(http.StatusOK)
_, err = w.Write(resBodyJson)
if err != nil {
g.logger.Error(failure.Translate(err, appError.ErrServerInternalError).Error())
}
} else {
next.ServeHTTP(w, r)
}
})
}

func (g *Gateway) validateRequestTypeConvertMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/v1/check" && r.Method == "POST" {
Expand Down
7 changes: 4 additions & 3 deletions go/pkg/server/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (g *GRPC) Run(ctx context.Context, wg *sync.WaitGroup, mode string) {
}

grpcServerOpts := []grpc.ServerOption{}
grpcServerOpts = append(grpcServerOpts, grpc.UnaryInterceptor(g.accessLogInterceptor()))
grpcServerOpts = append(grpcServerOpts, grpc.UnaryInterceptor(g.interceptor()))
if g.gRPCConfig.TLS.Enabled {
if g.gRPCConfig.TLS.CertPath == "" || g.gRPCConfig.TLS.KeyPath == "" {
panic(failure.New(appError.ErrServerStartFailed, failure.Message("certPath and keyPath must be set")))
Expand Down Expand Up @@ -122,11 +122,12 @@ func (g *GRPC) shutdown(ctx context.Context) {
}
}

func (g *GRPC) accessLogInterceptor() grpc.UnaryServerInterceptor {
func (g *GRPC) interceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
// Log the request
g.logger.Info("🔍 Access Log", slog.String("Method", info.FullMethod), slog.String("Request", fmt.Sprintf("%+v", req)))
resp, err := handler(ctx, req)

resp, err := handler(ctx, req)
return resp, err
}
}
Loading

0 comments on commit 29dd802

Please sign in to comment.