diff --git a/go.mod b/go.mod index e7a1dcddba..45544ab7a4 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20240919225659-2ad81685e772 + github.com/dolthub/vitess v0.0.0-20241002230050-2c2ea65cf324 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index a703549a96..1a7f1a0bcf 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9X github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= github.com/dolthub/vitess v0.0.0-20240919225659-2ad81685e772 h1:vDwBX7Lc8DnA8Zk0iRIu6slCw0GIUfYfFlYDYJQw8GQ= github.com/dolthub/vitess v0.0.0-20240919225659-2ad81685e772/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= +github.com/dolthub/vitess v0.0.0-20241002230050-2c2ea65cf324 h1:OO1XBXmBM3HBJfbwEwsj8h0m/bwYKIgFgGN8d+S+vrw= +github.com/dolthub/vitess v0.0.0-20241002230050-2c2ea65cf324/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/sql/planbuilder/ddl.go b/sql/planbuilder/ddl.go index 957e7bbaba..a9bb6cee01 100644 --- a/sql/planbuilder/ddl.go +++ b/sql/planbuilder/ddl.go @@ -526,6 +526,14 @@ func (b *Builder) buildAlterTableClause(inScope *scope, ddl *ast.DDL) []*scope { outScopes = append(outScopes, b.buildAlterCollationSpec(tableScope, ddl, rt)) } + if ddl.NotNullSpec != nil { + outScopes = append(outScopes, b.buildAlterNotNull(tableScope, ddl, rt)) + } + + if ddl.ColumnTypeSpec != nil { + outScopes = append(outScopes, b.buildAlterChangeColumnType(tableScope, ddl, rt)) + } + for _, s := range outScopes { if ts, ok := s.node.(sql.SchemaTarget); ok { s.node = b.modifySchemaTarget(s, ts, rt.Schema()) @@ -924,6 +932,56 @@ func (b *Builder) buildAlterAutoIncrement(inScope *scope, ddl *ast.DDL, table *p return } +func (b *Builder) buildAlterNotNull(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { + outScope = inScope + spec := ddl.NotNullSpec + for _, c := range table.Schema() { + if strings.EqualFold(c.Name, spec.Column.String()) { + colCopy := *c + switch strings.ToLower(spec.Action) { + case ast.SetStr: + // Set NOT NULL constraint + colCopy.Nullable = false + case ast.DropStr: + // Drop NOT NULL constraint + colCopy.Nullable = true + default: + err := sql.ErrUnsupportedFeature.New(ast.String(ddl)) + b.handleErr(err) + } + + modifyColumn := plan.NewModifyColumnResolved(table, c.Name, colCopy, nil) + outScope.node = b.modifySchemaTarget(inScope, modifyColumn, table.Schema()) + return + } + } + err := sql.ErrTableColumnNotFound.New(table.Name(), spec.Column.String()) + b.handleErr(err) + return +} + +func (b *Builder) buildAlterChangeColumnType(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { + outScope = inScope + spec := ddl.ColumnTypeSpec + for _, c := range table.Schema() { + if strings.EqualFold(c.Name, spec.Column.String()) { + colCopy := *c + typ, err := types.ColumnTypeToType(&spec.Type) + if err != nil { + b.handleErr(err) + return + } + colCopy.Type = typ + modifyColumn := plan.NewModifyColumnResolved(table, c.Name, colCopy, nil) + outScope.node = b.modifySchemaTarget(inScope, modifyColumn, table.Schema()) + return + } + } + err := sql.ErrTableColumnNotFound.New(table.Name(), spec.Column.String()) + b.handleErr(err) + return +} + func (b *Builder) buildAlterDefault(inScope *scope, ddl *ast.DDL, table *plan.ResolvedTable) (outScope *scope) { outScope = inScope switch strings.ToLower(ddl.DefaultSpec.Action) {