diff --git a/server/handler.go b/server/handler.go index 09d0940e0a..fa90976769 100644 --- a/server/handler.go +++ b/server/handler.go @@ -392,6 +392,8 @@ func (h *Handler) doQuery( } } + sqlCtx.SetParsedQuery(parsed) + more := remainder != "" var queryStr string diff --git a/sql/session.go b/sql/session.go index f0a4c10ead..b06ff74b8d 100644 --- a/sql/session.go +++ b/sql/session.go @@ -27,6 +27,8 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" + + "github.com/dolthub/vitess/go/vt/sqlparser" ) type key uint @@ -239,6 +241,7 @@ type Context struct { services Services pid uint64 query string + parsedQuery sqlparser.Statement queryTime time.Time tracer trace.Tracer rootSpan trace.Span @@ -399,6 +402,20 @@ func (c *Context) WithQuery(q string) *Context { return &nc } +// ParsedQuery returns the parsed query associated with this context. +// May return nil. +func (c *Context) ParsedQuery() sqlparser.Statement { + if c == nil { + return nil + } + return c.parsedQuery +} + +// SetParsedQuery adds the given parsed query to the context. +func (c *Context) SetParsedQuery(parsed sqlparser.Statement) { + c.parsedQuery = parsed +} + // QueryTime returns the time.Time when the context associated with this query was created func (c *Context) QueryTime() time.Time { if c == nil {