diff --git a/engine.go b/engine.go index cfdcbb32c5..a8a8988f26 100644 --- a/engine.go +++ b/engine.go @@ -210,7 +210,7 @@ func (e *Engine) AnalyzeQuery( query string, ) (sql.Node, error) { binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser) - parsed, _, _, qFlags, err := binder.Parse(query, false) + parsed, _, _, qFlags, err := binder.Parse(query, nil, false) if err != nil { return nil, err } @@ -238,7 +238,7 @@ func (e *Engine) PrepareParsedQuery( stmt sqlparser.Statement, ) (sql.Node, error) { binder := planbuilder.New(ctx, e.Analyzer.Catalog, e.Parser) - node, _, err := binder.BindOnly(stmt, query) + node, _, err := binder.BindOnly(stmt, query, nil) if err != nil { return nil, err @@ -586,7 +586,7 @@ func (e *Engine) bindQuery(ctx *sql.Context, query string, parsed sqlparser.Stat var bound sql.Node var err error if parsed == nil { - bound, _, _, qFlags, err = binder.Parse(query, false) + bound, _, _, qFlags, err = binder.Parse(query, qFlags, false) if err != nil { clearAutocommitErr := clearAutocommitTransaction(ctx) if clearAutocommitErr != nil { @@ -595,7 +595,7 @@ func (e *Engine) bindQuery(ctx *sql.Context, query string, parsed sqlparser.Stat return nil, nil, err } } else { - bound, qFlags, err = binder.BindOnly(parsed, query) + bound, qFlags, err = binder.BindOnly(parsed, query, qFlags) if err != nil { return nil, nil, err } @@ -651,7 +651,7 @@ func (e *Engine) bindExecuteQueryNode(ctx *sql.Context, query string, eq *plan.E binder.SetBindingsWithExpr(tempBindings) } - bound, _, err := binder.BindOnly(prep, query) + bound, _, err := binder.BindOnly(prep, query, nil) if err != nil { clearAutocommitErr := clearAutocommitTransaction(ctx) if clearAutocommitErr != nil { diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 6fb9c344a7..92c16714f0 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -511,7 +511,7 @@ func TestAnalyzer_Exp(t *testing.T) { ctx := enginetest.NewContext(harness) b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser()) - parsed, _, _, _, err := b.Parse(tt.query, false) + parsed, _, _, _, err := b.Parse(tt.query, nil, false) require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, nil) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index f5d5613c04..be74c6ecc2 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -5653,11 +5653,11 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve break } expectedEngineRow := make([]*string, len(engineRow)) - for i := range engineRow { - sqlVal, err := sch[i].Type.SQL(ctx, nil, engineRow[i]) - if !assert.NoError(t, err) { - break - } + row, err := server.RowToSQL(ctx, sch, engineRow, nil) + if !assert.NoError(t, err) { + break + } + for i, sqlVal := range row { if !sqlVal.IsNull() { str := sqlVal.ToString() expectedEngineRow[i] = &str diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index a99c9decd0..ef57d62dbb 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -528,7 +528,7 @@ func injectBindVarsAndPrepare( b := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, sql.NewMysqlParser()) b.SetParserOptions(sql.LoadSqlMode(ctx).ParserOptions()) - resPlan, _, err := b.BindOnly(parsed, q) + resPlan, _, err := b.BindOnly(parsed, q, nil) if err != nil { return q, nil, err } diff --git a/enginetest/plangen/cmd/plangen/main.go b/enginetest/plangen/cmd/plangen/main.go index 30d248a2af..7779baf30b 100644 --- a/enginetest/plangen/cmd/plangen/main.go +++ b/enginetest/plangen/cmd/plangen/main.go @@ -166,7 +166,7 @@ func generatePlansForSuite(spec PlanSpec, w *bytes.Buffer) error { if !tt.Skip { ctx := enginetest.NewContextWithEngine(harness, engine) binder := planbuilder.New(ctx, engine.EngineAnalyzer().Catalog, sql.NewMysqlParser()) - parsed, _, _, qFlags, err := binder.Parse(tt.Query, false) + parsed, _, _, qFlags, err := binder.Parse(tt.Query, nil, false) if err != nil { exit(fmt.Errorf("%w\nfailed to parse query: %s", err, tt.Query)) } diff --git a/server/handler.go b/server/handler.go index 00cf8a162e..09d0940e0a 100644 --- a/server/handler.go +++ b/server/handler.go @@ -40,7 +40,9 @@ import ( "github.com/dolthub/go-mysql-server/internal/sockstate" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/iters" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/rowexec" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -218,7 +220,7 @@ func (h *Handler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query s func (h *Handler) ComStmtExecute(ctx context.Context, c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { _, err := h.errorWrappedDoQuery(ctx, c, prepare.PrepareStmt, nil, MultiStmtModeOff, prepare.BindVars, func(res *sqltypes.Result, more bool) error { return callback(res) - }, nil) + }, &sql.QueryFlags{}) return err } @@ -295,7 +297,7 @@ func (h *Handler) ComMultiQuery( query string, callback mysql.ResultSpoolFn, ) (string, error) { - return h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOn, nil, callback, nil) + return h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOn, nil, callback, &sql.QueryFlags{}) } // ComQuery executes a SQL query on the SQLe engine. @@ -305,7 +307,7 @@ func (h *Handler) ComQuery( query string, callback mysql.ResultSpoolFn, ) error { - _, err := h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOff, nil, callback, nil) + _, err := h.errorWrappedDoQuery(ctx, c, query, nil, MultiStmtModeOff, nil, callback, &sql.QueryFlags{}) return err } @@ -317,7 +319,7 @@ func (h *Handler) ComParsedQuery( parsed sqlparser.Statement, callback mysql.ResultSpoolFn, ) error { - _, err := h.errorWrappedDoQuery(ctx, c, query, parsed, MultiStmtModeOff, nil, callback, nil) + _, err := h.errorWrappedDoQuery(ctx, c, query, parsed, MultiStmtModeOff, nil, callback, &sql.QueryFlags{}) return err } @@ -424,6 +426,7 @@ func (h *Handler) doQuery( } }() + qFlags.Set(sql.QFlagDeferProjections) schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags) if err != nil { sqlCtx.GetLogger().WithError(err).Warn("error running query") @@ -511,6 +514,37 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter, resultFields []*quer return &sqltypes.Result{Fields: resultFields}, nil } +// GetDeferredProjections looks for a top-level deferred projection, retrieves its projections, and removes it from the +// iterator tree. +func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) { + switch i := iter.(type) { + case *rowexec.ExprCloserIter: + _, projs := GetDeferredProjections(i.GetIter()) + return i, projs + case *plan.TrackedRowIter: + _, projs := GetDeferredProjections(i.GetIter()) + return i, projs + case *rowexec.TransactionCommittingIter: + newChild, projs := GetDeferredProjections(i.GetIter()) + if projs != nil { + i.WithChildIter(newChild) + } + return i, projs + case *iters.LimitIter: + newChild, projs := GetDeferredProjections(i.ChildIter) + if projs != nil { + i.ChildIter = newChild + } + return i, projs + case *rowexec.ProjectIter: + if i.CanDefer() { + return i.GetChildIter(), i.GetProjections() + } + return i, nil + } + return iter, nil +} + // resultForMax1RowIter ensures that an empty iterator returns at most one row func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []*querypb.Field) (*sqltypes.Result, error) { defer trace.StartRegion(ctx, "Handler.resultForMax1RowIter").End() @@ -527,8 +561,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, if err := iter.Close(ctx); err != nil { return nil, err } - - outputRow, err := rowToSQL(ctx, schema, row) + outputRow, err := RowToSQL(ctx, schema, row, nil) if err != nil { return nil, err } @@ -558,16 +591,11 @@ func (h *Handler) resultForDefaultIter( } } - pollCtx, cancelF := ctx.NewSubContext() - eg.Go(func() error { - defer pan2err() - return h.pollForClosedConnection(pollCtx, c) - }) - wg := sync.WaitGroup{} wg.Add(2) // Read rows off the row iterator and send them to the row channel. + iter, projs := GetDeferredProjections(iter) var rowChan = make(chan sql.Row, 512) eg.Go(func() error { defer pan2err() @@ -594,6 +622,12 @@ func (h *Handler) resultForDefaultIter( } }) + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() error { + defer pan2err() + return h.pollForClosedConnection(pollCtx, c) + }) + // Default waitTime is one minute if there is no timeout configured, in which case // it will loop to iterate again unless the socket died by the OS timeout or other problems. // If there is a timeout, it will be enforced to ensure that Vitess has a chance to @@ -639,7 +673,7 @@ func (h *Handler) resultForDefaultIter( continue } - outputRow, err := rowToSQL(ctx, schema, row) + outputRow, err := RowToSQL(ctx, schema, row, projs) if err != nil { return err } @@ -648,6 +682,7 @@ func (h *Handler) resultForDefaultIter( r.Rows = append(r.Rows, outputRow) r.RowsAffected++ case <-timer.C: + // TODO: timer should probably go in its own thread, as rowChan is blocking if h.readTimeout != 0 { // Cancel and return so Vitess can call the CloseConnection callback ctx.GetLogger().Tracef("connection timeout") @@ -901,25 +936,43 @@ func updateMaxUsedConnectionsStatusVariable() { }() } -func rowToSQL(ctx *sql.Context, s sql.Schema, row sql.Row) ([]sqltypes.Value, error) { - o := make([]sqltypes.Value, len(row)) +func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression) ([]sqltypes.Value, error) { // need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock) - if len(s) == 0 { - return o, nil + if len(sch) == 0 { + return []sqltypes.Value{}, nil } - var err error - for i, v := range row { - if v == nil { - o[i] = sqltypes.NULL + + outVals := make([]sqltypes.Value, len(sch)) + if len(projs) == 0 { + for i, col := range sch { + if row[i] == nil { + outVals[i] = sqltypes.NULL + continue + } + var err error + outVals[i], err = col.Type.SQL(ctx, nil, row[i]) + if err != nil { + return nil, err + } + } + return outVals, nil + } + + for i, col := range sch { + field, err := projs[i].Eval(ctx, row) + if err != nil { + return nil, err + } + if field == nil { + outVals[i] = sqltypes.NULL continue } - o[i], err = s[i].Type.SQL(ctx, nil, v) + outVals[i], err = col.Type.SQL(ctx, nil, field) if err != nil { return nil, err } } - - return o, nil + return outVals, nil } func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field { diff --git a/sql/analyzer/optimization_rules_test.go b/sql/analyzer/optimization_rules_test.go index 8b6a45d1ff..f36ee26fb7 100644 --- a/sql/analyzer/optimization_rules_test.go +++ b/sql/analyzer/optimization_rules_test.go @@ -223,7 +223,7 @@ func TestPushNotFilters(t *testing.T) { for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { q := fmt.Sprintf("SELECT 1 from xy WHERE %s", tt.in) - node, _, _, _, err := b.Parse(q, false) + node, _, _, _, err := b.Parse(q, nil, false) require.NoError(t, err) cmp, _, err := pushNotFilters(ctx, nil, node, nil, nil, nil) diff --git a/sql/analyzer/rules.go b/sql/analyzer/rules.go index 663ff008fa..dbfd801e5a 100644 --- a/sql/analyzer/rules.go +++ b/sql/analyzer/rules.go @@ -39,7 +39,7 @@ var OnceBeforeDefault = []Rule{ {applyDefaultSelectLimitId, applyDefaultSelectLimit}, {replaceCountStarId, replaceCountStar}, {applyEventSchedulerId, applyEventScheduler}, - {validateOffsetAndLimitId, validateLimitAndOffset}, + {validateOffsetAndLimitId, validateOffsetAndLimit}, {validateCreateTableId, validateCreateTable}, {validateAlterTableId, validateAlterTable}, {validateExprSemId, validateExprSem}, diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index f8a2750983..d1b1575b80 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -56,7 +56,7 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan var parsedProcedure sql.Node b := planbuilder.New(ctx, a.Catalog, sql.NewMysqlParser()) b.SetParserOptions(sql.NewSqlModeFromString(procedure.SqlMode).ParserOptions()) - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, false) + parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) if err != nil { procToRegister = &plan.Procedure{ CreateProcedureString: procedure.CreateStatement, @@ -300,7 +300,7 @@ func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop b.ProcCtx().AsOf = asOf } b.ProcCtx().DbName = call.Database().Name() - parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, false) + parsedProcedure, _, _, _, err = b.Parse(procedure.CreateStatement, nil, false) if err != nil { return nil, transform.SameTree, err } diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index 9bb6a5da45..4fe8d89db1 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -204,7 +204,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, var parsedTrigger sql.Node sqlMode := sql.NewSqlModeFromString(trigger.SqlMode) b.SetParserOptions(sqlMode.ParserOptions()) - parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, false) + parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, nil, false) b.Reset() if err != nil { return nil, transform.SameTree, err @@ -225,7 +225,7 @@ func applyTriggers(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, // first pass allows unresolved before we know whether trigger is relevant // TODO store destination table name with trigger, so we don't have to do parse twice b.TriggerCtx().Call = true - parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, false) + parsedTrigger, _, _, _, err = b.Parse(trigger.CreateStatement, nil, false) b.TriggerCtx().Call = false b.Reset() if err != nil { diff --git a/sql/analyzer/validation_rules.go b/sql/analyzer/validation_rules.go index 6b0f792d16..54cf51637d 100644 --- a/sql/analyzer/validation_rules.go +++ b/sql/analyzer/validation_rules.go @@ -29,8 +29,8 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -// validateLimitAndOffset ensures that only integer literals are used for limit and offset values -func validateLimitAndOffset(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { +// validateOffsetAndLimit ensures that only integer literals are used for limit and offset values +func validateOffsetAndLimit(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { var err error var i, i64 interface{} transform.Inspect(n, func(n sql.Node) bool { diff --git a/sql/expression/function/queryinfo.go b/sql/expression/function/queryinfo.go index f53e80f836..00174e8d75 100644 --- a/sql/expression/function/queryinfo.go +++ b/sql/expression/function/queryinfo.go @@ -27,104 +27,105 @@ import ( // RowCount implements the ROW_COUNT() function type RowCount struct{} -func (r RowCount) IsNonDeterministic() bool { - return true -} - func NewRowCount() sql.Expression { - return RowCount{} + return &RowCount{} } -var _ sql.FunctionExpression = RowCount{} -var _ sql.CollationCoercible = RowCount{} +var _ sql.FunctionExpression = &RowCount{} +var _ sql.CollationCoercible = &RowCount{} // Description implements sql.FunctionExpression -func (r RowCount) Description() string { +func (r *RowCount) Description() string { return "returns the number of rows updated." } // Resolved implements sql.Expression -func (r RowCount) Resolved() bool { +func (r *RowCount) Resolved() bool { return true } // String implements sql.Expression -func (r RowCount) String() string { +func (r *RowCount) String() string { return fmt.Sprintf("%s()", r.FunctionName()) } // Type implements sql.Expression -func (r RowCount) Type() sql.Type { +func (r *RowCount) Type() sql.Type { return types.Int64 } // CollationCoercibility implements the interface sql.CollationCoercible. -func (RowCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { +func (*RowCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } // IsNullable implements sql.Expression -func (r RowCount) IsNullable() bool { +func (r *RowCount) IsNullable() bool { return false } // Eval implements sql.Expression -func (r RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (r *RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return ctx.GetLastQueryInfoInt(sql.RowCount), nil } // Children implements sql.Expression -func (r RowCount) Children() []sql.Expression { +func (r *RowCount) Children() []sql.Expression { return nil } // WithChildren implements sql.Expression -func (r RowCount) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (r *RowCount) WithChildren(children ...sql.Expression) (sql.Expression, error) { return sql.NillaryWithChildren(r, children...) } // FunctionName implements sql.FunctionExpression -func (r RowCount) FunctionName() string { +func (r *RowCount) FunctionName() string { return "row_count" } +// IsNonDeterministic implements sql.NonDeterministicExpression +func (r *RowCount) IsNonDeterministic() bool { + return true +} + // LastInsertUuid implements the LAST_INSERT_UUID() function. This function is // NOT a standard function in MySQL, but is a useful analogue to LAST_INSERT_ID() // if customers are inserting UUIDs into a table. type LastInsertUuid struct{} -var _ sql.FunctionExpression = LastInsertUuid{} -var _ sql.CollationCoercible = LastInsertUuid{} +var _ sql.FunctionExpression = &LastInsertUuid{} +var _ sql.CollationCoercible = &LastInsertUuid{} func NewLastInsertUuid(children ...sql.Expression) (sql.Expression, error) { if len(children) > 0 { - return nil, sql.ErrInvalidChildrenNumber.New(LastInsertUuid{}.String(), len(children), 0) + return nil, sql.ErrInvalidChildrenNumber.New((&LastInsertUuid{}).String(), len(children), 0) } return &LastInsertUuid{}, nil } -func (l LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) { +func (l *LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } -func (l LastInsertUuid) Resolved() bool { +func (l *LastInsertUuid) Resolved() bool { return true } -func (l LastInsertUuid) String() string { +func (l *LastInsertUuid) String() string { return fmt.Sprintf("%s()", l.FunctionName()) } -func (l LastInsertUuid) Type() sql.Type { +func (l *LastInsertUuid) Type() sql.Type { return types.MustCreateStringWithDefaults(sqltypes.VarChar, 36) } -func (l LastInsertUuid) IsNullable() bool { +func (l *LastInsertUuid) IsNullable() bool { return false } -func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { +func (l *LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { lastInsertUuid := ctx.GetLastQueryInfoString(sql.LastInsertUuid) result, _, err := l.Type().Convert(lastInsertUuid) if err != nil { @@ -133,19 +134,19 @@ func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) { return result, nil } -func (l LastInsertUuid) Children() []sql.Expression { +func (l *LastInsertUuid) Children() []sql.Expression { return nil } -func (l LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (l *LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) { return NewLastInsertUuid(children...) } -func (l LastInsertUuid) FunctionName() string { +func (l *LastInsertUuid) FunctionName() string { return "last_insert_uuid" } -func (l LastInsertUuid) Description() string { +func (l *LastInsertUuid) Description() string { return "returns the first value of the UUID() function from the last INSERT statement." } @@ -155,56 +156,52 @@ type LastInsertId struct { expression.UnaryExpression } -func (r LastInsertId) IsNonDeterministic() bool { - return true -} - func NewLastInsertId(children ...sql.Expression) (sql.Expression, error) { switch len(children) { case 0: - return LastInsertId{}, nil + return &LastInsertId{}, nil case 1: - return LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil + return &LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil default: return nil, sql.ErrInvalidArgumentNumber.New("LastInsertId", len(children), 1) } } -var _ sql.FunctionExpression = LastInsertId{} -var _ sql.CollationCoercible = LastInsertId{} +var _ sql.FunctionExpression = &LastInsertId{} +var _ sql.CollationCoercible = &LastInsertId{} // Description implements sql.FunctionExpression -func (r LastInsertId) Description() string { +func (r *LastInsertId) Description() string { return "returns value of the AUTOINCREMENT column for the last INSERT." } // Resolved implements sql.Expression -func (r LastInsertId) Resolved() bool { +func (r *LastInsertId) Resolved() bool { return true } // String implements sql.Expression -func (r LastInsertId) String() string { +func (r *LastInsertId) String() string { return fmt.Sprintf("%s(%s)", r.FunctionName(), r.Child) } // Type implements sql.Expression -func (r LastInsertId) Type() sql.Type { +func (r *LastInsertId) Type() sql.Type { return types.Uint64 } // CollationCoercibility implements the interface sql.CollationCoercible. -func (LastInsertId) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { +func (*LastInsertId) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } // IsNullable implements sql.Expression -func (r LastInsertId) IsNullable() bool { +func (r *LastInsertId) IsNullable() bool { return false } // Eval implements sql.Expression -func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (r *LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // With no arguments, just return the last insert id for this session if len(r.Children()) == 0 { lastInsertId := ctx.GetLastQueryInfoInt(sql.LastInsertId) @@ -230,7 +227,7 @@ func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // Children implements sql.Expression -func (r LastInsertId) Children() []sql.Expression { +func (r *LastInsertId) Children() []sql.Expression { if r.Child == nil { return nil } @@ -238,75 +235,81 @@ func (r LastInsertId) Children() []sql.Expression { } // WithChildren implements sql.Expression -func (r LastInsertId) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (r *LastInsertId) WithChildren(children ...sql.Expression) (sql.Expression, error) { return NewLastInsertId(children...) } // FunctionName implements sql.FunctionExpression -func (r LastInsertId) FunctionName() string { +func (r *LastInsertId) FunctionName() string { return "last_insert_id" } -// FoundRows implements the FOUND_ROWS() function -type FoundRows struct{} - -func (r FoundRows) IsNonDeterministic() bool { +// IsNonDeterministic implements sql.NonDeterministicExpression +func (r *LastInsertId) IsNonDeterministic() bool { return true } +// FoundRows implements the FOUND_ROWS() function +type FoundRows struct{} + func NewFoundRows() sql.Expression { - return FoundRows{} + return &FoundRows{} } -var _ sql.FunctionExpression = FoundRows{} -var _ sql.CollationCoercible = FoundRows{} +var _ sql.FunctionExpression = &FoundRows{} +var _ sql.CollationCoercible = &FoundRows{} // FunctionName implements sql.FunctionExpression -func (r FoundRows) FunctionName() string { +func (r *FoundRows) FunctionName() string { return "found_rows" } // Description implements sql.Expression -func (r FoundRows) Description() string { +func (r *FoundRows) Description() string { return "for a SELECT with a LIMIT clause, returns the number of rows that would be returned were there no LIMIT clause." } // Resolved implements sql.Expression -func (r FoundRows) Resolved() bool { +func (r *FoundRows) Resolved() bool { return true } // String implements sql.Expression -func (r FoundRows) String() string { +func (r *FoundRows) String() string { return fmt.Sprintf("%s()", r.FunctionName()) } // Type implements sql.Expression -func (r FoundRows) Type() sql.Type { +func (r *FoundRows) Type() sql.Type { return types.Int64 } // CollationCoercibility implements the interface sql.CollationCoercible. -func (FoundRows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { +func (*FoundRows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } // IsNullable implements sql.Expression -func (r FoundRows) IsNullable() bool { +func (r *FoundRows) IsNullable() bool { return false } // Eval implements sql.Expression -func (r FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { +func (r *FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return ctx.GetLastQueryInfoInt(sql.FoundRows), nil } // Children implements sql.Expression -func (r FoundRows) Children() []sql.Expression { +func (r *FoundRows) Children() []sql.Expression { return nil } // WithChildren implements sql.Expression -func (r FoundRows) WithChildren(children ...sql.Expression) (sql.Expression, error) { +func (r *FoundRows) WithChildren(children ...sql.Expression) (sql.Expression, error) { return sql.NillaryWithChildren(r, children...) } + +// IsNonDeterministic implements sql.NonDeterministicExpression +func (r *FoundRows) IsNonDeterministic() bool { + return true +} diff --git a/sql/plan/process.go b/sql/plan/process.go index 537bc1f40d..07b62b0fbb 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -306,7 +306,7 @@ const ( QueryTypeUpdate ) -type trackedRowIter struct { +type TrackedRowIter struct { node sql.Node iter sql.RowIter numRows int64 @@ -321,11 +321,11 @@ func NewTrackedRowIter( iter sql.RowIter, onNext NotifyFunc, onDone NotifyFunc, -) *trackedRowIter { - return &trackedRowIter{node: node, iter: iter, onDone: onDone, onNext: onNext} +) *TrackedRowIter { + return &TrackedRowIter{node: node, iter: iter, onDone: onDone, onNext: onNext} } -func (i *trackedRowIter) done() { +func (i *TrackedRowIter) done() { if i.onDone != nil { i.onDone() i.onDone = nil @@ -347,13 +347,13 @@ func disposeNode(n sql.Node) { }) } -func (i *trackedRowIter) Dispose() { +func (i *TrackedRowIter) Dispose() { if i.node != nil { disposeNode(i.node) } } -func (i *trackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { +func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { row, err := i.iter.Next(ctx) if err != nil { return nil, err @@ -368,7 +368,7 @@ func (i *trackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } -func (i *trackedRowIter) Close(ctx *sql.Context) error { +func (i *TrackedRowIter) Close(ctx *sql.Context) error { err := i.iter.Close(ctx) i.updateSessionVars(ctx) @@ -377,7 +377,15 @@ func (i *trackedRowIter) Close(ctx *sql.Context) error { return err } -func (i *trackedRowIter) updateSessionVars(ctx *sql.Context) { +func (i *TrackedRowIter) GetNode() sql.Node { + return i.node +} + +func (i *TrackedRowIter) GetIter() sql.RowIter { + return i.iter +} + +func (i *TrackedRowIter) updateSessionVars(ctx *sql.Context) { switch i.QueryType { case QueryTypeSelect: ctx.SetLastQueryInfoInt(sql.RowCount, -1) diff --git a/sql/plan/project.go b/sql/plan/project.go index 3ef980e671..814f63bd13 100644 --- a/sql/plan/project.go +++ b/sql/plan/project.go @@ -26,8 +26,8 @@ import ( // Project is a projection of certain expression from the children node. type Project struct { UnaryNode - // Expression projected. Projections []sql.Expression + CanDefer bool } var _ sql.Expressioner = (*Project)(nil) @@ -160,8 +160,9 @@ func (p *Project) WithChildren(children ...sql.Node) (sql.Node, error) { if len(children) != 1 { return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) } - - return NewProject(p.Projections, children[0]), nil + np := *p + np.Child = children[0] + return &np, nil } // CheckPrivileges implements the interface sql.Node. @@ -179,6 +180,13 @@ func (p *Project) WithExpressions(exprs ...sql.Expression) (sql.Node, error) { if len(exprs) != len(p.Projections) { return nil, sql.ErrInvalidChildrenNumber.New(p, len(exprs), len(p.Projections)) } + np := *p + np.Projections = exprs + return &np, nil +} - return NewProject(exprs, p.Child), nil +func (p *Project) WithCanDefer(canDefer bool) *Project { + np := *p + np.CanDefer = canDefer + return &np } diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index c466c0bac7..5bbd5e156c 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -888,5 +888,5 @@ func (b *Builder) bindOnlyWithDatabase(db sql.Database, stmt ast.Statement, s st b.currentDatabase = curDb }() b.currentDatabase = db - return b.BindOnly(stmt, s) + return b.BindOnly(stmt, s, nil) } diff --git a/sql/planbuilder/parse.go b/sql/planbuilder/parse.go index e5df71a096..58550301c5 100644 --- a/sql/planbuilder/parse.go +++ b/sql/planbuilder/parse.go @@ -41,11 +41,11 @@ func ParseWithOptions(ctx *sql.Context, cat sql.Catalog, query string, options a // TODO: need correct parser b := New(ctx, cat, sql.NewMysqlParser()) b.SetParserOptions(options) - node, _, _, qFlags, err := b.Parse(query, false) + node, _, _, qFlags, err := b.Parse(query, nil, false) return node, qFlags, err } -func (b *Builder) Parse(query string, multi bool) (ret sql.Node, parsed, remainder string, qProps *sql.QueryFlags, err error) { +func (b *Builder) Parse(query string, qFlags *sql.QueryFlags, multi bool) (ret sql.Node, parsed, remainder string, qProps *sql.QueryFlags, err error) { defer trace.StartRegion(b.ctx, "ParseOnly").End() b.nesting++ if b.nesting > maxAnalysisIterations { @@ -74,12 +74,16 @@ func (b *Builder) Parse(query string, multi bool) (ret sql.Node, parsed, remaind return nil, parsed, remainder, nil, sql.ErrSyntaxError.New(err.Error()) } + if qFlags != nil { + b.qFlags = qFlags + } + outScope := b.build(nil, stmt, parsed) return outScope.node, parsed, remainder, b.qFlags, err } -func (b *Builder) BindOnly(stmt ast.Statement, s string) (_ sql.Node, _ *sql.QueryFlags, err error) { +func (b *Builder) BindOnly(stmt ast.Statement, s string, queryFlags *sql.QueryFlags) (_ sql.Node, _ *sql.QueryFlags, err error) { defer trace.StartRegion(b.ctx, "BindOnly").End() defer func() { if r := recover(); r != nil { @@ -91,7 +95,9 @@ func (b *Builder) BindOnly(stmt ast.Statement, s string) (_ sql.Node, _ *sql.Que } } }() - + if queryFlags != nil { + b.qFlags = queryFlags + } outScope := b.build(nil, stmt, s) return outScope.node, b.qFlags, err } diff --git a/sql/planbuilder/parse_test.go b/sql/planbuilder/parse_test.go index 0ab4554644..c0275977ec 100644 --- a/sql/planbuilder/parse_test.go +++ b/sql/planbuilder/parse_test.go @@ -2871,7 +2871,7 @@ func TestPlanBuilderErr(t *testing.T) { stmt, err := sqlparser.Parse(tt.Query) require.NoError(t, err) - _, _, err = b.BindOnly(stmt, tt.Query) + _, _, err = b.BindOnly(stmt, tt.Query, nil) defer b.Reset() require.Error(t, err) diff --git a/sql/planbuilder/project.go b/sql/planbuilder/project.go index ebac3c8b8d..35a42c14a3 100644 --- a/sql/planbuilder/project.go +++ b/sql/planbuilder/project.go @@ -194,6 +194,19 @@ func (b *Builder) selectExprToExpression(inScope *scope, se ast.SelectExpr) sql. return nil } +func (b *Builder) markDeferProjection(proj sql.Node, inScope, outScope *scope) { + if !b.qFlags.IsSet(sql.QFlagDeferProjections) || b.qFlags.IsSet(sql.QFlagUndeferrableExprs) { + return + } + if inScope.parent != nil && inScope.parent.activeSubquery != nil { + return + } + if _, isProj := proj.(*plan.Project); !isProj { + return + } + proj.(*plan.Project).CanDefer = true +} + func (b *Builder) buildProjection(inScope, outScope *scope) { projections := make([]sql.Expression, len(outScope.cols)) for i, sc := range outScope.cols { @@ -203,6 +216,7 @@ func (b *Builder) buildProjection(inScope, outScope *scope) { if err != nil { b.handleErr(err) } + b.markDeferProjection(proj, inScope, outScope) outScope.node = proj } diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 55ff36f52b..2ee0ec8c78 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -169,6 +169,11 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { b.handleErr(err) } + switch rf.(type) { + case *function.Sleep, sql.NonDeterministicExpression: + b.qFlags.Set(sql.QFlagUndeferrableExprs) + } + // NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw // errors for when DISTINCT is used on aggregate functions that don't support DISTINCT. if v.Distinct { diff --git a/sql/planbuilder/show.go b/sql/planbuilder/show.go index d56fb30e1b..418235a2a7 100644 --- a/sql/planbuilder/show.go +++ b/sql/planbuilder/show.go @@ -343,7 +343,7 @@ func (b *Builder) buildShowProcedureStatus(inScope *scope, s *ast.Show) (outScop node, _, _, _, err := b.Parse("select routine_schema as `Db`, routine_name as `Name`, routine_type as `Type`,"+ "definer as `Definer`, last_altered as `Modified`, created as `Created`, security_type as `Security_type`,"+ "routine_comment as `Comment`, CHARACTER_SET_CLIENT as `character_set_client`, COLLATION_CONNECTION as `collation_connection`,"+ - "database_collation as `Database Collation` from information_schema.routines where routine_type = 'PROCEDURE'", false) + "database_collation as `Database Collation` from information_schema.routines where routine_type = 'PROCEDURE'", nil, false) if err != nil { b.handleErr(err) } @@ -379,7 +379,7 @@ func (b *Builder) buildShowFunctionStatus(inScope *scope, s *ast.Show) (outScope node, _, _, _, err := b.Parse("select routine_schema as `Db`, routine_name as `Name`, routine_type as `Type`,"+ "definer as `Definer`, last_altered as `Modified`, created as `Created`, security_type as `Security_type`,"+ "routine_comment as `Comment`, character_set_client, collation_connection,"+ - "database_collation as `Database Collation` from information_schema.routines where routine_type = 'FUNCTION'", false) + "database_collation as `Database Collation` from information_schema.routines where routine_type = 'FUNCTION'", nil, false) if err != nil { b.handleErr(err) } @@ -783,7 +783,7 @@ func (b *Builder) buildShowCollation(inScope *scope, s *ast.Show) (outScope *sco // information_schema, with slightly different syntax and with some columns aliased. // TODO: install information_schema automatically for all catalogs node, _, _, _, err := b.Parse("select collation_name as `collation`, character_set_name as charset, id,"+ - "is_default as `default`, is_compiled as compiled, sortlen, pad_attribute from information_schema.collations order by collation_name", false) + "is_default as `default`, is_compiled as compiled, sortlen, pad_attribute from information_schema.collations order by collation_name", nil, false) if err != nil { b.handleErr(err) } @@ -828,7 +828,7 @@ select XA as XA, SAVEPOINTS as Savepoints from information_schema.engines -`, false) +`, nil, false) if err != nil { b.handleErr(err) } @@ -839,7 +839,7 @@ from information_schema.engines func (b *Builder) buildShowPlugins(inScope *scope, s *ast.Show) (outScope *scope) { outScope = inScope.push() - infoSchemaSelect, _, _, _, err := b.Parse("select * from information_schema.plugins", false) + infoSchemaSelect, _, _, _, err := b.Parse("select * from information_schema.plugins", nil, false) if err != nil { b.handleErr(err) } diff --git a/sql/query_flags.go b/sql/query_flags.go index 5d2ba22f05..fcdd053c31 100644 --- a/sql/query_flags.go +++ b/sql/query_flags.go @@ -42,6 +42,12 @@ const ( // QFlagMax1Row indicates that a query can only return at most one row QFlagMax1Row + + // QFlagDeferProjections indicates that a top-level projections for this query should be deferred and handled by + // RowToSQL + QFlagDeferProjections + // QFlagUndeferrableExprs indicates that the query has expressions that cannot be deferred + QFlagUndeferrableExprs ) type QueryFlags struct { @@ -55,6 +61,13 @@ func (qp *QueryFlags) Set(flag int) { qp.Flags.Add(flag) } +func (qp *QueryFlags) Unset(flag int) { + if qp == nil { + return + } + qp.Flags.Remove(flag) +} + func (qp *QueryFlags) IsSet(flag int) bool { return qp.Flags.Contains(flag) } diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index b10c9c4044..4c86bf17b5 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -375,7 +375,7 @@ func (b *BaseBuilder) buildRowUpdateAccumulator(ctx *sql.Context, n *plan.RowUpd case *updateJoinIter: i.accumulator = rowHandler.(*updateJoinRowHandler) done = true - case *projectIter: + case *ProjectIter: iter = i.childIter case *plan.CheckpointingTableEditorIter: iter = i.InnerIter() diff --git a/sql/rowexec/expr_closer.go b/sql/rowexec/expr_closer.go index b3c995fca8..0b4e43adb0 100644 --- a/sql/rowexec/expr_closer.go +++ b/sql/rowexec/expr_closer.go @@ -19,14 +19,14 @@ import ( "github.com/dolthub/go-mysql-server/sql/transform" ) -// exprCloserIter ensures that all expressions that implement sql.Closer are closed. This is implemented as a capturing +// ExprCloserIter ensures that all expressions that implement sql.Closer are closed. This is implemented as a capturing // iterator, as our workflow only supports closing nodes, not expressions. -type exprCloserIter struct { +type ExprCloserIter struct { iter sql.RowIter exprs []sql.Closer } -var _ sql.RowIter = (*exprCloserIter)(nil) +var _ sql.RowIter = (*ExprCloserIter)(nil) // AddExpressionCloser returns a new iterator that ensures that any expressions that implement sql.Closer are closed. // If there are no expressions that implement sql.Closer in the tree, then the original iterator is returned. @@ -43,19 +43,19 @@ func AddExpressionCloser(node sql.Node, iter sql.RowIter) sql.RowIter { if len(exprs) == 0 { return iter } - return &exprCloserIter{ + return &ExprCloserIter{ iter: iter, exprs: exprs, } } // Next implements the interface sql.RowIter. -func (eci *exprCloserIter) Next(ctx *sql.Context) (sql.Row, error) { +func (eci *ExprCloserIter) Next(ctx *sql.Context) (sql.Row, error) { return eci.iter.Next(ctx) } // Close implements the interface sql.RowIter. -func (eci *exprCloserIter) Close(ctx *sql.Context) error { +func (eci *ExprCloserIter) Close(ctx *sql.Context) error { err := eci.iter.Close(ctx) for _, expr := range eci.exprs { if nErr := expr.Close(ctx); err == nil { @@ -64,3 +64,7 @@ func (eci *exprCloserIter) Close(ctx *sql.Context) error { } return err } + +func (eci *ExprCloserIter) GetIter() sql.RowIter { + return eci.iter +} diff --git a/sql/rowexec/other_iters.go b/sql/rowexec/other_iters.go index 09f0b5961b..dcc2e9b713 100644 --- a/sql/rowexec/other_iters.go +++ b/sql/rowexec/other_iters.go @@ -311,8 +311,8 @@ type rowIterPartitionFunc func(ctx *sql.Context, partition sql.Partition) (sql.R // iterPartitionRows is the parallel worker for an Exchange node. It // is meant to be run as a goroutine in an errgroup.Group. It will // values read off of |partitions|. For each value it reads, it will -// call |getRowIter| to get a row projectIter, and will then call |Next| on -// that row projectIter, passing every row it gets into |rows|. If it +// call |getRowIter| to get a row ProjectIter, and will then call |Next| on +// that row ProjectIter, passing every row it gets into |rows|. If it // receives an error at any point, it returns it. |iterPartitionRows| // stops iterating and returns |nil| when |partitions| is closed. func iterPartitionRows(ctx *sql.Context, getRowIter rowIterPartitionFunc, partitions <-chan sql.Partition, rows chan<- sql.Row) (rerr error) { diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index 0830f5d94b..17bb1c7f37 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -311,8 +311,9 @@ func (b *BaseBuilder) buildProject(ctx *sql.Context, n *plan.Project, row sql.Ro return nil, err } - return sql.NewSpanIter(span, &projectIter{ - p: n.Projections, + return sql.NewSpanIter(span, &ProjectIter{ + projs: n.Projections, + canDefer: n.CanDefer, childIter: i, }), nil } @@ -322,8 +323,8 @@ func (b *BaseBuilder) buildVirtualColumnTable(ctx *sql.Context, n *plan.VirtualC attribute.Int("projections", len(n.Projections)), )) - return sql.NewSpanIter(span, &projectIter{ - p: n.Projections, + return sql.NewSpanIter(span, &ProjectIter{ + projs: n.Projections, childIter: tableIter, }), nil } diff --git a/sql/rowexec/rel_iters.go b/sql/rowexec/rel_iters.go index a331278190..720c81018f 100644 --- a/sql/rowexec/rel_iters.go +++ b/sql/rowexec/rel_iters.go @@ -124,24 +124,36 @@ func (i *offsetIter) Close(ctx *sql.Context) error { var _ sql.RowIter = &iters.JsonTableRowIter{} -type projectIter struct { - p []sql.Expression +type ProjectIter struct { + projs []sql.Expression + canDefer bool childIter sql.RowIter } -func (i *projectIter) Next(ctx *sql.Context) (sql.Row, error) { +func (i *ProjectIter) Next(ctx *sql.Context) (sql.Row, error) { childRow, err := i.childIter.Next(ctx) if err != nil { return nil, err } - - return ProjectRow(ctx, i.p, childRow) + return ProjectRow(ctx, i.projs, childRow) } -func (i *projectIter) Close(ctx *sql.Context) error { +func (i *ProjectIter) Close(ctx *sql.Context) error { return i.childIter.Close(ctx) } +func (i *ProjectIter) GetProjections() []sql.Expression { + return i.projs +} + +func (i *ProjectIter) CanDefer() bool { + return i.canDefer +} + +func (i *ProjectIter) GetChildIter() sql.RowIter { + return i.childIter +} + // ProjectRow evaluates a set of projections. func ProjectRow( ctx *sql.Context, @@ -149,7 +161,7 @@ func ProjectRow( row sql.Row, ) (sql.Row, error) { var fields = make(sql.Row, len(projections)) - var secondPass = make([]int, 0, len(projections)) + var secondPass []int for i, expr := range projections { // Default values that are expressions may reference other fields, thus they must evaluate after all other exprs. // Also default expressions may not refer to other columns that come after them if they also have a default expr. diff --git a/sql/rowexec/transaction.go b/sql/rowexec/transaction.go index ecfd3a5017..c036712d31 100644 --- a/sql/rowexec/transaction.go +++ b/sql/rowexec/transaction.go @@ -306,5 +306,5 @@ func (b *BaseBuilder) buildTransactionCommittingNode(ctx *sql.Context, n *plan.T if err != nil { return nil, err } - return transactionCommittingIter{childIter: iter}, nil + return &TransactionCommittingIter{childIter: iter}, nil } diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index d02cddd34b..c60935175a 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -67,18 +67,18 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) { } } -// transactionCommittingIter is a simple RowIter wrapper to allow the engine to conditionally commit a transaction +// TransactionCommittingIter is a simple RowIter wrapper to allow the engine to conditionally commit a transaction // during the Close() operation -type transactionCommittingIter struct { +type TransactionCommittingIter struct { childIter sql.RowIter transactionDatabase string } -func (t transactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { +func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } -func (t transactionCommittingIter) Close(ctx *sql.Context) error { +func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { var err error if t.childIter != nil { err = t.childIter.Close(ctx) @@ -114,3 +114,13 @@ func (t transactionCommittingIter) Close(ctx *sql.Context) error { return nil } + +func (t *TransactionCommittingIter) GetIter() sql.RowIter { + return t.childIter +} + +func (t *TransactionCommittingIter) WithChildIter(childIter sql.RowIter) sql.RowIter { + nt := *t + t.childIter = childIter + return &nt +} diff --git a/sql/rowexec/update.go b/sql/rowexec/update.go index 2fc20b6dff..c61c0d929d 100644 --- a/sql/rowexec/update.go +++ b/sql/rowexec/update.go @@ -183,7 +183,7 @@ func newUpdateIter( } } -// updateJoinIter wraps the child UpdateSource projectIter and returns join row in such a way that updates per table row are +// updateJoinIter wraps the child UpdateSource ProjectIter and returns join row in such a way that updates per table row are // done once. type updateJoinIter struct { updateSourceIter sql.RowIter