Skip to content

Commit

Permalink
feedback and fixing more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
James Cor committed Oct 3, 2024
1 parent 55df1f5 commit 0f10cba
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 86 deletions.
5 changes: 3 additions & 2 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ func (h *Handler) doQuery(
}
}()

qFlags.Set(sql.QFlagDeferProjections)
qFlags.Set(sql.QFlagDeferProjections) // TODO: this somehow breaks timeout???????
schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan, bindings, qFlags)
if err != nil {
sqlCtx.GetLogger().WithError(err).Warn("error running query")
Expand Down Expand Up @@ -522,7 +522,7 @@ func GetDeferredProjections(iter sql.RowIter) []sql.Expression {
if commitIter, isCommitIter := i.GetIter().(*rowexec.TransactionCommittingIter); isCommitIter {
if projIter, isProjIter := commitIter.GetIter().(*rowexec.ProjectIter); isProjIter {
if projIter.CanDefer() {
projIter.Defer()
commitIter.WithChildIter(projIter.GetChildIter())
return projIter.GetProjections()
}
}
Expand Down Expand Up @@ -668,6 +668,7 @@ func (h *Handler) resultForDefaultIter(
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
case <-timer.C:
// TODO: timer should probably go in its own thread, as rowChan is blocking
if h.readTimeout != 0 {
// Cancel and return so Vitess can call the CloseConnection callback
ctx.GetLogger().Tracef("connection timeout")
Expand Down
133 changes: 68 additions & 65 deletions sql/expression/function/queryinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,104 +27,105 @@ import (
// RowCount implements the ROW_COUNT() function
type RowCount struct{}

func (r RowCount) IsNonDeterministic() bool {
return true
}

func NewRowCount() sql.Expression {
return RowCount{}
return &RowCount{}
}

var _ sql.FunctionExpression = RowCount{}
var _ sql.CollationCoercible = RowCount{}
var _ sql.FunctionExpression = &RowCount{}
var _ sql.CollationCoercible = &RowCount{}

// Description implements sql.FunctionExpression
func (r RowCount) Description() string {
func (r *RowCount) Description() string {
return "returns the number of rows updated."
}

// Resolved implements sql.Expression
func (r RowCount) Resolved() bool {
func (r *RowCount) Resolved() bool {
return true
}

// String implements sql.Expression
func (r RowCount) String() string {
func (r *RowCount) String() string {
return fmt.Sprintf("%s()", r.FunctionName())
}

// Type implements sql.Expression
func (r RowCount) Type() sql.Type {
func (r *RowCount) Type() sql.Type {
return types.Int64
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (RowCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
func (*RowCount) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
}

// IsNullable implements sql.Expression
func (r RowCount) IsNullable() bool {
func (r *RowCount) IsNullable() bool {
return false
}

// Eval implements sql.Expression
func (r RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
func (r *RowCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return ctx.GetLastQueryInfoInt(sql.RowCount), nil
}

// Children implements sql.Expression
func (r RowCount) Children() []sql.Expression {
func (r *RowCount) Children() []sql.Expression {
return nil
}

// WithChildren implements sql.Expression
func (r RowCount) WithChildren(children ...sql.Expression) (sql.Expression, error) {
func (r *RowCount) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return sql.NillaryWithChildren(r, children...)
}

// FunctionName implements sql.FunctionExpression
func (r RowCount) FunctionName() string {
func (r *RowCount) FunctionName() string {
return "row_count"
}

// IsNonDeterministic implements sql.NonDeterministicExpression
func (r *RowCount) IsNonDeterministic() bool {
return true
}

// LastInsertUuid implements the LAST_INSERT_UUID() function. This function is
// NOT a standard function in MySQL, but is a useful analogue to LAST_INSERT_ID()
// if customers are inserting UUIDs into a table.
type LastInsertUuid struct{}

var _ sql.FunctionExpression = LastInsertUuid{}
var _ sql.CollationCoercible = LastInsertUuid{}
var _ sql.FunctionExpression = &LastInsertUuid{}
var _ sql.CollationCoercible = &LastInsertUuid{}

func NewLastInsertUuid(children ...sql.Expression) (sql.Expression, error) {
if len(children) > 0 {
return nil, sql.ErrInvalidChildrenNumber.New(LastInsertUuid{}.String(), len(children), 0)
return nil, sql.ErrInvalidChildrenNumber.New((&LastInsertUuid{}).String(), len(children), 0)
}

return &LastInsertUuid{}, nil
}

func (l LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) {
func (l *LastInsertUuid) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
}

func (l LastInsertUuid) Resolved() bool {
func (l *LastInsertUuid) Resolved() bool {
return true
}

func (l LastInsertUuid) String() string {
func (l *LastInsertUuid) String() string {
return fmt.Sprintf("%s()", l.FunctionName())
}

func (l LastInsertUuid) Type() sql.Type {
func (l *LastInsertUuid) Type() sql.Type {
return types.MustCreateStringWithDefaults(sqltypes.VarChar, 36)
}

func (l LastInsertUuid) IsNullable() bool {
func (l *LastInsertUuid) IsNullable() bool {
return false
}

func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) {
func (l *LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) {
lastInsertUuid := ctx.GetLastQueryInfoString(sql.LastInsertUuid)
result, _, err := l.Type().Convert(lastInsertUuid)
if err != nil {
Expand All @@ -133,19 +134,19 @@ func (l LastInsertUuid) Eval(ctx *sql.Context, _ sql.Row) (interface{}, error) {
return result, nil
}

func (l LastInsertUuid) Children() []sql.Expression {
func (l *LastInsertUuid) Children() []sql.Expression {
return nil
}

func (l LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) {
func (l *LastInsertUuid) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return NewLastInsertUuid(children...)
}

func (l LastInsertUuid) FunctionName() string {
func (l *LastInsertUuid) FunctionName() string {
return "last_insert_uuid"
}

func (l LastInsertUuid) Description() string {
func (l *LastInsertUuid) Description() string {
return "returns the first value of the UUID() function from the last INSERT statement."
}

Expand All @@ -155,56 +156,52 @@ type LastInsertId struct {
expression.UnaryExpression
}

func (r LastInsertId) IsNonDeterministic() bool {
return true
}

func NewLastInsertId(children ...sql.Expression) (sql.Expression, error) {
switch len(children) {
case 0:
return LastInsertId{}, nil
return &LastInsertId{}, nil
case 1:
return LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil
return &LastInsertId{UnaryExpression: expression.UnaryExpression{Child: children[0]}}, nil
default:
return nil, sql.ErrInvalidArgumentNumber.New("LastInsertId", len(children), 1)
}
}

var _ sql.FunctionExpression = LastInsertId{}
var _ sql.CollationCoercible = LastInsertId{}
var _ sql.FunctionExpression = &LastInsertId{}
var _ sql.CollationCoercible = &LastInsertId{}

// Description implements sql.FunctionExpression
func (r LastInsertId) Description() string {
func (r *LastInsertId) Description() string {
return "returns value of the AUTOINCREMENT column for the last INSERT."
}

// Resolved implements sql.Expression
func (r LastInsertId) Resolved() bool {
func (r *LastInsertId) Resolved() bool {
return true
}

// String implements sql.Expression
func (r LastInsertId) String() string {
func (r *LastInsertId) String() string {
return fmt.Sprintf("%s(%s)", r.FunctionName(), r.Child)
}

// Type implements sql.Expression
func (r LastInsertId) Type() sql.Type {
func (r *LastInsertId) Type() sql.Type {
return types.Uint64
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (LastInsertId) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
func (*LastInsertId) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
}

// IsNullable implements sql.Expression
func (r LastInsertId) IsNullable() bool {
func (r *LastInsertId) IsNullable() bool {
return false
}

// Eval implements sql.Expression
func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
func (r *LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
// With no arguments, just return the last insert id for this session
if len(r.Children()) == 0 {
lastInsertId := ctx.GetLastQueryInfoInt(sql.LastInsertId)
Expand All @@ -230,83 +227,89 @@ func (r LastInsertId) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
}

// Children implements sql.Expression
func (r LastInsertId) Children() []sql.Expression {
func (r *LastInsertId) Children() []sql.Expression {
if r.Child == nil {
return nil
}
return []sql.Expression{r.Child}
}

// WithChildren implements sql.Expression
func (r LastInsertId) WithChildren(children ...sql.Expression) (sql.Expression, error) {
func (r *LastInsertId) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return NewLastInsertId(children...)
}

// FunctionName implements sql.FunctionExpression
func (r LastInsertId) FunctionName() string {
func (r *LastInsertId) FunctionName() string {
return "last_insert_id"
}

// FoundRows implements the FOUND_ROWS() function
type FoundRows struct{}

func (r FoundRows) IsNonDeterministic() bool {
// IsNonDeterministic implements sql.NonDeterministicExpression
func (r *LastInsertId) IsNonDeterministic() bool {
return true
}

// FoundRows implements the FOUND_ROWS() function
type FoundRows struct{}

func NewFoundRows() sql.Expression {
return FoundRows{}
return &FoundRows{}
}

var _ sql.FunctionExpression = FoundRows{}
var _ sql.CollationCoercible = FoundRows{}
var _ sql.FunctionExpression = &FoundRows{}
var _ sql.CollationCoercible = &FoundRows{}

// FunctionName implements sql.FunctionExpression
func (r FoundRows) FunctionName() string {
func (r *FoundRows) FunctionName() string {
return "found_rows"
}

// Description implements sql.Expression
func (r FoundRows) Description() string {
func (r *FoundRows) Description() string {
return "for a SELECT with a LIMIT clause, returns the number of rows that would be returned were there no LIMIT clause."
}

// Resolved implements sql.Expression
func (r FoundRows) Resolved() bool {
func (r *FoundRows) Resolved() bool {
return true
}

// String implements sql.Expression
func (r FoundRows) String() string {
func (r *FoundRows) String() string {
return fmt.Sprintf("%s()", r.FunctionName())
}

// Type implements sql.Expression
func (r FoundRows) Type() sql.Type {
func (r *FoundRows) Type() sql.Type {
return types.Int64
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (FoundRows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
func (*FoundRows) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 5
}

// IsNullable implements sql.Expression
func (r FoundRows) IsNullable() bool {
func (r *FoundRows) IsNullable() bool {
return false
}

// Eval implements sql.Expression
func (r FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
func (r *FoundRows) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return ctx.GetLastQueryInfoInt(sql.FoundRows), nil
}

// Children implements sql.Expression
func (r FoundRows) Children() []sql.Expression {
func (r *FoundRows) Children() []sql.Expression {
return nil
}

// WithChildren implements sql.Expression
func (r FoundRows) WithChildren(children ...sql.Expression) (sql.Expression, error) {
func (r *FoundRows) WithChildren(children ...sql.Expression) (sql.Expression, error) {
return sql.NillaryWithChildren(r, children...)
}

// IsNonDeterministic implements sql.NonDeterministicExpression
func (r *FoundRows) IsNonDeterministic() bool {
return true
}
Loading

0 comments on commit 0f10cba

Please sign in to comment.