From 0f10cbaf5339714a93ea940471451c5c938a7c00 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 3 Oct 2024 12:21:00 -0700 Subject: [PATCH] feedback and fixing more tests --- server/handler.go | 5 +- sql/expression/function/queryinfo.go | 133 ++++++++++++++------------- sql/planbuilder/project.go | 34 +++---- sql/planbuilder/scalar.go | 5 + sql/query_flags.go | 7 ++ sql/rowexec/rel_iters.go | 4 + sql/rowexec/transaction_iters.go | 6 ++ 7 files changed, 108 insertions(+), 86 deletions(-) diff --git a/server/handler.go b/server/handler.go index f72336d276..24020d239a 100644 --- a/server/handler.go +++ b/server/handler.go @@ -425,7 +425,7 @@ func (h *Handler) doQuery( } }() - qFlags.Set(sql.QFlagDeferProjections) + qFlags.Set(sql.QFlagDeferProjections) // TODO: this somehow breaks timeout??????? schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags) if err != nil { sqlCtx.GetLogger().WithError(err).Warn("error running query") @@ -522,7 +522,7 @@ func GetDeferredProjections(iter sql.RowIter) []sql.Expression { if commitIter, isCommitIter := i.GetIter().(*rowexec.TransactionCommittingIter); isCommitIter { if projIter, isProjIter := commitIter.GetIter().(*rowexec.ProjectIter); isProjIter { if projIter.CanDefer() { - projIter.Defer() + commitIter.WithChildIter(projIter.GetChildIter()) return projIter.GetProjections() } } @@ -668,6 +668,7 @@ func (h *Handler) resultForDefaultIter( r.Rows = append(r.Rows, outputRow) r.RowsAffected++ case <-timer.C: + // TODO: timer should probably go in its own thread, as rowChan is blocking if h.readTimeout != 0 { // Cancel and return so Vitess can call the CloseConnection callback ctx.GetLogger().Tracef("connection timeout") diff --git a/sql/expression/function/queryinfo.go b/sql/expression/function/queryinfo.go index f53e80f836..902085ae93 100644 --- a/sql/expression/function/queryinfo.go +++ b/sql/expression/function/queryinfo.go @@ -27,104 +27,105 @@ import ( // RowCount implements the ROW_COUNT() function type RowCount struct{} -func (r RowCount) IsNonDeterministic() bool { - return true -} - func NewRowCount() sql.Expression { - return RowCount{} + return &RowCount{} } -var _ sql.FunctionExpression = RowCount{} -var _ sql.CollationCoercible = RowCount{} +var _ sql.FunctionExpression = &RowCount{} +var _ sql.CollationCoercible = &RowCount{} // Description implements sql.FunctionExpression -func (r RowCount) Description() string { +func (r *RowCount) Description() string { return "returns the number of rows updated." } // Resolved implements sql.Expression -func (r RowCount) Resolved() bool { +func (r *RowCount) Resolved() bool { return true } // String implements sql.Expression -func (r RowCount) String() string { +func (r *RowCount) String() string { return fmt.Sprintf("%s()", r.FunctionName()) } // Type implements sql.Expression -func (r RowCount) Type() sql.Type { +func (r *RowCount) Type() sql.Type { return types.Int64 } // CollationCoercibility implements the interface sql.CollationCoercible. -func (RowCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { +func (*RowCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } // IsNullable implements sql.Expression -func (r RowCount) IsNullable() bool { +func (r *RowCount) IsNullable() bool { return false } // Eval implements sql.Expression -func (r RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (r *RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return ctx.GetLastQueryInfoInt(sql.RowCount), nil } // Children implements sql.Expression -func (r RowCount) Children() []sql.Expression { +func (r *RowCount) Children() []sql.Expression { return nil } // WithChildren implements sql.Expression -func (r RowCount) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (r *RowCount) WithChildren(children ...sql.Expression) (sql.Expression, error) { return sql.NillaryWithChildren(r, children...) } // FunctionName implements sql.FunctionExpression -func (r RowCount) FunctionName() string { +func (r *RowCount) FunctionName() string { return "row_count" } +// IsNonDeterministic implements sql.NonDeterministicExpression +func (r *RowCount) IsNonDeterministic() bool { + return true +} + // LastInsertUuid implements the LAST_INSERT_UUID() function. This function is // NOT a standard function in MySQL, but is a useful analogue to LAST_INSERT_ID() // if customers are inserting UUIDs into a table. type LastInsertUuid struct{} -var _ sql.FunctionExpression = LastInsertUuid{} -var _ sql.CollationCoercible = LastInsertUuid{} +var _ sql.FunctionExpression = &LastInsertUuid{} +var _ sql.CollationCoercible = &LastInsertUuid{} func NewLastInsertUuid(children ...sql.Expression) (sql.Expression, error) { if len(children) > 0 { - return nil, sql.ErrInvalidChildrenNumber.New(LastInsertUuid{}.String(), len(children), 0) + return nil, sql.ErrInvalidChildrenNumber.New((&LastInsertUuid{}).String(), len(children), 0) } return &LastInsertUuid{}, nil } -func (l LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) { +func (l *LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } -func (l LastInsertUuid) Resolved() bool { +func (l *LastInsertUuid) Resolved() bool { return true } -func (l LastInsertUuid) String() string { +func (l *LastInsertUuid) String() string { return fmt.Sprintf("%s()", l.FunctionName()) } -func (l LastInsertUuid) Type() sql.Type { +func (l *LastInsertUuid) Type() sql.Type { return types.MustCreateStringWithDefaults(sqltypes.VarChar, 36) } -func (l LastInsertUuid) IsNullable() bool { +func (l *LastInsertUuid) IsNullable() bool { return false } -func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { +func (l *LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { lastInsertUuid := ctx.GetLastQueryInfoString(sql.LastInsertUuid) result, _, err := l.Type().Convert(lastInsertUuid) if err != nil { @@ -133,19 +134,19 @@ func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { return result, nil } -func (l LastInsertUuid) Children() []sql.Expression { +func (l *LastInsertUuid) Children() []sql.Expression { return nil } -func (l LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (l *LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) { return NewLastInsertUuid(children...) } -func (l LastInsertUuid) FunctionName() string { +func (l *LastInsertUuid) FunctionName() string { return "last_insert_uuid" } -func (l LastInsertUuid) Description() string { +func (l *LastInsertUuid) Description() string { return "returns the first value of the UUID() function from the last INSERT statement." } @@ -155,56 +156,52 @@ type LastInsertId struct { expression.UnaryExpression } -func (r LastInsertId) IsNonDeterministic() bool { - return true -} - func NewLastInsertId(children ...sql.Expression) (sql.Expression, error) { switch len(children) { case 0: - return LastInsertId{}, nil + return &LastInsertId{}, nil case 1: - return LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil + return &LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil default: return nil, sql.ErrInvalidArgumentNumber.New("LastInsertId", len(children), 1) } } -var _ sql.FunctionExpression = LastInsertId{} -var _ sql.CollationCoercible = LastInsertId{} +var _ sql.FunctionExpression = &LastInsertId{} +var _ sql.CollationCoercible = &LastInsertId{} // Description implements sql.FunctionExpression -func (r LastInsertId) Description() string { +func (r *LastInsertId) Description() string { return "returns value of the AUTOINCREMENT column for the last INSERT." } // Resolved implements sql.Expression -func (r LastInsertId) Resolved() bool { +func (r *LastInsertId) Resolved() bool { return true } // String implements sql.Expression -func (r LastInsertId) String() string { +func (r *LastInsertId) String() string { return fmt.Sprintf("%s(%s)", r.FunctionName(), r.Child) } // Type implements sql.Expression -func (r LastInsertId) Type() sql.Type { +func (r *LastInsertId) Type() sql.Type { return types.Uint64 } // CollationCoercibility implements the interface sql.CollationCoercible. -func (LastInsertId) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { +func (*LastInsertId) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } // IsNullable implements sql.Expression -func (r LastInsertId) IsNullable() bool { +func (r *LastInsertId) IsNullable() bool { return false } // Eval implements sql.Expression -func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (r *LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // With no arguments, just return the last insert id for this session if len(r.Children()) == 0 { lastInsertId := ctx.GetLastQueryInfoInt(sql.LastInsertId) @@ -230,7 +227,7 @@ func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // Children implements sql.Expression -func (r LastInsertId) Children() []sql.Expression { +func (r *LastInsertId) Children() []sql.Expression { if r.Child == nil { return nil } @@ -238,75 +235,81 @@ func (r LastInsertId) Children() []sql.Expression { } // WithChildren implements sql.Expression -func (r LastInsertId) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (r *LastInsertId) WithChildren(children ...sql.Expression) (sql.Expression, error) { return NewLastInsertId(children...) } // FunctionName implements sql.FunctionExpression -func (r LastInsertId) FunctionName() string { +func (r *LastInsertId) FunctionName() string { return "last_insert_id" } -// FoundRows implements the FOUND_ROWS() function -type FoundRows struct{} - -func (r FoundRows) IsNonDeterministic() bool { +// IsNonDeterministic implements sql.NonDeterministicExpression +func (r *LastInsertId) IsNonDeterministic() bool { return true } +// FoundRows implements the FOUND_ROWS() function +type FoundRows struct{} + func NewFoundRows() sql.Expression { - return FoundRows{} + return &FoundRows{} } -var _ sql.FunctionExpression = FoundRows{} -var _ sql.CollationCoercible = FoundRows{} +var _ sql.FunctionExpression = &FoundRows{} +var _ sql.CollationCoercible = &FoundRows{} // FunctionName implements sql.FunctionExpression -func (r FoundRows) FunctionName() string { +func (r *FoundRows) FunctionName() string { return "found_rows" } // Description implements sql.Expression -func (r FoundRows) Description() string { +func (r *FoundRows) Description() string { return "for a SELECT with a LIMIT clause, returns the number of rows that would be returned were there no LIMIT clause." } // Resolved implements sql.Expression -func (r FoundRows) Resolved() bool { +func (r *FoundRows) Resolved() bool { return true } // String implements sql.Expression -func (r FoundRows) String() string { +func (r *FoundRows) String() string { return fmt.Sprintf("%s()", r.FunctionName()) } // Type implements sql.Expression -func (r FoundRows) Type() sql.Type { +func (r *FoundRows) Type() sql.Type { return types.Int64 } // CollationCoercibility implements the interface sql.CollationCoercible. -func (FoundRows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { +func (*FoundRows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } // IsNullable implements sql.Expression -func (r FoundRows) IsNullable() bool { +func (r *FoundRows) IsNullable() bool { return false } // Eval implements sql.Expression -func (r FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (r *FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return ctx.GetLastQueryInfoInt(sql.FoundRows), nil } // Children implements sql.Expression -func (r FoundRows) Children() []sql.Expression { +func (r *FoundRows) Children() []sql.Expression { return nil } // WithChildren implements sql.Expression -func (r FoundRows) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (r *FoundRows) WithChildren(children ...sql.Expression) (sql.Expression, error) { return sql.NillaryWithChildren(r, children...) } + +// IsNonDeterministic implements sql.NonDeterministicExpression +func (r *FoundRows) IsNonDeterministic() bool { + return true +} \ No newline at end of file diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index 95d98f5a33..6897150b62 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -21,8 +21,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/expression/function" - "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/transform" ) @@ -195,16 +194,7 @@ func (b *Builder) selectExprToExpression(inScope *scope, se ast.SelectExpr) sql. return nil } -func (b *Builder) buildProjection(inScope, outScope *scope) { - projections := make([]sql.Expression, len(outScope.cols)) - for i, sc := range outScope.cols { - projections[i] = sc.scalar - } - proj, err := b.f.buildProject(plan.NewProject(projections, inScope.node), outScope.refsSubquery) - if err != nil { - b.handleErr(err) - } - outScope.node = proj +func (b *Builder) markDeferProjection(proj sql.Node, inScope, outScope *scope) { if !b.qFlags.IsSet(sql.QFlagDeferProjections) { return } @@ -214,16 +204,22 @@ func (b *Builder) buildProjection(inScope, outScope *scope) { if _, isProj := proj.(*plan.Project); !isProj { return } - for _, sc := range outScope.cols { - switch sc.scalar.(type) { - // TODO: column default expression are also not deferrable, but they don't appear in top level projections - case function.RowCount: - return - } - } proj.(*plan.Project).CanDefer = true } +func (b *Builder) buildProjection(inScope, outScope *scope) { + projections := make([]sql.Expression, len(outScope.cols)) + for i, sc := range outScope.cols { + projections[i] = sc.scalar + } + proj, err := b.f.buildProject(plan.NewProject(projections, inScope.node), outScope.refsSubquery) + if err != nil { + b.handleErr(err) + } + b.markDeferProjection(proj, inScope, outScope) + outScope.node = proj +} + func selectExprNeedsAlias(e *ast.AliasedExpr, expr sql.Expression) bool { if len(e.InputExpression) == 0 { return false diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 55ff36f52b..87a52fe911 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -169,6 +169,11 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { b.handleErr(err) } + switch rf.(type) { + case *function.Sleep, sql.NonDeterministicExpression: + b.qFlags.Unset(sql.QFlagDeferProjections) + } + // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT. if v.Distinct { diff --git a/sql/query_flags.go b/sql/query_flags.go index 86c9310dff..432a1c8eef 100644 --- a/sql/query_flags.go +++ b/sql/query_flags.go @@ -59,6 +59,13 @@ func (qp *QueryFlags) Set(flag int) { qp.Flags.Add(flag) } +func (qp *QueryFlags) Unset(flag int) { + if qp == nil { + return + } + qp.Flags.Remove(flag) +} + func (qp *QueryFlags) IsSet(flag int) bool { return qp.Flags.Contains(flag) } diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index 7569306e19..b67b6cc5fe 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -158,6 +158,10 @@ func (i *ProjectIter) Defer() { i.deferred = true } +func (i *ProjectIter) GetChildIter() sql.RowIter { + return i.childIter +} + // ProjectRow evaluates a set of projections. func ProjectRow( ctx *sql.Context, diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index f358f96621..c60935175a 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -118,3 +118,9 @@ func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { func (t *TransactionCommittingIter) GetIter() sql.RowIter { return t.childIter } + +func (t *TransactionCommittingIter) WithChildIter(childIter sql.RowIter) sql.RowIter { + nt := *t + t.childIter = childIter + return &nt +}