Skip to content

Commit

Permalink
lightning: support prepared statement and client stmt cache in logica…
Browse files Browse the repository at this point in the history
…l import mode (#55482)

close #54850
  • Loading branch information
dbsid committed Sep 21, 2024
1 parent 2651b77 commit 2eb4dc8
Show file tree
Hide file tree
Showing 13 changed files with 305 additions and 42 deletions.
14 changes: 8 additions & 6 deletions lightning/pkg/importer/chunk_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ func (cr *chunkProcessor) process(
// Create the encoder.
kvEncoder, err := rc.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{
SessionOptions: encode.SessionOptions{
SQLMode: rc.cfg.TiDB.SQLMode,
Timestamp: cr.chunk.Timestamp,
SysVars: rc.sysVars,
SQLMode: rc.cfg.TiDB.SQLMode,
Timestamp: cr.chunk.Timestamp,
SysVars: rc.sysVars,
LogicalImportPrepStmt: rc.cfg.TikvImporter.LogicalImportPrepStmt,
// use chunk.PrevRowIDMax as the auto random seed, so it can stay the same value after recover from checkpoint.
AutoRandomSeed: cr.chunk.Chunk.PrevRowIDMax,
},
Expand Down Expand Up @@ -262,9 +263,10 @@ func (cr *chunkProcessor) encodeLoop(

originalTableEncoder, err = rc.encBuilder.NewEncoder(ctx, &encode.EncodingConfig{
SessionOptions: encode.SessionOptions{
SQLMode: rc.cfg.TiDB.SQLMode,
Timestamp: cr.chunk.Timestamp,
SysVars: rc.sysVars,
SQLMode: rc.cfg.TiDB.SQLMode,
Timestamp: cr.chunk.Timestamp,
SysVars: rc.sysVars,
LogicalImportPrepStmt: rc.cfg.TikvImporter.LogicalImportPrepStmt,
// use chunk.PrevRowIDMax as the auto random seed, so it can stay the same value after recover from checkpoint.
AutoRandomSeed: cr.chunk.Chunk.PrevRowIDMax,
},
Expand Down
1 change: 1 addition & 0 deletions lightning/tests/lightning_tidb_duplicate_data/error.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
backend = "tidb"
on-duplicate = "error"
logical-import-batch-rows = 1
logical-import-prep-stmt = true
1 change: 1 addition & 0 deletions lightning/tests/lightning_tidb_duplicate_data/ignore.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
backend = "tidb"
on-duplicate = "ignore"
logical-import-batch-rows = 1
logical-import-prep-stmt = true
1 change: 1 addition & 0 deletions lightning/tests/lightning_tidb_duplicate_data/replace.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
backend = "tidb"
on-duplicate = "replace"
logical-import-batch-rows = 1
logical-import-prep-stmt = true
2 changes: 2 additions & 0 deletions lightning/tidb-lightning.toml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ addr = "127.0.0.1:8287"
# the rows will be split in a way to respect both settings.
# This value may be decreased to reduce the stress on the cluster due to large transaction.
#logical-import-batch-rows = 65536
# logical-import-prep-stmt controls whether to use prepared statements in logical mode (TiDB backend).
#logical-import-prep-stmt = false

[mydumper]
# block size of file reading
Expand Down
7 changes: 4 additions & 3 deletions pkg/lightning/backend/encode/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ type Encoder interface {

// SessionOptions is the initial configuration of the session.
type SessionOptions struct {
SQLMode mysql.SQLMode
Timestamp int64
SysVars map[string]string
SQLMode mysql.SQLMode
Timestamp int64
SysVars map[string]string
LogicalImportPrepStmt bool
// a seed used for tableKvEncoder's auto random bits value
AutoRandomSeed int64
// IndexID is used by the dupeDetector. Only the key range with the specified index ID is scanned.
Expand Down
5 changes: 4 additions & 1 deletion pkg/lightning/backend/tidb/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ go_library(
"//pkg/table",
"//pkg/types",
"//pkg/util/dbutil",
"//pkg/util/hack",
"//pkg/util/kvcache",
"//pkg/util/redact",
"@com_github_go_sql_driver_mysql//:mysql",
"@com_github_google_uuid//:uuid",
Expand All @@ -37,7 +39,7 @@ go_test(
timeout = "short",
srcs = ["tidb_test.go"],
flaky = True,
shard_count = 15,
shard_count = 17,
deps = [
":tidb",
"//pkg/errno",
Expand All @@ -60,5 +62,6 @@ go_test(
"@com_github_go_sql_driver_mysql//:mysql",
"@com_github_stretchr_testify//require",
"@org_uber_go_atomic//:atomic",
"@org_uber_go_zap//:zap",
],
)
147 changes: 119 additions & 28 deletions pkg/lightning/backend/tidb/tidb.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"strconv"
"strings"
"sync"
"time"

gmysql "github.com/go-sql-driver/mysql"
Expand All @@ -43,6 +44,8 @@ import (
"github.com/pingcap/tidb/pkg/table"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/dbutil"
"github.com/pingcap/tidb/pkg/util/hack"
"github.com/pingcap/tidb/pkg/util/kvcache"
"github.com/pingcap/tidb/pkg/util/redact"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
Expand All @@ -56,12 +59,16 @@ var extraHandleTableColumn = &table.Column{

const (
writeRowsMaxRetryTimes = 3
// To limit memory usage for prepared statements.
prepStmtCacheSize uint = 100
)

type tidbRow struct {
insertStmt string
path string
offset int64
insertStmt string
preparedInsertStmt string
values []any
path string
offset int64
}

var emptyTiDBRow = tidbRow{
Expand Down Expand Up @@ -91,8 +98,9 @@ type tidbEncoder struct {
// the there are enough columns.
columnCnt int
// data file path
path string
logger log.Logger
path string
logger log.Logger
prepStmt bool
}

type encodingBuilder struct{}
Expand All @@ -106,10 +114,11 @@ func NewEncodingBuilder() encode.EncodingBuilder {
// It implements the `backend.EncodingBuilder` interface.
func (*encodingBuilder) NewEncoder(_ context.Context, config *encode.EncodingConfig) (encode.Encoder, error) {
return &tidbEncoder{
mode: config.SQLMode,
tbl: config.Table,
path: config.Path,
logger: config.Logger,
mode: config.SQLMode,
tbl: config.Table,
path: config.Path,
logger: config.Logger,
prepStmt: config.LogicalImportPrepStmt,
}, nil
}

Expand Down Expand Up @@ -288,6 +297,16 @@ func (*targetInfoGetter) CheckRequirements(ctx context.Context, _ *backend.Check
return nil
}

// stmtKey defines key for stmtCache.
type stmtKey struct {
query string
}

// Hash implements SimpleLRUCache.Key.
func (k *stmtKey) Hash() []byte {
return hack.Slice(k.query)
}

type tidbBackend struct {
db *sql.DB
conflictCfg config.Conflict
Expand All @@ -301,6 +320,9 @@ type tidbBackend struct {
// affecting the cluster too much.
maxChunkSize uint64
maxChunkRows int
// implement stmtCache to improve performance
stmtCache *kvcache.SimpleLRUCache
stmtCacheMutex sync.RWMutex
}

var _ backend.Backend = (*tidbBackend)(nil)
Expand Down Expand Up @@ -334,13 +356,23 @@ func NewTiDBBackend(
log.FromContext(ctx).Warn("unsupported conflict strategy for TiDB backend, overwrite with `error`")
onDuplicate = config.ErrorOnDup
}
var stmtCache *kvcache.SimpleLRUCache
if cfg.TikvImporter.LogicalImportPrepStmt {
stmtCache = kvcache.NewSimpleLRUCache(prepStmtCacheSize, 0, 0)
stmtCache.SetOnEvict(func(_ kvcache.Key, value kvcache.Value) {
stmt := value.(*sql.Stmt)
stmt.Close()
})
}
return &tidbBackend{
db: db,
conflictCfg: conflict,
onDuplicate: onDuplicate,
errorMgr: errorMgr,
maxChunkSize: uint64(cfg.TikvImporter.LogicalImportBatchSize),
maxChunkRows: cfg.TikvImporter.LogicalImportBatchRows,
db: db,
conflictCfg: conflict,
onDuplicate: onDuplicate,
errorMgr: errorMgr,
maxChunkSize: uint64(cfg.TikvImporter.LogicalImportBatchSize),
maxChunkRows: cfg.TikvImporter.LogicalImportBatchRows,
stmtCache: stmtCache,
stmtCacheMutex: sync.RWMutex{},
}
}

Expand Down Expand Up @@ -556,16 +588,25 @@ func (enc *tidbEncoder) Encode(row []types.Datum, _ int64, columnPermutation []i
return emptyTiDBRow, errors.Errorf("column count mismatch, at most %d but got %d", len(enc.columnIdx), len(row))
}

var encoded strings.Builder
var encoded, preparedInsertStmt strings.Builder
var values []any
encoded.Grow(8 * len(row))
encoded.WriteByte('(')
if enc.prepStmt {
preparedInsertStmt.Grow(2 * len(row))
preparedInsertStmt.WriteByte('(')
values = make([]any, 0, len(row))
}
cnt := 0
for i, field := range row {
if enc.columnIdx[i] < 0 {
continue
}
if cnt > 0 {
encoded.WriteByte(',')
if enc.prepStmt {
preparedInsertStmt.WriteByte(',')
}
}
datum := field
if err := enc.appendSQL(&encoded, &datum, getColumnByIndex(cols, enc.columnIdx[i])); err != nil {
Expand All @@ -576,13 +617,23 @@ func (enc *tidbEncoder) Encode(row []types.Datum, _ int64, columnPermutation []i
)
return nil, err
}
if enc.prepStmt {
preparedInsertStmt.WriteByte('?')
values = append(values, datum.GetValue())
}
cnt++
}
encoded.WriteByte(')')
if enc.prepStmt {
preparedInsertStmt.WriteByte(')')
}

return tidbRow{
insertStmt: encoded.String(),
path: enc.path,
offset: offset,
insertStmt: encoded.String(),
preparedInsertStmt: preparedInsertStmt.String(),
values: values,
path: enc.path,
offset: offset,
}, nil
}

Expand Down Expand Up @@ -665,8 +716,9 @@ rowLoop:
}

type stmtTask struct {
rows tidbRows
stmt string
rows tidbRows
stmt string
values []any
}

// WriteBatchRowsToDB write rows in batch mode, which will insert multiple rows like this:
Expand All @@ -679,14 +731,23 @@ func (be *tidbBackend) WriteBatchRowsToDB(ctx context.Context, tableName string,
}
// Note: we are not going to do interpolation (prepared statements) to avoid
// complication arise from data length overflow of BIT and BINARY columns
var values []any
if be.stmtCache != nil && len(rows) > 0 {
values = make([]any, 0, len(rows[0].values)*len(rows))
}
stmtTasks := make([]stmtTask, 1)
for i, row := range rows {
if i != 0 {
insertStmt.WriteByte(',')
}
insertStmt.WriteString(row.insertStmt)
if be.stmtCache != nil {
insertStmt.WriteString(row.preparedInsertStmt)
values = append(values, row.values...)
} else {
insertStmt.WriteString(row.insertStmt)
}
}
stmtTasks[0] = stmtTask{rows, insertStmt.String()}
stmtTasks[0] = stmtTask{rows, insertStmt.String(), values}
return be.execStmts(ctx, stmtTasks, tableName, true)
}

Expand Down Expand Up @@ -715,8 +776,12 @@ func (be *tidbBackend) WriteRowsToDB(ctx context.Context, tableName string, colu
for _, row := range rows {
var finalInsertStmt strings.Builder
finalInsertStmt.WriteString(is)
finalInsertStmt.WriteString(row.insertStmt)
stmtTasks = append(stmtTasks, stmtTask{[]tidbRow{row}, finalInsertStmt.String()})
if be.stmtCache != nil {
finalInsertStmt.WriteString(row.preparedInsertStmt)
} else {
finalInsertStmt.WriteString(row.insertStmt)
}
stmtTasks = append(stmtTasks, stmtTask{[]tidbRow{row}, finalInsertStmt.String(), row.values})
}
return be.execStmts(ctx, stmtTasks, tableName, false)
}
Expand Down Expand Up @@ -754,8 +819,34 @@ stmtLoop:
err error
)
for i := 0; i < writeRowsMaxRetryTimes; i++ {
stmt := stmtTask.stmt
result, err = be.db.ExecContext(ctx, stmt)
query := stmtTask.stmt
if be.stmtCache != nil {
var prepStmt *sql.Stmt
key := &stmtKey{query: query}
be.stmtCacheMutex.RLock()
stmt, ok := be.stmtCache.Get(key)
be.stmtCacheMutex.RUnlock()
if ok {
prepStmt = stmt.(*sql.Stmt)
} else if stmt, err := be.db.PrepareContext(ctx, query); err == nil {
be.stmtCacheMutex.Lock()
// check again if the key is already in the cache
// to avoid override existing stmt without closing it
if cachedStmt, ok := be.stmtCache.Get(key); !ok {
prepStmt = stmt
be.stmtCache.Put(key, stmt)
} else {
prepStmt = cachedStmt.(*sql.Stmt)
stmt.Close()
}
be.stmtCacheMutex.Unlock()
} else {
return errors.Trace(err)
}
result, err = prepStmt.ExecContext(ctx, stmtTask.values...)
} else {
result, err = be.db.ExecContext(ctx, query)
}
if err == nil {
affected, err2 := result.RowsAffected()
if err2 != nil {
Expand All @@ -776,7 +867,7 @@ stmtLoop:

if !common.IsContextCanceledError(err) {
log.FromContext(ctx).Error("execute statement failed",
zap.Array("rows", stmtTask.rows), zap.String("stmt", redact.Value(stmt)), zap.Error(err))
zap.Array("rows", stmtTask.rows), zap.String("stmt", redact.Value(query)), zap.Error(err))
}
// It's batch mode, just return the error. Caller will fall back to row-by-row mode.
if batch {
Expand Down
Loading

0 comments on commit 2eb4dc8

Please sign in to comment.