Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization: Defer Projections for Server Queries #2676

Merged
merged 48 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
668323c
small refactor
Sep 20, 2024
bcf8ace
something
Sep 23, 2024
f9de608
Merge branch 'main' into james/proj
Sep 23, 2024
c24dc76
prototype
Sep 23, 2024
9cfa2f7
fix defer
Sep 23, 2024
558c0b5
sdf
Sep 23, 2024
dabce47
only defer
Sep 23, 2024
5b1d37c
fix enginetest
Sep 23, 2024
72b403c
fix enginetests
Sep 23, 2024
dd7161a
tidy up
Sep 25, 2024
9846d13
fix
Sep 25, 2024
6539989
[ga-format-pr] Run ./format_repo.sh to fix formatting
jycor Sep 25, 2024
4804e53
more tidying
Sep 25, 2024
efd578b
conflits
Sep 25, 2024
d454f56
conflicts with main
Sep 25, 2024
94ff580
[ga-format-pr] Run ./format_repo.sh to fix formatting
jycor Sep 25, 2024
7c2a8ef
Merge branch 'main' into james/proj3
Sep 26, 2024
a35cb43
conditionally defer projections
Sep 30, 2024
a695ea8
skip optimization for testing
Sep 30, 2024
a0c5f69
merge
Sep 30, 2024
d4239dc
fix
Sep 30, 2024
c41b9f1
fix again
Sep 30, 2024
87c5cae
fix
Sep 30, 2024
6bfd205
aaaaa
Sep 30, 2024
a617bc6
the same but not
Sep 30, 2024
5fbe0d2
Merge branch 'main' into james/proj3
Oct 1, 2024
693dfae
unnecessary?
Oct 1, 2024
429ca08
revert alloc
Oct 1, 2024
c19546c
fix
Oct 1, 2024
052c0d9
move poll
Oct 1, 2024
78b297f
readd extra row alloc
Oct 1, 2024
05f30a1
undo
Oct 1, 2024
7f9f227
Merge branch 'james/proj' into james/proj3
Oct 1, 2024
1ae4a8f
skip opt for single results
Oct 1, 2024
6967599
inline into planbuilder
Oct 2, 2024
70de43b
[ga-format-pr] Run ./format_repo.sh to fix formatting
jycor Oct 2, 2024
197ff0d
tidying
Oct 2, 2024
5556140
Merge branch 'main' into james/proj3
Oct 3, 2024
a16c2d2
real opt
Oct 3, 2024
4b30e06
[ga-format-pr] Run ./format_repo.sh to fix formatting
jycor Oct 3, 2024
55df1f5
fix tests
Oct 3, 2024
0f10cba
feedback and fixing more tests
Oct 3, 2024
bd2a659
[ga-format-pr] Run ./format_repo.sh to fix formatting
jycor Oct 3, 2024
a8b65de
more tidying
Oct 3, 2024
b0c4ae9
Merge branch 'james/proj3' of https://github.com/dolthub/go-mysql-ser…
Oct 3, 2024
988e247
peek through limit iters as well
Oct 3, 2024
4597dd0
[ga-format-pr] Run ./format_repo.sh to fix formatting
jycor Oct 3, 2024
c30684c
another flag
Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ func (e *Engine) AnalyzeQuery(
query string,
) (sql.Node, error) {
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
parsed, _, _, qFlags, err := binder.Parse(query, false)
parsed, _, _, qFlags, err := binder.Parse(query, nil, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -238,7 +238,7 @@ func (e *Engine) PrepareParsedQuery(
stmt sqlparser.Statement,
) (sql.Node, error) {
binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser)
node, _, err := binder.BindOnly(stmt, query)
node, _, err := binder.BindOnly(stmt, query, nil)

if err != nil {
return nil, err
Expand Down Expand Up @@ -586,7 +586,7 @@ func (e *Engine) bindQuery(ctx *sql.Context, query string, parsed sqlparser.Stat
var bound sql.Node
var err error
if parsed == nil {
bound, _, _, qFlags, err = binder.Parse(query, false)
bound, _, _, qFlags, err = binder.Parse(query, qFlags, false)
if err != nil {
clearAutocommitErr := clearAutocommitTransaction(ctx)
if clearAutocommitErr != nil {
Expand All @@ -595,7 +595,7 @@ func (e *Engine) bindQuery(ctx *sql.Context, query string, parsed sqlparser.Stat
return nil, nil, err
}
} else {
bound, qFlags, err = binder.BindOnly(parsed, query)
bound, qFlags, err = binder.BindOnly(parsed, query, qFlags)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -651,7 +651,7 @@ func (e *Engine) bindExecuteQueryNode(ctx *sql.Context, query string, eq *plan.E
binder.SetBindingsWithExpr(tempBindings)
}

bound, _, err := binder.BindOnly(prep, query)
bound, _, err := binder.BindOnly(prep, query, nil)
if err != nil {
clearAutocommitErr := clearAutocommitTransaction(ctx)
if clearAutocommitErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ func TestAnalyzer_Exp(t *testing.T) {

ctx := enginetest.NewContext(harness)
b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser())
parsed, _, _, _, err := b.Parse(tt.query, false)
parsed, _, _, _, err := b.Parse(tt.query, nil, false)
require.NoError(t, err)

analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, nil)
Expand Down
10 changes: 5 additions & 5 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5653,11 +5653,11 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
break
}
expectedEngineRow := make([]*string, len(engineRow))
for i := range engineRow {
sqlVal, err := sch[i].Type.SQL(ctx, nil, engineRow[i])
if !assert.NoError(t, err) {
break
}
row, err := server.RowToSQL(ctx, sch, engineRow, nil)
if !assert.NoError(t, err) {
break
}
for i, sqlVal := range row {
if !sqlVal.IsNull() {
str := sqlVal.ToString()
expectedEngineRow[i] = &str
Expand Down
2 changes: 1 addition & 1 deletion enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ func injectBindVarsAndPrepare(

b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser())
b.SetParserOptions(sql.LoadSqlMode(ctx).ParserOptions())
resPlan, _, err := b.BindOnly(parsed, q)
resPlan, _, err := b.BindOnly(parsed, q, nil)
if err != nil {
return q, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion enginetest/plangen/cmd/plangen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ func generatePlansForSuite(spec PlanSpec, w *bytes.Buffer) error {
if !tt.Skip {
ctx := enginetest.NewContextWithEngine(harness, engine)
binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, sql.NewMysqlParser())
parsed, _, _, qFlags, err := binder.Parse(tt.Query, false)
parsed, _, _, qFlags, err := binder.Parse(tt.Query, nil, false)
if err != nil {
exit(fmt.Errorf("%w\nfailed to parse query: %s", err, tt.Query))
}
Expand Down
101 changes: 77 additions & 24 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ import (
"github.com/dolthub/go-mysql-server/internal/sockstate"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/iters"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/rowexec"
"github.com/dolthub/go-mysql-server/sql/types"
)

Expand Down Expand Up @@ -218,7 +220,7 @@ func (h *Handler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query s
func (h *Handler) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
_, err := h.errorWrappedDoQuery(ctx, c, prepare.PrepareStmt, nil, MultiStmtModeOff, prepare.BindVars, func(res *sqltypes.Result, more bool) error {
return callback(res)
}, nil)
}, &sql.QueryFlags{})
return err
}

Expand Down Expand Up @@ -295,7 +297,7 @@ func (h *Handler) ComMultiQuery(
query string,
callback mysql.ResultSpoolFn,
) (string, error) {
return h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOn, nil, callback, nil)
return h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOn, nil, callback, &sql.QueryFlags{})
}

// ComQuery executes a SQL query on the SQLe engine.
Expand All @@ -305,7 +307,7 @@ func (h *Handler) ComQuery(
query string,
callback mysql.ResultSpoolFn,
) error {
_, err := h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOff, nil, callback, nil)
_, err := h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOff, nil, callback, &sql.QueryFlags{})
return err
}

Expand All @@ -317,7 +319,7 @@ func (h *Handler) ComParsedQuery(
parsed sqlparser.Statement,
callback mysql.ResultSpoolFn,
) error {
_, err := h.errorWrappedDoQuery(ctx, c, query, parsed, MultiStmtModeOff, nil, callback, nil)
_, err := h.errorWrappedDoQuery(ctx, c, query, parsed, MultiStmtModeOff, nil, callback, &sql.QueryFlags{})
return err
}

Expand Down Expand Up @@ -424,6 +426,7 @@ func (h *Handler) doQuery(
}
}()

qFlags.Set(sql.QFlagDeferProjections)
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
if err != nil {
sqlCtx.GetLogger().WithError(err).Warn("error running query")
Expand Down Expand Up @@ -511,6 +514,37 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter, resultFields []*quer
return &sqltypes.Result{Fields: resultFields}, nil
}

// GetDeferredProjections looks for a top-level deferred projection, retrieves its projections, and removes it from the
// iterator tree.
func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) {
switch i := iter.(type) {
case *rowexec.ExprCloserIter:
_, projs := GetDeferredProjections(i.GetIter())
return i, projs
case *plan.TrackedRowIter:
_, projs := GetDeferredProjections(i.GetIter())
return i, projs
case *rowexec.TransactionCommittingIter:
newChild, projs := GetDeferredProjections(i.GetIter())
if projs != nil {
i.WithChildIter(newChild)
}
return i, projs
case *iters.LimitIter:
newChild, projs := GetDeferredProjections(i.ChildIter)
if projs != nil {
i.ChildIter = newChild
}
return i, projs
case *rowexec.ProjectIter:
if i.CanDefer() {
return i.GetChildIter(), i.GetProjections()
}
return i, nil
}
return iter, nil
}

// resultForMax1RowIter ensures that an empty iterator returns at most one row
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) {
defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End()
Expand All @@ -527,8 +561,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
if err := iter.Close(ctx); err != nil {
return nil, err
}

outputRow, err := rowToSQL(ctx, schema, row)
outputRow, err := RowToSQL(ctx, schema, row, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -558,16 +591,11 @@ func (h *Handler) resultForDefaultIter(
}
}

pollCtx, cancelF := ctx.NewSubContext()
eg.Go(func() error {
defer pan2err()
return h.pollForClosedConnection(pollCtx, c)
})

wg := sync.WaitGroup{}
wg.Add(2)

// Read rows off the row iterator and send them to the row channel.
iter, projs := GetDeferredProjections(iter)
var rowChan = make(chan sql.Row, 512)
eg.Go(func() error {
defer pan2err()
Expand All @@ -594,6 +622,12 @@ func (h *Handler) resultForDefaultIter(
}
})

pollCtx, cancelF := ctx.NewSubContext()
eg.Go(func() error {
defer pan2err()
return h.pollForClosedConnection(pollCtx, c)
})

// Default waitTime is one minute if there is no timeout configured, in which case
// it will loop to iterate again unless the socket died by the OS timeout or other problems.
// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
Expand Down Expand Up @@ -639,7 +673,7 @@ func (h *Handler) resultForDefaultIter(
continue
}

outputRow, err := rowToSQL(ctx, schema, row)
outputRow, err := RowToSQL(ctx, schema, row, projs)
if err != nil {
return err
}
Expand All @@ -648,6 +682,7 @@ func (h *Handler) resultForDefaultIter(
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
case <-timer.C:
// TODO: timer should probably go in its own thread, as rowChan is blocking
if h.readTimeout != 0 {
// Cancel and return so Vitess can call the CloseConnection callback
ctx.GetLogger().Tracef("connection timeout")
Expand Down Expand Up @@ -901,25 +936,43 @@ func updateMaxUsedConnectionsStatusVariable() {
}()
}

func rowToSQL(ctx *sql.Context, s sql.Schema, row sql.Row) ([]sqltypes.Value, error) {
o := make([]sqltypes.Value, len(row))
func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) {
// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock)
if len(s) == 0 {
return o, nil
if len(sch) == 0 {
return []sqltypes.Value{}, nil
}
var err error
for i, v := range row {
if v == nil {
o[i] = sqltypes.NULL

outVals := make([]sqltypes.Value, len(sch))
if len(projs) == 0 {
for i, col := range sch {
if row[i] == nil {
outVals[i] = sqltypes.NULL
continue
}
var err error
outVals[i], err = col.Type.SQL(ctx, nil, row[i])
if err != nil {
return nil, err
}
}
return outVals, nil
}

for i, col := range sch {
field, err := projs[i].Eval(ctx, row)
if err != nil {
return nil, err
}
if field == nil {
outVals[i] = sqltypes.NULL
continue
}
o[i], err = s[i].Type.SQL(ctx, nil, v)
outVals[i], err = col.Type.SQL(ctx, nil, field)
if err != nil {
return nil, err
}
}

return o, nil
return outVals, nil
}

func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field {
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/optimization_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func TestPushNotFilters(t *testing.T) {
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
q := fmt.Sprintf("SELECT 1 from xy WHERE %s", tt.in)
node, _, _, _, err := b.Parse(q, false)
node, _, _, _, err := b.Parse(q, nil, false)
require.NoError(t, err)

cmp, _, err := pushNotFilters(ctx, nil, node, nil, nil, nil)
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var OnceBeforeDefault = []Rule{
{applyDefaultSelectLimitId, applyDefaultSelectLimit},
{replaceCountStarId, replaceCountStar},
{applyEventSchedulerId, applyEventScheduler},
{validateOffsetAndLimitId, validateLimitAndOffset},
{validateOffsetAndLimitId, validateOffsetAndLimit},
{validateCreateTableId, validateCreateTable},
{validateAlterTableId, validateAlterTable},
{validateExprSemId, validateExprSem},
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/stored_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan
var parsedProcedure sql.Node
b := planbuilder.New(ctx, a.Catalog, sql.NewMysqlParser())
b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions())
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, false)
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
procToRegister = &plan.Procedure{
CreateProcedureString: procedure.CreateStatement,
Expand Down Expand Up @@ -300,7 +300,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop
b.ProcCtx().AsOf = asOf
}
b.ProcCtx().DbName = call.Database().Name()
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, false)
parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false)
if err != nil {
return nil, transform.SameTree, err
}
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/triggers.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
var parsedTrigger sql.Node
sqlMode := sql.NewSqlModeFromString(trigger.SqlMode)
b.SetParserOptions(sqlMode.ParserOptions())
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, false)
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, nil, false)
b.Reset()
if err != nil {
return nil, transform.SameTree, err
Expand All @@ -225,7 +225,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
// first pass allows unresolved before we know whether trigger is relevant
// TODO store destination table name with trigger, so we don't have to do parse twice
b.TriggerCtx().Call = true
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, false)
parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, nil, false)
b.TriggerCtx().Call = false
b.Reset()
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/validation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"
)

// validateLimitAndOffset ensures that only integer literals are used for limit and offset values
func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
// validateOffsetAndLimit ensures that only integer literals are used for limit and offset values
func validateOffsetAndLimit(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
var err error
var i, i64 interface{}
transform.Inspect(n, func(n sql.Node) bool {
Expand Down
Loading
Loading