Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 27, 2024
1 parent 7b71cc1 commit 8f81f65
Show file tree
Hide file tree
Showing 14 changed files with 229 additions and 10,107 deletions.
6 changes: 3 additions & 3 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,12 @@ func (c *Client) BulkWrite(ctx context.Context, models *ClientWriteModels,
}

wc := c.writeConcern
if sess.TransactionRunning() && wc != nil {
return nil, errors.New("cannot set write concern after starting a transaction")
}
if bwo.WriteConcern != nil {
wc = bwo.WriteConcern
}
if sess.TransactionRunning() {
wc = nil
}
if !writeconcern.AckWrite(wc) {
sess = nil
}
Expand Down
100 changes: 79 additions & 21 deletions mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package mongo

import (
"context"
"errors"
"strconv"

"go.mongodb.org/mongo-driver/bson"
Expand Down Expand Up @@ -40,6 +41,9 @@ type clientBulkWrite struct {
}

func (bw *clientBulkWrite) execute(ctx context.Context) error {
if len(bw.models) == 0 {
return errors.New("empty write models")
}
docs := make([]bsoncore.Document, len(bw.models))
nsMap := make(map[string]int)
var nsList []string
Expand Down Expand Up @@ -170,12 +174,21 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
Database("admin").
Deployment(bw.client.deployment).Crypt(bw.client.cryptFLE).
ServerAPI(bw.client.serverAPI).Timeout(bw.client.timeout).
Logger(bw.client.logger).Authenticator(bw.client.authenticator)
err := op.Execute(ctx)
if err != nil {
return err
Logger(bw.client.logger).Authenticator(bw.client.authenticator).Name("bulkWrite")
opErr := op.Execute(ctx)
var wcErrs []*WriteConcernError
if opErr != nil {
if errors.Is(opErr, driver.ErrUnacknowledgedWrite) {
return nil
}
var writeErr driver.WriteCommandError
if errors.As(opErr, &writeErr) {
wcErr := convertDriverWriteConcernError(writeErr.WriteConcernError)
wcErrs = append(wcErrs, wcErr)
}
}
var res struct {
Ok bool
Cursor struct {
FirstBatch []bson.Raw
}
Expand All @@ -184,33 +197,48 @@ func (bw *clientBulkWrite) execute(ctx context.Context) error {
NMatched int32
NModified int32
NUpserted int32
NErrors int32
Code int32
Errmsg string
}
rawRes := op.Result()
err = bson.Unmarshal(rawRes, &res)
if err != nil {
if err := bson.Unmarshal(rawRes, &res); err != nil {
return err
}
bw.result.DeletedCount = int64(res.NDeleted)
bw.result.InsertedCount = int64(res.NInserted)
bw.result.MatchedCount = int64(res.NMatched)
bw.result.ModifiedCount = int64(res.NModified)
bw.result.UpsertedCount = int64(res.NUpserted)
errors := make(map[int64]WriteError)
for i, cur := range res.Cursor.FirstBatch {
switch res := resMap[i].(type) {
case map[int64]ClientDeleteResult:
if err = appendDeleteResult(cur, res); err != nil {
if err := appendDeleteResult(cur, res, errors); err != nil {
return err
}
case map[int64]ClientInsertResult:
if err = appendInsertResult(cur, res, insIDMap); err != nil {
if err := appendInsertResult(cur, res, errors, insIDMap); err != nil {
return err
}
case map[int64]ClientUpdateResult:
if err = appendUpdateResult(cur, res); err != nil {
if err := appendUpdateResult(cur, res, errors); err != nil {
return err
}
}
}
if !res.Ok || res.NErrors > 0 || opErr != nil {
return ClientBulkWriteException{
TopLevelError: &WriteError{
Code: int(res.Code),
Message: res.Errmsg,
Raw: bson.Raw(rawRes),
},
WriteConcernErrors: wcErrs,
WriteErrors: errors,
PartialResult: &bw.result,
}
}
return nil
}

Expand Down Expand Up @@ -383,45 +411,75 @@ func createClientDeleteDoc(
return bsoncore.AppendDocumentEnd(doc, didx)
}

func appendDeleteResult(cur bson.Raw, m map[int64]ClientDeleteResult) error {
func appendDeleteResult(cur bson.Raw, m map[int64]ClientDeleteResult, e map[int64]WriteError) error {
var res struct {
Idx int32
N int32
Ok bool
Idx int32
N int32
Code int32
Errmsg string
}
if err := bson.Unmarshal(cur, &res); err != nil {
return err
}
m[int64(res.Idx)] = ClientDeleteResult{int64(res.N)}
if res.Ok {
m[int64(res.Idx)] = ClientDeleteResult{int64(res.N)}
} else {
e[int64(res.Idx)] = WriteError{
Code: int(res.Code),
Message: res.Errmsg,
}
}
return nil
}

func appendInsertResult(cur bson.Raw, m map[int64]ClientInsertResult, insIdMap map[int]interface{}) error {
func appendInsertResult(cur bson.Raw, m map[int64]ClientInsertResult, e map[int64]WriteError, insIDMap map[int]interface{}) error {
var res struct {
Idx int32
Ok bool
Idx int32
Code int32
Errmsg string
}
if err := bson.Unmarshal(cur, &res); err != nil {
return err
}
m[int64(res.Idx)] = ClientInsertResult{insIdMap[int(res.Idx)]}
if res.Ok {
m[int64(res.Idx)] = ClientInsertResult{insIDMap[int(res.Idx)]}
} else {
e[int64(res.Idx)] = WriteError{
Code: int(res.Code),
Message: res.Errmsg,
}
}
return nil
}

func appendUpdateResult(cur bson.Raw, m map[int64]ClientUpdateResult) error {
func appendUpdateResult(cur bson.Raw, m map[int64]ClientUpdateResult, e map[int64]WriteError) error {
var res struct {
Ok bool
Idx int32
N int32
NModified int32
Upserted struct {
ID interface{} `bson:"_id"`
}
Code int32
Errmsg string
}
if err := bson.Unmarshal(cur, &res); err != nil {
return err
}
m[int64(res.Idx)] = ClientUpdateResult{
MatchedCount: int64(res.N),
ModifiedCount: int64(res.NModified),
UpsertedID: res.Upserted.ID,
if res.Ok {
m[int64(res.Idx)] = ClientUpdateResult{
MatchedCount: int64(res.N),
ModifiedCount: int64(res.NModified),
UpsertedID: res.Upserted.ID,
}
} else {
e[int64(res.Idx)] = WriteError{
Code: int(res.Code),
Message: res.Errmsg,
}
}
return nil
}
13 changes: 5 additions & 8 deletions mongo/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,10 @@ func (bwe BulkWriteException) serverError() {}

// ClientBulkWriteException is the error type returned by ClientBulkWrite operations.
type ClientBulkWriteException struct {
TopLevelError *error
TopLevelError *WriteError

// The write concern errors that occurred.
WriteConcernErrors []WriteConcernError
WriteConcernErrors []*WriteConcernError

// The write errors that occurred during individual operation execution.
WriteErrors map[int64]WriteError
Expand All @@ -626,12 +626,12 @@ type ClientBulkWriteException struct {
func (bwe ClientBulkWriteException) Error() string {
causes := make([]string, 0, 4)
if bwe.TopLevelError != nil {
causes = append(causes, "top level error: "+(*bwe.TopLevelError).Error())
causes = append(causes, "top level error: "+bwe.TopLevelError.Error())
}
if len(bwe.WriteConcernErrors) > 0 {
errs := make([]error, len(bwe.WriteConcernErrors))
for i := 0; i < len(bwe.WriteConcernErrors); i++ {
errs[i] = &bwe.WriteConcernErrors[i]
errs[i] = bwe.WriteConcernErrors[i]
}
causes = append(causes, "write concern errors: "+joinBatchErrors(errs))
}
Expand All @@ -643,7 +643,7 @@ func (bwe ClientBulkWriteException) Error() string {
causes = append(causes, "write errors: "+joinBatchErrors(errs))
}
if bwe.PartialResult != nil {
causes = append(causes, fmt.Sprintf("result: %v", bwe.PartialResult))
causes = append(causes, fmt.Sprintf("result: %v", *bwe.PartialResult))
}

message := "bulk write exception: "
Expand All @@ -653,9 +653,6 @@ func (bwe ClientBulkWriteException) Error() string {
return "bulk write exception: " + strings.Join(causes, ", ")
}

// serverError implements the ServerError interface.
func (bwe ClientBulkWriteException) serverError() {}

// returnResult is used to determine if a function calling processWriteError should return
// the result or return nil. Since the processWriteError function is used by many different
// methods, both *One and *Many, we need a way to differentiate if the method should return
Expand Down
3 changes: 3 additions & 0 deletions mongo/integration/unified/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ func extractErrorDetails(err error) (errorDetails, bool) {
details.raw = we.Raw
}
details.labels = converted.Labels
case mongo.ClientBulkWriteException:
details.raw = converted.TopLevelError.Raw
details.codes = append(details.codes, int32(converted.TopLevelError.Code))
default:
return errorDetails{}, false
}
Expand Down
2 changes: 1 addition & 1 deletion mongo/integration/unified/schema_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (

var (
supportedSchemaVersions = map[int]string{
1: "1.17",
1: "1.21",
}
)

Expand Down
Loading

0 comments on commit 8f81f65

Please sign in to comment.