diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4a49e5e..6533a88 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -60,7 +60,7 @@ jobs: cd max2max mkdir build go get . - env GOOS=linux GOARCH=amd64 go build -o ./build/max2max . + env CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o ./build/max2max . - name: Login to DockerHub uses: docker/login-action@v1 with: diff --git a/max2max/internal/client/client.go b/max2max/internal/client/client.go index f9131b4..855ab52 100644 --- a/max2max/internal/client/client.go +++ b/max2max/internal/client/client.go @@ -1,11 +1,14 @@ package client import ( - "errors" + "context" + e "errors" "fmt" "log/slog" "os" "strings" + + "github.com/pkg/errors" ) type Loader interface { @@ -14,65 +17,68 @@ type Loader interface { } type OdpsClient interface { - GetPartitionNames(tableID string) ([]string, error) - ExecSQL(query string) error + GetPartitionNames(ctx context.Context, tableID string) ([]string, error) + ExecSQL(ctx context.Context, query string) error } type Client struct { OdpsClient OdpsClient + Loader Loader + appCtx context.Context logger *slog.Logger shutdownFns []func() error } -func NewClient(setupFns ...SetupFn) (*Client, error) { +func NewClient(ctx context.Context, setupFns ...SetupFn) (*Client, error) { c := &Client{ + appCtx: ctx, shutdownFns: make([]func() error, 0), } for _, setupFn := range setupFns { if err := setupFn(c); err != nil { - return nil, err + return nil, errors.WithStack(err) } } return c, nil } func (c *Client) Close() error { + c.logger.Info("closing client") var err error for _, fn := range c.shutdownFns { - err = errors.Join(err, fn()) + err = e.Join(err, fn()) } - return err + return errors.WithStack(err) } -func (c *Client) Execute(loader Loader, tableID, queryFilePath string) error { +func (c *Client) Execute(ctx context.Context, tableID, queryFilePath string) error { // read query from filepath c.logger.Info(fmt.Sprintf("executing query from %s", queryFilePath)) queryRaw, err := os.ReadFile(queryFilePath) if err != nil { - return err + return errors.WithStack(err) } // check if table is partitioned - c.logger.Info(fmt.Sprintf("checking if table %s is partitioned", tableID)) - partitionNames, err := c.OdpsClient.GetPartitionNames(tableID) + partitionNames, err := c.OdpsClient.GetPartitionNames(ctx, tableID) if err != nil { - return err + return errors.WithStack(err) } // prepare query - queryToExec := loader.GetQuery(tableID, string(queryRaw)) + queryToExec := c.Loader.GetQuery(tableID, string(queryRaw)) if len(partitionNames) > 0 { c.logger.Info(fmt.Sprintf("table %s is partitioned by %s", tableID, strings.Join(partitionNames, ", "))) - queryToExec = loader.GetPartitionedQuery(tableID, string(queryRaw), partitionNames) + queryToExec = c.Loader.GetPartitionedQuery(tableID, string(queryRaw), partitionNames) } // execute query with odps client c.logger.Info(fmt.Sprintf("execute: %s", queryToExec)) - if err := c.OdpsClient.ExecSQL(queryToExec); err != nil { - return err + if err := c.OdpsClient.ExecSQL(ctx, queryToExec); err != nil { + return errors.WithStack(err) } c.logger.Info("execution done") - return nil + return errors.WithStack(err) } diff --git a/max2max/internal/client/client_test.go b/max2max/internal/client/client_test.go index ab0750a..8b96569 100644 --- a/max2max/internal/client/client_test.go +++ b/max2max/internal/client/client_test.go @@ -1,6 +1,7 @@ package client_test import ( + "context" "fmt" "os" "testing" @@ -13,17 +14,17 @@ import ( func TestExecute(t *testing.T) { t.Run("should return error when reading query file fails", func(t *testing.T) { // arrange - client, err := client.NewClient(client.SetupLogger("error")) + client, err := client.NewClient(context.TODO(), client.SetupLogger("error")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{} // act - err = client.Execute(nil, "", "./nonexistentfile") + err = client.Execute(context.TODO(), "", "./nonexistentfile") // assert assert.Error(t, err) }) t.Run("should return error when getting partition name fails", func(t *testing.T) { // arrange - client, err := client.NewClient(client.SetupLogger("error")) + client, err := client.NewClient(context.TODO(), client.SetupLogger("error")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ partitionResult: func() ([]string, error) { @@ -32,14 +33,14 @@ func TestExecute(t *testing.T) { } assert.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) // act - err = client.Execute(nil, "project_test.table_test", "/tmp/query.sql") + err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql") // assert assert.Error(t, err) assert.ErrorContains(t, err, "error get partition name") }) t.Run("should return error when executing query fails", func(t *testing.T) { // arrange - client, err := client.NewClient(client.SetupLogger("error")) + client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("APPEND")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ partitionResult: func() ([]string, error) { @@ -49,21 +50,16 @@ func TestExecute(t *testing.T) { return fmt.Errorf("error exec sql") }, } - loader := &mockLoader{ - getQueryResult: func() string { - return "INSERT INTO table SELECT * FROM table;" - }, - } require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) // act - err = client.Execute(loader, "project_test.table_test", "/tmp/query.sql") + err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql") // assert assert.Error(t, err) assert.ErrorContains(t, err, "error exec sql") }) t.Run("should return nil when everything is successful", func(t *testing.T) { // arrange - client, err := client.NewClient(client.SetupLogger("error")) + client, err := client.NewClient(context.TODO(), client.SetupLogger("error"), client.SetupLoader("APPEND")) require.NoError(t, err) client.OdpsClient = &mockOdpsClient{ partitionResult: func() ([]string, error) { @@ -73,17 +69,9 @@ func TestExecute(t *testing.T) { return nil }, } - loader := &mockLoader{ - getQueryResult: func() string { - return "INSERT INTO table SELECT * FROM table;" - }, - getPartitionedQueryResult: func() string { - return "INSERT INTO table PARTITION (event_date) SELECT * FROM table;" - }, - } require.NoError(t, os.WriteFile("/tmp/query.sql", []byte("SELECT * FROM table;"), 0644)) // act - err = client.Execute(loader, "project_test.table_test", "/tmp/query.sql") + err = client.Execute(context.TODO(), "project_test.table_test", "/tmp/query.sql") // assert assert.NoError(t, err) }) @@ -94,23 +82,10 @@ type mockOdpsClient struct { execSQLResult func() error } -func (m *mockOdpsClient) GetPartitionNames(tableID string) ([]string, error) { +func (m *mockOdpsClient) GetPartitionNames(ctx context.Context, tableID string) ([]string, error) { return m.partitionResult() } -func (m *mockOdpsClient) ExecSQL(query string) error { +func (m *mockOdpsClient) ExecSQL(ctx context.Context, query string) error { return m.execSQLResult() } - -type mockLoader struct { - getQueryResult func() string - getPartitionedQueryResult func() string -} - -func (m *mockLoader) GetQuery(tableID, query string) string { - return m.getQueryResult() -} - -func (m *mockLoader) GetPartitionedQuery(tableID, query string, partitionName []string) string { - return m.getPartitionedQueryResult() -} diff --git a/max2max/internal/client/odps.go b/max2max/internal/client/odps.go index c478fad..1712eba 100644 --- a/max2max/internal/client/odps.go +++ b/max2max/internal/client/odps.go @@ -1,10 +1,13 @@ package client import ( + "context" + e "errors" "fmt" "log/slog" "github.com/aliyun/aliyun-odps-go-sdk/odps" + "github.com/pkg/errors" ) type odpsClient struct { @@ -12,6 +15,7 @@ type odpsClient struct { client *odps.Odps } +// NewODPSClient creates a new odpsClient instance func NewODPSClient(logger *slog.Logger, client *odps.Odps) *odpsClient { return &odpsClient{ logger: logger, @@ -20,22 +24,40 @@ func NewODPSClient(logger *slog.Logger, client *odps.Odps) *odpsClient { } // ExecSQL executes the given query in syncronous mode (blocking) -// TODO: change the execution mode to async and do graceful shutdown -func (c *odpsClient) ExecSQL(query string) error { +// with capability to do graceful shutdown by terminating task instance +// when context is cancelled. +func (c *odpsClient) ExecSQL(ctx context.Context, query string) error { taskIns, err := c.client.ExecSQl(query) if err != nil { - return err + return errors.WithStack(err) } + // generate log view + url, err := odps.NewLogView(c.client).GenerateLogView(taskIns, 1) + if err != nil { + err = e.Join(err, taskIns.Terminate()) + return errors.WithStack(err) + } + c.logger.Info(fmt.Sprintf("log view: %s", url)) + // wait execution success c.logger.Info(fmt.Sprintf("taskId: %s", taskIns.Id())) - return taskIns.WaitForSuccess() + select { + case <-ctx.Done(): + c.logger.Info("context cancelled, terminating task instance") + err := taskIns.Terminate() + return e.Join(ctx.Err(), err) + case err := <-wait(taskIns): + return errors.WithStack(err) + } } -func (c *odpsClient) GetPartitionNames(tableID string) ([]string, error) { +// GetPartitionNames returns the partition names of the given table +// by querying the table schema. +func (c *odpsClient) GetPartitionNames(_ context.Context, tableID string) ([]string, error) { table := c.client.Table(tableID) if err := table.Load(); err != nil { - return nil, err + return nil, errors.WithStack(err) } var partitionNames []string for _, partition := range table.Schema().PartitionColumns { @@ -43,3 +65,14 @@ func (c *odpsClient) GetPartitionNames(tableID string) ([]string, error) { } return partitionNames, nil } + +// wait waits for the task instance to finish on a separate goroutine +func wait(taskIns *odps.Instance) <-chan error { + errChan := make(chan error) + go func(errChan chan<- error) { + defer close(errChan) + err := taskIns.WaitForSuccess() + errChan <- errors.WithStack(err) + }(errChan) + return errChan +} diff --git a/max2max/internal/client/opentelemetry.go b/max2max/internal/client/opentelemetry.go index c7e53f1..cccf4b0 100644 --- a/max2max/internal/client/opentelemetry.go +++ b/max2max/internal/client/opentelemetry.go @@ -3,6 +3,7 @@ package client import ( "context" + "github.com/pkg/errors" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" @@ -10,14 +11,13 @@ import ( "go.opentelemetry.io/otel/sdk/resource" ) -func setupOTelSDK(collectorGRPCEndpoint string, jobName, scheduledTime string) (shutdown func() error, err error) { - ctx := context.Background() // TODO: use context from main +func setupOTelSDK(ctx context.Context, collectorGRPCEndpoint string, jobName, scheduledTime string) (shutdown func() error, err error) { metricExporter, err := otlpmetricgrpc.New(ctx, otlpmetricgrpc.WithEndpoint(collectorGRPCEndpoint), otlpmetricgrpc.WithInsecure(), ) if err != nil { - return nil, err + return nil, errors.WithStack(err) } // for now, we only need metric provider @@ -33,6 +33,6 @@ func setupOTelSDK(collectorGRPCEndpoint string, jobName, scheduledTime string) ( otel.SetMeterProvider(meterProvider) return func() error { - return meterProvider.Shutdown(ctx) + return meterProvider.Shutdown(context.Background()) }, nil } diff --git a/max2max/internal/client/setup.go b/max2max/internal/client/setup.go index a850048..5507f25 100644 --- a/max2max/internal/client/setup.go +++ b/max2max/internal/client/setup.go @@ -2,7 +2,9 @@ package client import ( "github.com/aliyun/aliyun-odps-go-sdk/odps" + "github.com/goto/transformers/max2max/internal/loader" "github.com/goto/transformers/max2max/internal/logger" + "github.com/pkg/errors" ) type SetupFn func(c *Client) error @@ -11,7 +13,7 @@ func SetupLogger(logLevel string) SetupFn { return func(c *Client) error { logger, err := logger.NewLogger(logLevel) if err != nil { - return err + return errors.WithStack(err) } c.logger = logger return nil @@ -30,11 +32,22 @@ func SetupOTelSDK(collectorGRPCEndpoint, jobName, scheduledTime string) SetupFn if collectorGRPCEndpoint == "" { return nil } - shutdownFn, err := setupOTelSDK(collectorGRPCEndpoint, jobName, scheduledTime) + shutdownFn, err := setupOTelSDK(c.appCtx, collectorGRPCEndpoint, jobName, scheduledTime) if err != nil { - return err + return errors.WithStack(err) } c.shutdownFns = append(c.shutdownFns, shutdownFn) return nil } } + +func SetupLoader(loadMethod string) SetupFn { + return func(c *Client) error { + loader, err := loader.GetLoader(loadMethod, c.logger) + if err != nil { + return errors.WithStack(err) + } + c.Loader = loader + return nil + } +} diff --git a/max2max/internal/config/config.go b/max2max/internal/config/config.go index f438804..06c98ad 100644 --- a/max2max/internal/config/config.go +++ b/max2max/internal/config/config.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/aliyun/aliyun-odps-go-sdk/odps" + "github.com/pkg/errors" ) type Config struct { @@ -41,7 +42,7 @@ func NewConfig() (*Config, error) { scvAcc := getEnv("SERVICE_ACCOUNT", "") cred, err := collectMaxComputeCredential([]byte(scvAcc)) if err != nil { - return nil, err + return nil, errors.WithStack(err) } cfg.Config.AccessId = cred.AccessId cfg.Config.AccessKey = cred.AccessKey @@ -56,7 +57,7 @@ func NewConfig() (*Config, error) { func collectMaxComputeCredential(scvAcc []byte) (*maxComputeCredentials, error) { var creds maxComputeCredentials if err := json.Unmarshal(scvAcc, &creds); err != nil { - return nil, err + return nil, errors.WithStack(err) } return &creds, nil diff --git a/max2max/internal/loader/factory.go b/max2max/internal/loader/factory.go index c222069..673bb26 100644 --- a/max2max/internal/loader/factory.go +++ b/max2max/internal/loader/factory.go @@ -3,6 +3,8 @@ package loader import ( "fmt" "log/slog" + + "github.com/pkg/errors" ) type Loader interface { @@ -23,6 +25,7 @@ func GetLoader(name string, logger *slog.Logger) (Loader, error) { // case MERGE_REPLACE: // return NewMergeReplaceLoader(logger), nil default: - return nil, fmt.Errorf("loader %s not found", name) + err := fmt.Errorf("loader %s not found", name) + return nil, errors.WithStack(err) } } diff --git a/max2max/internal/logger/logger.go b/max2max/internal/logger/logger.go index 18934a5..6fd17f8 100644 --- a/max2max/internal/logger/logger.go +++ b/max2max/internal/logger/logger.go @@ -3,12 +3,14 @@ package logger import ( "log/slog" "os" + + "github.com/pkg/errors" ) func NewLogger(logLevel string) (*slog.Logger, error) { var level slog.Level if err := level.UnmarshalText([]byte(logLevel)); err != nil { - return nil, err + return nil, errors.WithStack(err) } writter := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: level}) diff --git a/max2max/main.go b/max2max/main.go index c3354f8..5777711 100644 --- a/max2max/main.go +++ b/max2max/main.go @@ -1,46 +1,19 @@ package main import ( + "fmt" + "os" + _ "github.com/aliyun/aliyun-odps-go-sdk/sqldriver" - "github.com/goto/transformers/max2max/internal/client" - "github.com/goto/transformers/max2max/internal/config" - "github.com/goto/transformers/max2max/internal/loader" - "github.com/goto/transformers/max2max/internal/logger" ) -// TODO: -// - graceful shutdown -// - error handling func main() { - // load config - cfg, err := config.NewConfig() - if err != nil { - panic(err) - } - - // initiate dependencies - logger, err := logger.NewLogger(cfg.LogLevel) - if err != nil { - panic(err) - } - loader, err := loader.GetLoader(cfg.LoadMethod, logger) - if err != nil { - panic(err) - } - // initiate client - client, err := client.NewClient( - client.SetupLogger(cfg.LogLevel), - client.SetupOTelSDK(cfg.OtelCollectorGRPCEndpoint, cfg.JobName, cfg.ScheduledTime), - client.SetupODPSClient(cfg.GenOdps()), - ) - if err != nil { - panic(err) - } - defer client.Close() - - // execute query - err = client.Execute(loader, cfg.DestinationTableID, cfg.QueryFilePath) - if err != nil { - panic(err) + // max2max is the main function to execute the max2max transformation + // which reads the configuration, sets up the client and executes the query. + // It also handles graceful shutdown by listening to os signals. + // It returns error if any. + if err := max2max(); err != nil { + fmt.Printf("error: %+v\n", err) + os.Exit(1) } } diff --git a/max2max/max2max.go b/max2max/max2max.go new file mode 100644 index 0000000..98dc4ba --- /dev/null +++ b/max2max/max2max.go @@ -0,0 +1,44 @@ +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/goto/transformers/max2max/internal/client" + "github.com/goto/transformers/max2max/internal/config" + "github.com/pkg/errors" +) + +func max2max() error { + // load config + cfg, err := config.NewConfig() + if err != nil { + return errors.WithStack(err) + } + + // graceful shutdown + ctx, cancelFn := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancelFn() + + // initiate client + client, err := client.NewClient( + ctx, + client.SetupLogger(cfg.LogLevel), + client.SetupOTelSDK(cfg.OtelCollectorGRPCEndpoint, cfg.JobName, cfg.ScheduledTime), + client.SetupODPSClient(cfg.GenOdps()), + client.SetupLoader(cfg.LoadMethod), + ) + if err != nil { + return errors.WithStack(err) + } + defer client.Close() + + // execute query + err = client.Execute(ctx, cfg.DestinationTableID, cfg.QueryFilePath) + if err != nil { + return errors.WithStack(err) + } + return nil +}