From 394b263a606e1b3f221cee02e78b1f9567a8ac1f Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 10 May 2024 17:30:21 -0400 Subject: [PATCH] WIP --- bson/array_codec.go | 2 +- bson/bson_test.go | 6 +- bson/bsoncodec.go | 11 +- bson/byte_slice_codec.go | 2 +- bson/cond_addr_codec.go | 2 +- bson/cond_addr_codec_test.go | 4 +- bson/decoder_test.go | 4 +- bson/default_value_decoders.go | 2 +- bson/default_value_decoders_test.go | 89 +-- bson/default_value_encoders.go | 144 ++-- bson/default_value_encoders_test.go | 109 +-- bson/empty_interface_codec.go | 2 +- bson/encoder.go | 76 ++- bson/int_codec.go | 156 ++++- bson/map_codec.go | 4 +- bson/marshal_test.go | 7 +- bson/mgoregistry.go | 67 +- bson/pointer_codec.go | 2 +- bson/primitive_codecs.go | 16 +- bson/raw_value_test.go | 14 +- bson/registry.go | 310 +++++---- bson/registry_examples_test.go | 40 +- bson/registry_test.go | 779 ++++++++-------------- bson/setter_getter.go | 2 +- bson/slice_codec.go | 2 +- bson/string_codec.go | 2 +- bson/struct_codec.go | 50 +- bson/time_codec.go | 2 +- bson/unmarshal_test.go | 5 +- bson/unmarshal_value_test.go | 10 +- internal/integration/client_test.go | 9 +- internal/integration/crud_spec_test.go | 8 +- internal/integration/database_test.go | 8 +- internal/integration/unified_spec_test.go | 10 +- mongo/database_test.go | 4 +- mongo/options/clientoptions_test.go | 2 +- mongo/read_write_concern_spec_test.go | 8 +- x/mongo/driver/topology/server_options.go | 2 +- 38 files changed, 993 insertions(+), 979 deletions(-) diff --git a/bson/array_codec.go b/bson/array_codec.go index 9ea43d4028..757fd60004 100644 --- a/bson/array_codec.go +++ b/bson/array_codec.go @@ -20,7 +20,7 @@ var ( ) // EncodeValue is the ValueEncoder for bsoncore.Array values. -func (ac *arrayCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (ac *arrayCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreArray { return ValueEncoderError{Name: "CoreArrayEncodeValue", Types: []reflect.Type{tCoreArray}, Received: val} } diff --git a/bson/bson_test.go b/bson/bson_test.go index 5d99e066a8..246b0e913a 100644 --- a/bson/bson_test.go +++ b/bson/bson_test.go @@ -358,12 +358,12 @@ func TestMapCodec(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - mapRegistry := NewRegistry() - mapRegistry.RegisterKindEncoder(reflect.Map, tc.codec) + mapRegistry := NewRegistryBuilder() + mapRegistry.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return tc.codec }) buf := new(bytes.Buffer) vw := NewValueWriter(buf) enc := NewEncoder(vw) - enc.SetRegistry(mapRegistry) + enc.SetRegistry(mapRegistry.Build()) err := enc.Encode(mapObj) assert.Nil(t, err, "Encode error: %v", err) str := buf.String() diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 68c108e104..db176ad906 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -103,21 +103,26 @@ type DecodeContext struct { zeroStructs bool } +// EncoderRegistry is an interface provides a ValueEncoder based on the given reflect.Type. +type EncoderRegistry interface { + LookupEncoder(reflect.Type) (ValueEncoder, error) +} + // ValueEncoder is the interface implemented by types that can encode a provided Go type to BSON. // The value to encode is provided as a reflect.Value and a bson.ValueWriter is used within the // EncodeValue method to actually create the BSON representation. For convenience, ValueEncoderFunc // is provided to allow use of a function with the correct signature as a ValueEncoder. A pointer // to a Registry instance is provided to allow implementations to lookup further ValueEncoders. type ValueEncoder interface { - EncodeValue(*Registry, ValueWriter, reflect.Value) error + EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error } // ValueEncoderFunc is an adapter function that allows a function with the correct signature to be // used as a ValueEncoder. -type ValueEncoderFunc func(*Registry, ValueWriter, reflect.Value) error +type ValueEncoderFunc func(EncoderRegistry, ValueWriter, reflect.Value) error // EncodeValue implements the ValueEncoder interface. -func (fn ValueEncoderFunc) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (fn ValueEncoderFunc) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { return fn(reg, vw, val) } diff --git a/bson/byte_slice_codec.go b/bson/byte_slice_codec.go index 83dba12ecb..e012c3d913 100644 --- a/bson/byte_slice_codec.go +++ b/bson/byte_slice_codec.go @@ -23,7 +23,7 @@ var ( ) // EncodeValue is the ValueEncoder for []byte. -func (bsc *byteSliceCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (bsc *byteSliceCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tByteSlice { return ValueEncoderError{Name: "ByteSliceEncodeValue", Types: []reflect.Type{tByteSlice}, Received: val} } diff --git a/bson/cond_addr_codec.go b/bson/cond_addr_codec.go index d1baf96ef4..ef87da6250 100644 --- a/bson/cond_addr_codec.go +++ b/bson/cond_addr_codec.go @@ -17,7 +17,7 @@ type condAddrEncoder struct { } // EncodeValue is the ValueEncoderFunc for a value that may be addressable. -func (cae *condAddrEncoder) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (cae *condAddrEncoder) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.CanAddr() { return cae.canAddrEnc.EncodeValue(reg, vw, val) } diff --git a/bson/cond_addr_codec_test.go b/bson/cond_addr_codec_test.go index 26cd9c5534..ee4f61f3a7 100644 --- a/bson/cond_addr_codec_test.go +++ b/bson/cond_addr_codec_test.go @@ -22,11 +22,11 @@ func TestCondAddrCodec(t *testing.T) { t.Run("addressEncode", func(t *testing.T) { invoked := 0 - encode1 := ValueEncoderFunc(func(*Registry, ValueWriter, reflect.Value) error { + encode1 := ValueEncoderFunc(func(EncoderRegistry, ValueWriter, reflect.Value) error { invoked = 1 return nil }) - encode2 := ValueEncoderFunc(func(*Registry, ValueWriter, reflect.Value) error { + encode2 := ValueEncoderFunc(func(EncoderRegistry, ValueWriter, reflect.Value) error { invoked = 2 return nil }) diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 8fe8d07480..3b96f63559 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -29,7 +29,7 @@ func TestBasicDecode(t *testing.T) { got := reflect.New(tc.sType).Elem() vr := NewValueReader(tc.data) - reg := DefaultRegistry + reg := NewRegistryBuilder().Build() decoder, err := reg.LookupDecoder(reflect.TypeOf(got)) noerr(t, err) err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) @@ -199,7 +199,7 @@ func TestDecoderv2(t *testing.T) { t.Run("SetRegistry", func(t *testing.T) { t.Parallel() - r1, r2 := DefaultRegistry, NewRegistry() + r1, r2 := DefaultRegistry, NewRegistryBuilder().Build() dc1 := DecodeContext{Registry: r1} dc2 := DecodeContext{Registry: r2} dec := NewDecoder(NewValueReader([]byte{})) diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index b105ab0715..56331da9a8 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -36,7 +36,7 @@ func (d decodeBinaryError) Error() string { // There is no support for decoding map[string]interface{} because there is no decoder for // interface{}, so users must either register this decoder themselves or use the // EmptyInterfaceDecoder available in the bson package. -func registerDefaultDecoders(reg *Registry) { +func registerDefaultDecoders(reg *RegistryBuilder) { if reg == nil { panic(errors.New("argument to RegisterDefaultDecoders must not be nil")) } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index d2434a19e1..0e32e64ba7 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -819,7 +819,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", map[string]string{}, - &DecodeContext{Registry: newTestRegistry()}, + &DecodeContext{Registry: newTestRegistryBuilder().Build()}, &valueReaderWriter{}, readDocument, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -905,7 +905,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", [1]string{}, - &DecodeContext{Registry: newTestRegistry()}, + &DecodeContext{Registry: newTestRegistryBuilder().Build()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -999,7 +999,7 @@ func TestDefaultValueDecoders(t *testing.T) { { "Lookup Error", []string{}, - &DecodeContext{Registry: newTestRegistry()}, + &DecodeContext{Registry: newTestRegistryBuilder().Build()}, &valueReaderWriter{BSONType: TypeArray}, readArray, ErrNoDecoder{Type: reflect.TypeOf("")}, @@ -3310,7 +3310,7 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - dc := DecodeContext{Registry: newTestRegistry()} + dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} want := ErrNoTypeMapEntry{Type: tc.bsontype} got := defaultEmptyInterfaceCodec.DecodeValue(dc, llvr, val) if !assert.CompareErrors(got, want) { @@ -3323,8 +3323,9 @@ func TestDefaultValueDecoders(t *testing.T) { t.Skip() } val := reflect.New(tEmpty).Elem() - reg := newTestRegistry() - reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) + reg := newTestRegistryBuilder(). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() dc := DecodeContext{ Registry: reg, } @@ -3341,9 +3342,10 @@ func TestDefaultValueDecoders(t *testing.T) { } want := errors.New("DecodeValue failure error") llc := &llCodec{t: t, err: want} - reg := newTestRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) - reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) + reg := newTestRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() dc := DecodeContext{ Registry: reg, } @@ -3356,9 +3358,10 @@ func TestDefaultValueDecoders(t *testing.T) { t.Run("Success", func(t *testing.T) { want := tc.val llc := &llCodec{t: t, decodeval: tc.val} - reg := newTestRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf(tc.val), llc) - reg.RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)) + reg := newTestRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf(tc.val), llc). + RegisterTypeMapEntry(tc.bsontype, reflect.TypeOf(tc.val)). + Build() dc := DecodeContext{ Registry: reg, } @@ -3395,7 +3398,7 @@ func TestDefaultValueDecoders(t *testing.T) { llvr := &valueReaderWriter{BSONType: TypeDouble} want := ErrNoTypeMapEntry{Type: TypeDouble} val := reflect.New(tEmpty).Elem() - got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistry()}, llvr, val) + got := defaultEmptyInterfaceCodec.DecodeValue(DecodeContext{Registry: newTestRegistryBuilder().Build()}, llvr, val) if !assert.CompareErrors(got, want) { t.Errorf("Errors are not equal. got %v; want %v", got, want) } @@ -3416,15 +3419,15 @@ func TestDefaultValueDecoders(t *testing.T) { // registering a custom type map entry for both Type(0) anad TypeEmbeddedDocument should cause // both top-level and embedded documents to decode to registered type when unmarshalling to interface{} - topLevelReg := newTestRegistry() - registerDefaultEncoders(topLevelReg) - registerDefaultDecoders(topLevelReg) - topLevelReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + topLevelRb := newTestRegistryBuilder() + registerDefaultEncoders(topLevelRb) + registerDefaultDecoders(topLevelRb) + topLevelRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) - embeddedReg := newTestRegistry() - registerDefaultEncoders(embeddedReg) - registerDefaultDecoders(embeddedReg) - embeddedReg.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) + embeddedRb := newTestRegistryBuilder() + registerDefaultEncoders(embeddedRb) + registerDefaultDecoders(embeddedRb) + embeddedRb.RegisterTypeMapEntry(Type(0), reflect.TypeOf(M{})) // create doc {"nested": {"foo": 1}} innerDoc := bsoncore.BuildDocument( @@ -3445,8 +3448,8 @@ func TestDefaultValueDecoders(t *testing.T) { name string registry *Registry }{ - {"top level", topLevelReg}, - {"embedded", embeddedReg}, + {"top level", topLevelRb.Build()}, + {"embedded", embeddedRb.Build()}, } for _, tc := range testCases { var got interface{} @@ -3464,10 +3467,11 @@ func TestDefaultValueDecoders(t *testing.T) { // If a type map entry is registered for TypeEmbeddedDocument, the decoder should use ancestor // information if available instead of the registered entry. - reg := newTestRegistry() - registerDefaultEncoders(reg) - registerDefaultDecoders(reg) - reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) + rb := newTestRegistryBuilder() + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + rb.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(M{})) + reg := rb.Build() // build document {"nested": {"foo": 10}} inner := bsoncore.BuildDocument( @@ -3500,8 +3504,9 @@ func TestDefaultValueDecoders(t *testing.T) { emptyInterfaceErrorDecode := func(DecodeContext, ValueReader, reflect.Value) error { return decodeValueError } - emptyInterfaceErrorRegistry := newTestRegistry() - emptyInterfaceErrorRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) + emptyInterfaceErrorRegistry := newTestRegistryBuilder(). + RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)). + Build() // Set up a document {foo: 10} and an error that would happen if the value were decoded into interface{} // using the registry defined above. @@ -3553,9 +3558,9 @@ func TestDefaultValueDecoders(t *testing.T) { outerDoc := buildDocument(bsoncore.AppendDocumentElement(nil, "first", inner1Doc)) // Use a registry that has all default decoders with the custom interface{} decoder that always errors. - nestedRegistry := newTestRegistry() - registerDefaultDecoders(nestedRegistry) - nestedRegistry.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) + nestedRegistryBuilder := newTestRegistryBuilder() + registerDefaultDecoders(nestedRegistryBuilder) + nestedRegistryBuilder.RegisterTypeDecoder(tEmpty, ValueDecoderFunc(emptyInterfaceErrorDecode)) nestedErr := &DecodeError{ keys: []string{"fourth", "1", "third", "randomKey", "second", "first"}, wrapped: decodeValueError, @@ -3640,7 +3645,7 @@ func TestDefaultValueDecoders(t *testing.T) { "struct - no decoder found", stringStruct{}, NewValueReader(docBytes), - newTestRegistry(), + newTestRegistryBuilder().Build(), defaultTestStructCodec, stringStructErr, }, @@ -3648,7 +3653,7 @@ func TestDefaultValueDecoders(t *testing.T) { "deeply nested struct", outer{}, NewValueReader(outerDoc), - nestedRegistry, + nestedRegistryBuilder.Build(), defaultTestStructCodec, nestedErr, }, @@ -3705,11 +3710,11 @@ func TestDefaultValueDecoders(t *testing.T) { bsoncore.BuildArrayElement(nil, "boolArray", trueValue), ) - reg := newTestRegistry() - registerDefaultDecoders(reg) - reg.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) + rb := newTestRegistryBuilder() + registerDefaultDecoders(rb) + rb.RegisterTypeMapEntry(TypeBoolean, reflect.TypeOf(mybool(true))) - dc := DecodeContext{Registry: reg} + dc := DecodeContext{Registry: rb.Build()} vr := NewValueReader(docBytes) val := reflect.New(tD).Elem() err := dDecodeValue(dc, vr, val) @@ -3774,8 +3779,8 @@ func buildDocument(elems []byte) []byte { } func buildDefaultRegistry() *Registry { - reg := newTestRegistry() - registerDefaultEncoders(reg) - registerDefaultDecoders(reg) - return reg + rb := newTestRegistryBuilder() + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + return rb.Build() } diff --git a/bson/default_value_encoders.go b/bson/default_value_encoders.go index df80ef0080..ca6a4a9cad 100644 --- a/bson/default_value_encoders.go +++ b/bson/default_value_encoders.go @@ -28,7 +28,7 @@ var sliceWriterPool = sync.Pool{ }, } -func encodeElement(reg *Registry, dw DocumentWriter, e E) error { +func encodeElement(reg EncoderRegistry, dw DocumentWriter, e E) error { vw, err := dw.WriteDocumentElement(e.Key) if err != nil { return err @@ -50,59 +50,59 @@ func encodeElement(reg *Registry, dw DocumentWriter, e E) error { } // registerDefaultEncoders will register the default encoder methods with the provided Registry. -func registerDefaultEncoders(reg *Registry) { - if reg == nil { +func registerDefaultEncoders(rb *RegistryBuilder) { + if rb == nil { panic(errors.New("argument to RegisterDefaultEncoders must not be nil")) } - intEncoder := &intCodec{} - uintEncoder := &uintCodec{} - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{}) - reg.RegisterTypeEncoder(tTime, &timeCodec{}) - reg.RegisterTypeEncoder(tEmpty, &emptyInterfaceCodec{}) - reg.RegisterTypeEncoder(tCoreArray, &arrayCodec{}) - reg.RegisterTypeEncoder(tOID, ValueEncoderFunc(objectIDEncodeValue)) - reg.RegisterTypeEncoder(tDecimal, ValueEncoderFunc(decimal128EncodeValue)) - reg.RegisterTypeEncoder(tJSONNumber, ValueEncoderFunc(jsonNumberEncodeValue)) - reg.RegisterTypeEncoder(tURL, ValueEncoderFunc(urlEncodeValue)) - reg.RegisterTypeEncoder(tJavaScript, ValueEncoderFunc(javaScriptEncodeValue)) - reg.RegisterTypeEncoder(tSymbol, ValueEncoderFunc(symbolEncodeValue)) - reg.RegisterTypeEncoder(tBinary, ValueEncoderFunc(binaryEncodeValue)) - reg.RegisterTypeEncoder(tUndefined, ValueEncoderFunc(undefinedEncodeValue)) - reg.RegisterTypeEncoder(tDateTime, ValueEncoderFunc(dateTimeEncodeValue)) - reg.RegisterTypeEncoder(tNull, ValueEncoderFunc(nullEncodeValue)) - reg.RegisterTypeEncoder(tRegex, ValueEncoderFunc(regexEncodeValue)) - reg.RegisterTypeEncoder(tDBPointer, ValueEncoderFunc(dbPointerEncodeValue)) - reg.RegisterTypeEncoder(tTimestamp, ValueEncoderFunc(timestampEncodeValue)) - reg.RegisterTypeEncoder(tMinKey, ValueEncoderFunc(minKeyEncodeValue)) - reg.RegisterTypeEncoder(tMaxKey, ValueEncoderFunc(maxKeyEncodeValue)) - reg.RegisterTypeEncoder(tCoreDocument, ValueEncoderFunc(coreDocumentEncodeValue)) - reg.RegisterTypeEncoder(tCodeWithScope, ValueEncoderFunc(codeWithScopeEncodeValue)) - reg.RegisterKindEncoder(reflect.Bool, ValueEncoderFunc(booleanEncodeValue)) - reg.RegisterKindEncoder(reflect.Int, intEncoder) - reg.RegisterKindEncoder(reflect.Int8, intEncoder) - reg.RegisterKindEncoder(reflect.Int16, intEncoder) - reg.RegisterKindEncoder(reflect.Int32, intEncoder) - reg.RegisterKindEncoder(reflect.Int64, intEncoder) - reg.RegisterKindEncoder(reflect.Uint, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint8, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint16, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint32, uintEncoder) - reg.RegisterKindEncoder(reflect.Uint64, uintEncoder) - reg.RegisterKindEncoder(reflect.Float32, ValueEncoderFunc(floatEncodeValue)) - reg.RegisterKindEncoder(reflect.Float64, ValueEncoderFunc(floatEncodeValue)) - reg.RegisterKindEncoder(reflect.Array, ValueEncoderFunc(arrayEncodeValue)) - reg.RegisterKindEncoder(reflect.Map, &mapCodec{}) - reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) - reg.RegisterKindEncoder(reflect.String, &stringCodec{}) - reg.RegisterKindEncoder(reflect.Struct, newStructCodec(DefaultStructTagParser)) - reg.RegisterKindEncoder(reflect.Ptr, &pointerCodec{}) - reg.RegisterInterfaceEncoder(tValueMarshaler, ValueEncoderFunc(valueMarshalerEncodeValue)) - reg.RegisterInterfaceEncoder(tMarshaler, ValueEncoderFunc(marshalerEncodeValue)) - reg.RegisterInterfaceEncoder(tProxy, ValueEncoderFunc(proxyEncodeValue)) + intEncoder := func() ValueEncoder { return &intCodec{} } + floatEncoder := func() ValueEncoder { return ValueEncoderFunc(floatEncodeValue) } + rb.RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{} }). + RegisterTypeEncoder(tTime, func() ValueEncoder { return &timeCodec{} }). + RegisterTypeEncoder(tEmpty, func() ValueEncoder { return &emptyInterfaceCodec{} }). + RegisterTypeEncoder(tCoreArray, func() ValueEncoder { return &arrayCodec{} }). + RegisterTypeEncoder(tOID, func() ValueEncoder { return ValueEncoderFunc(objectIDEncodeValue) }). + RegisterTypeEncoder(tDecimal, func() ValueEncoder { return ValueEncoderFunc(decimal128EncodeValue) }). + RegisterTypeEncoder(tJSONNumber, func() ValueEncoder { return ValueEncoderFunc(jsonNumberEncodeValue) }). + RegisterTypeEncoder(tURL, func() ValueEncoder { return ValueEncoderFunc(urlEncodeValue) }). + RegisterTypeEncoder(tJavaScript, func() ValueEncoder { return ValueEncoderFunc(javaScriptEncodeValue) }). + RegisterTypeEncoder(tSymbol, func() ValueEncoder { return ValueEncoderFunc(symbolEncodeValue) }). + RegisterTypeEncoder(tBinary, func() ValueEncoder { return ValueEncoderFunc(binaryEncodeValue) }). + RegisterTypeEncoder(tUndefined, func() ValueEncoder { return ValueEncoderFunc(undefinedEncodeValue) }). + RegisterTypeEncoder(tDateTime, func() ValueEncoder { return ValueEncoderFunc(dateTimeEncodeValue) }). + RegisterTypeEncoder(tNull, func() ValueEncoder { return ValueEncoderFunc(nullEncodeValue) }). + RegisterTypeEncoder(tRegex, func() ValueEncoder { return ValueEncoderFunc(regexEncodeValue) }). + RegisterTypeEncoder(tDBPointer, func() ValueEncoder { return ValueEncoderFunc(dbPointerEncodeValue) }). + RegisterTypeEncoder(tTimestamp, func() ValueEncoder { return ValueEncoderFunc(timestampEncodeValue) }). + RegisterTypeEncoder(tMinKey, func() ValueEncoder { return ValueEncoderFunc(minKeyEncodeValue) }). + RegisterTypeEncoder(tMaxKey, func() ValueEncoder { return ValueEncoderFunc(maxKeyEncodeValue) }). + RegisterTypeEncoder(tCoreDocument, func() ValueEncoder { return ValueEncoderFunc(coreDocumentEncodeValue) }). + RegisterTypeEncoder(tCodeWithScope, func() ValueEncoder { return ValueEncoderFunc(codeWithScopeEncodeValue) }). + RegisterKindEncoder(reflect.Bool, func() ValueEncoder { return ValueEncoderFunc(booleanEncodeValue) }). + RegisterKindEncoder(reflect.Int, intEncoder). + RegisterKindEncoder(reflect.Int8, intEncoder). + RegisterKindEncoder(reflect.Int16, intEncoder). + RegisterKindEncoder(reflect.Int32, intEncoder). + RegisterKindEncoder(reflect.Int64, intEncoder). + RegisterKindEncoder(reflect.Uint, intEncoder). + RegisterKindEncoder(reflect.Uint8, intEncoder). + RegisterKindEncoder(reflect.Uint16, intEncoder). + RegisterKindEncoder(reflect.Uint32, intEncoder). + RegisterKindEncoder(reflect.Uint64, intEncoder). + RegisterKindEncoder(reflect.Float32, floatEncoder). + RegisterKindEncoder(reflect.Float64, floatEncoder). + RegisterKindEncoder(reflect.Array, func() ValueEncoder { return ValueEncoderFunc(arrayEncodeValue) }). + RegisterKindEncoder(reflect.Map, func() ValueEncoder { return &mapCodec{} }). + RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.String, func() ValueEncoder { return &stringCodec{} }). + RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return newStructCodec(DefaultStructTagParser) }). + RegisterKindEncoder(reflect.Ptr, func() ValueEncoder { return &pointerCodec{} }). + RegisterInterfaceEncoder(tValueMarshaler, func() ValueEncoder { return ValueEncoderFunc(valueMarshalerEncodeValue) }). + RegisterInterfaceEncoder(tMarshaler, func() ValueEncoder { return ValueEncoderFunc(marshalerEncodeValue) }). + RegisterInterfaceEncoder(tProxy, func() ValueEncoder { return ValueEncoderFunc(proxyEncodeValue) }) } // booleanEncodeValue is the ValueEncoderFunc for bool types. -func booleanEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func booleanEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Bool { return ValueEncoderError{Name: "BooleanEncodeValue", Kinds: []reflect.Kind{reflect.Bool}, Received: val} } @@ -114,7 +114,7 @@ func fitsIn32Bits(i int64) bool { } // floatEncodeValue is the ValueEncoderFunc for float types. -func floatEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func floatEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Float32, reflect.Float64: return vw.WriteDouble(val.Float()) @@ -124,7 +124,7 @@ func floatEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // objectIDEncodeValue is the ValueEncoderFunc for ObjectID. -func objectIDEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func objectIDEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tOID { return ValueEncoderError{Name: "ObjectIDEncodeValue", Types: []reflect.Type{tOID}, Received: val} } @@ -132,7 +132,7 @@ func objectIDEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // decimal128EncodeValue is the ValueEncoderFunc for Decimal128. -func decimal128EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func decimal128EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDecimal { return ValueEncoderError{Name: "Decimal128EncodeValue", Types: []reflect.Type{tDecimal}, Received: val} } @@ -140,7 +140,7 @@ func decimal128EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // jsonNumberEncodeValue is the ValueEncoderFunc for json.Number. -func jsonNumberEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func jsonNumberEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJSONNumber { return ValueEncoderError{Name: "JSONNumberEncodeValue", Types: []reflect.Type{tJSONNumber}, Received: val} } @@ -164,7 +164,7 @@ func jsonNumberEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) err } // urlEncodeValue is the ValueEncoderFunc for url.URL. -func urlEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func urlEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tURL { return ValueEncoderError{Name: "URLEncodeValue", Types: []reflect.Type{tURL}, Received: val} } @@ -173,7 +173,7 @@ func urlEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // arrayEncodeValue is the ValueEncoderFunc for array types. -func arrayEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func arrayEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Array { return ValueEncoderError{Name: "ArrayEncodeValue", Kinds: []reflect.Kind{reflect.Array}, Received: val} } @@ -243,7 +243,7 @@ func arrayEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { return aw.WriteArrayEnd() } -func lookupElementEncoder(reg *Registry, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { +func lookupElementEncoder(reg EncoderRegistry, origEncoder ValueEncoder, currVal reflect.Value) (ValueEncoder, reflect.Value, error) { if origEncoder != nil || (currVal.Kind() != reflect.Interface) { return origEncoder, currVal, nil } @@ -257,7 +257,7 @@ func lookupElementEncoder(reg *Registry, origEncoder ValueEncoder, currVal refle } // valueMarshalerEncodeValue is the ValueEncoderFunc for ValueMarshaler implementations. -func valueMarshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func valueMarshalerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement ValueMarshaler switch { case !val.IsValid(): @@ -285,7 +285,7 @@ func valueMarshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) e } // marshalerEncodeValue is the ValueEncoderFunc for Marshaler implementations. -func marshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func marshalerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Marshaler switch { case !val.IsValid(): @@ -313,7 +313,7 @@ func marshalerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // proxyEncodeValue is the ValueEncoderFunc for Proxy implementations. -func proxyEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func proxyEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Proxy switch { case !val.IsValid(): @@ -357,7 +357,7 @@ func proxyEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { } // javaScriptEncodeValue is the ValueEncoderFunc for the JavaScript type. -func javaScriptEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func javaScriptEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tJavaScript { return ValueEncoderError{Name: "JavaScriptEncodeValue", Types: []reflect.Type{tJavaScript}, Received: val} } @@ -366,7 +366,7 @@ func javaScriptEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // symbolEncodeValue is the ValueEncoderFunc for the Symbol type. -func symbolEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func symbolEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tSymbol { return ValueEncoderError{Name: "SymbolEncodeValue", Types: []reflect.Type{tSymbol}, Received: val} } @@ -375,7 +375,7 @@ func symbolEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // binaryEncodeValue is the ValueEncoderFunc for Binary. -func binaryEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func binaryEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tBinary { return ValueEncoderError{Name: "BinaryEncodeValue", Types: []reflect.Type{tBinary}, Received: val} } @@ -385,7 +385,7 @@ func binaryEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // undefinedEncodeValue is the ValueEncoderFunc for Undefined. -func undefinedEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func undefinedEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tUndefined { return ValueEncoderError{Name: "UndefinedEncodeValue", Types: []reflect.Type{tUndefined}, Received: val} } @@ -394,7 +394,7 @@ func undefinedEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // dateTimeEncodeValue is the ValueEncoderFunc for DateTime. -func dateTimeEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func dateTimeEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDateTime { return ValueEncoderError{Name: "DateTimeEncodeValue", Types: []reflect.Type{tDateTime}, Received: val} } @@ -403,7 +403,7 @@ func dateTimeEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // nullEncodeValue is the ValueEncoderFunc for Null. -func nullEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func nullEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tNull { return ValueEncoderError{Name: "NullEncodeValue", Types: []reflect.Type{tNull}, Received: val} } @@ -412,7 +412,7 @@ func nullEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // regexEncodeValue is the ValueEncoderFunc for Regex. -func regexEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func regexEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRegex { return ValueEncoderError{Name: "RegexEncodeValue", Types: []reflect.Type{tRegex}, Received: val} } @@ -423,7 +423,7 @@ func regexEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // dbPointerEncodeValue is the ValueEncoderFunc for DBPointer. -func dbPointerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func dbPointerEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tDBPointer { return ValueEncoderError{Name: "DBPointerEncodeValue", Types: []reflect.Type{tDBPointer}, Received: val} } @@ -434,7 +434,7 @@ func dbPointerEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // timestampEncodeValue is the ValueEncoderFunc for Timestamp. -func timestampEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func timestampEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTimestamp { return ValueEncoderError{Name: "TimestampEncodeValue", Types: []reflect.Type{tTimestamp}, Received: val} } @@ -445,7 +445,7 @@ func timestampEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error } // minKeyEncodeValue is the ValueEncoderFunc for MinKey. -func minKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func minKeyEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMinKey { return ValueEncoderError{Name: "MinKeyEncodeValue", Types: []reflect.Type{tMinKey}, Received: val} } @@ -454,7 +454,7 @@ func minKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // maxKeyEncodeValue is the ValueEncoderFunc for MaxKey. -func maxKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func maxKeyEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tMaxKey { return ValueEncoderError{Name: "MaxKeyEncodeValue", Types: []reflect.Type{tMaxKey}, Received: val} } @@ -463,7 +463,7 @@ func maxKeyEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { } // coreDocumentEncodeValue is the ValueEncoderFunc for bsoncore.Document. -func coreDocumentEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func coreDocumentEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCoreDocument { return ValueEncoderError{Name: "CoreDocumentEncodeValue", Types: []reflect.Type{tCoreDocument}, Received: val} } @@ -474,7 +474,7 @@ func coreDocumentEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) err } // codeWithScopeEncodeValue is the ValueEncoderFunc for CodeWithScope. -func codeWithScopeEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func codeWithScopeEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tCodeWithScope { return ValueEncoderError{Name: "CodeWithScopeEncodeValue", Types: []reflect.Type{tCodeWithScope}, Received: val} } diff --git a/bson/default_value_encoders_test.go b/bson/default_value_encoders_test.go index 0cc12ce597..cd8efe72db 100644 --- a/bson/default_value_encoders_test.go +++ b/bson/default_value_encoders_test.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/url" "reflect" "strings" @@ -37,11 +38,11 @@ func TestDefaultValueEncoders(t *testing.T) { var wrong = func(string, string) string { return "wrong" } type mybool bool - // type myint8 int8 - // type myint16 int16 - // type myint32 int32 - // type myint64 int64 - // type myint int + type myint8 int8 + type myint16 int16 + type myint32 int32 + type myint64 int64 + type myint int type myuint8 uint8 type myuint16 uint16 type myuint32 uint32 @@ -92,51 +93,52 @@ func TestDefaultValueEncoders(t *testing.T) { {"reflection path", mybool(true), nil, nil, writeBoolean, nil}, }, }, - /* - { - "IntEncodeValue", - ValueEncoderFunc(intEncodeValue), - []subtest{ - { - "wrong type", - wrong, - nil, - nil, - nothing, - ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, - Received: reflect.ValueOf(wrong), + { + "IntEncodeValue", + &intCodec{}, + []subtest{ + { + "wrong type", + wrong, + nil, + nil, + nothing, + ValueEncoderError{ + Name: "IntEncodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, }, + Received: reflect.ValueOf(wrong), }, - {"int8/fast path", int8(127), nil, nil, writeInt32, nil}, - {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, - {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, - {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, - {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, - {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, - {"int/fast path - MinInt32", int(math.MinInt32), nil, nil, writeInt32, nil}, - {"int8/reflection path", myint8(127), nil, nil, writeInt32, nil}, - {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, - {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, - {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, - {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, - {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, - {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, - {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, - {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, - {"int/reflection path - MinInt32", myint(math.MinInt32), nil, nil, writeInt32, nil}, }, + {"int8/fast path", int8(127), nil, nil, writeInt32, nil}, + {"int16/fast path", int16(32767), nil, nil, writeInt32, nil}, + {"int32/fast path", int32(2147483647), nil, nil, writeInt32, nil}, + {"int64/fast path", int64(1234567890987), nil, nil, writeInt64, nil}, + // {"int64/fast path - minsize", int64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"int64/fast path - minsize too large", int64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"int64/fast path - minsize too small", int64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int/fast path - positive int32", int(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, + {"int/fast path - negative int32", int(math.MinInt32 + 1), nil, nil, writeInt32, nil}, + {"int/fast path - MaxInt32", int(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int/fast path - MinInt32", int(math.MinInt32), nil, nil, writeInt32, nil}, + {"int8/reflection path", myint8(127), nil, nil, writeInt32, nil}, + {"int16/reflection path", myint16(32767), nil, nil, writeInt32, nil}, + {"int32/reflection path", myint32(2147483647), nil, nil, writeInt32, nil}, + {"int64/reflection path", myint64(1234567890987), nil, nil, writeInt64, nil}, + // {"int64/reflection path - minsize", myint64(math.MaxInt32), &EncodeContext{minSize: true}, nil, writeInt32, nil}, + // {"int64/reflection path - minsize too large", myint64(math.MaxInt32 + 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + // {"int64/reflection path - minsize too small", myint64(math.MinInt32 - 1), &EncodeContext{minSize: true}, nil, writeInt64, nil}, + {"int/reflection path - positive int32", myint(math.MaxInt32 - 1), nil, nil, writeInt32, nil}, + {"int/reflection path - negative int32", myint(math.MinInt32 + 1), nil, nil, writeInt32, nil}, + {"int/reflection path - MaxInt32", myint(math.MaxInt32), nil, nil, writeInt32, nil}, + {"int/reflection path - MinInt32", myint(math.MinInt32), nil, nil, writeInt32, nil}, }, - */ + }, { "UintEncodeValue", - &uintCodec{}, + &intCodec{}, []subtest{ { "wrong type", @@ -145,8 +147,11 @@ func TestDefaultValueEncoders(t *testing.T) { nil, nothing, ValueEncoderError{ - Name: "UintEncodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, + Name: "IntEncodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: reflect.ValueOf(wrong), }, }, @@ -235,7 +240,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", map[string]int{"foo": 1}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeDocument, fmt.Errorf("no encoder found for int"), @@ -259,7 +264,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty map/success", map[string]interface{}{}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeDocumentEnd, nil, @@ -315,7 +320,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", [1]int{1}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -393,7 +398,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "Lookup Error", []int{1}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArray, fmt.Errorf("no encoder found for int"), @@ -433,7 +438,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "empty slice/success", []interface{}{}, - newTestRegistry(), + newTestRegistryBuilder().Build(), &valueReaderWriter{}, writeArrayEnd, nil, @@ -510,7 +515,7 @@ func TestDefaultValueEncoders(t *testing.T) { { "json.Number/int64/success", json.Number("1234567890"), - nil, nil, writeInt64, nil, + buildDefaultRegistry(), nil, writeInt64, nil, }, { "json.Number/float64/success", diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 0a68c77a40..cea7dfd348 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -22,7 +22,7 @@ var ( ) // EncodeValue is the ValueEncoderFunc for interface{}. -func (eic emptyInterfaceCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (eic emptyInterfaceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tEmpty { return ValueEncoderError{Name: "EmptyInterfaceEncodeValue", Types: []reflect.Type{tEmpty}, Received: val} } diff --git a/bson/encoder.go b/bson/encoder.go index 33cb16cc13..42c1900006 100644 --- a/bson/encoder.go +++ b/bson/encoder.go @@ -70,9 +70,10 @@ func (e *Encoder) SetRegistry(r *Registry) { // ErrorOnInlineDuplicates causes the Encoder to return an error if there is a duplicate field in // the marshaled BSON when the "inline" struct tag option is set. func (e *Encoder) ErrorOnInlineDuplicates() { - if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { - if enc, ok := v.(*structCodec); ok { - enc.overwriteDuplicatedInlinedFields = false + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).overwriteDuplicatedInlinedFields = false } } } @@ -81,14 +82,20 @@ func (e *Encoder) ErrorOnInlineDuplicates() { // uint8, uint16, uint32, or uint64) as the minimum BSON int size (either 32 or 64 bits) that can // represent the integer value. func (e *Encoder) IntMinSize() { - if v, ok := e.reg.kindEncoders.Load(reflect.Int); ok { - if enc, ok := v.(*intCodec); ok { - enc.encodeToMinSize = true - } - } - if v, ok := e.reg.kindEncoders.Load(reflect.Uint); ok { - if enc, ok := v.(*uintCodec); ok { - enc.encodeToMinSize = true + // if v, ok := e.reg.kindEncoders.Load(reflect.Int); ok { + // if enc, ok := v.(*intCodec); ok { + // enc.encodeToMinSize = true + // } + // } + // if v, ok := e.reg.kindEncoders.Load(reflect.Uint); ok { + // if enc, ok := v.(*uintCodec); ok { + // enc.encodeToMinSize = true + // } + // } + t := reflect.TypeOf((*intCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*intCodec).encodeToMinSize = true } } } @@ -96,9 +103,10 @@ func (e *Encoder) IntMinSize() { // StringifyMapKeysWithFmt causes the Encoder to convert Go map keys to BSON document field name // strings using fmt.Sprint instead of the default string conversion logic. func (e *Encoder) StringifyMapKeysWithFmt() { - if v, ok := e.reg.kindEncoders.Load(reflect.Map); ok { - if enc, ok := v.(*mapCodec); ok { - enc.encodeKeysWithStringer = true + t := reflect.TypeOf((*mapCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*mapCodec).encodeKeysWithStringer = true } } } @@ -106,9 +114,10 @@ func (e *Encoder) StringifyMapKeysWithFmt() { // NilMapAsEmpty causes the Encoder to marshal nil Go maps as empty BSON documents instead of BSON // null. func (e *Encoder) NilMapAsEmpty() { - if v, ok := e.reg.kindEncoders.Load(reflect.Map); ok { - if enc, ok := v.(*mapCodec); ok { - enc.encodeNilAsEmpty = true + t := reflect.TypeOf((*mapCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*mapCodec).encodeNilAsEmpty = true } } } @@ -116,9 +125,10 @@ func (e *Encoder) NilMapAsEmpty() { // NilSliceAsEmpty causes the Encoder to marshal nil Go slices as empty BSON arrays instead of BSON // null. func (e *Encoder) NilSliceAsEmpty() { - if v, ok := e.reg.kindEncoders.Load(reflect.Slice); ok { - if enc, ok := v.(*sliceCodec); ok { - enc.encodeNilAsEmpty = true + t := reflect.TypeOf((*sliceCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*sliceCodec).encodeNilAsEmpty = true } } } @@ -126,9 +136,15 @@ func (e *Encoder) NilSliceAsEmpty() { // NilByteSliceAsEmpty causes the Encoder to marshal nil Go byte slices as empty BSON binary values // instead of BSON null. func (e *Encoder) NilByteSliceAsEmpty() { - if v, ok := e.reg.typeEncoders.Load(tByteSlice); ok { - if enc, ok := v.(*byteSliceCodec); ok { - enc.encodeNilAsEmpty = true + // if v, ok := e.reg.typeEncoders.Load(tByteSlice); ok { + // if enc, ok := v.(*byteSliceCodec); ok { + // enc.encodeNilAsEmpty = true + // } + // } + t := reflect.TypeOf((*byteSliceCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*byteSliceCodec).encodeNilAsEmpty = true } } } @@ -142,9 +158,10 @@ func (e *Encoder) NilByteSliceAsEmpty() { // Note that the Encoder only examines exported struct fields when determining if a struct is the // zero value. It considers pointers to a zero struct value (e.g. &MyStruct{}) not empty. func (e *Encoder) OmitZeroStruct() { - if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { - if enc, ok := v.(*structCodec); ok { - enc.encodeOmitDefaultStruct = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).encodeOmitDefaultStruct = true } } } @@ -152,9 +169,10 @@ func (e *Encoder) OmitZeroStruct() { // UseJSONStructTags causes the Encoder to fall back to using the "json" struct tag if a "bson" // struct tag is not specified. func (e *Encoder) UseJSONStructTags() { - if v, ok := e.reg.kindEncoders.Load(reflect.Struct); ok { - if enc, ok := v.(*structCodec); ok { - enc.useJSONStructTags = true + t := reflect.TypeOf((*structCodec)(nil)) + if v, ok := e.reg.encoderTypeMap[t]; ok && v != nil { + for i := range v { + v[i].(*structCodec).useJSONStructTags = true } } } diff --git a/bson/int_codec.go b/bson/int_codec.go index d0791ad70b..4d82092309 100644 --- a/bson/int_codec.go +++ b/bson/int_codec.go @@ -7,6 +7,8 @@ package bson import ( + "fmt" + "math" "reflect" ) @@ -15,10 +17,16 @@ type intCodec struct { // encodeToMinSize causes EncodeValue to marshal Go uint values (excluding uint64) as the // minimum BSON int size (either 32-bit or 64-bit) that can represent the integer value. encodeToMinSize bool + + // truncate, if true, instructs decoders to to truncate the fractional part of BSON "double" + // values when attempting to unmarshal them into a Go integer (int, int8, int16, int32, int64, + // uint, uint8, uint16, uint32, or uint64) struct field. The truncation logic does not apply to + // BSON "decimal128" values. + truncate bool } // EncodeValue is the ValueEncoder for uint types. -func (ic *intCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (ic *intCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { switch val.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32: return vw.WriteInt32(int32(val.Int())) @@ -34,11 +42,153 @@ func (ic *intCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) return vw.WriteInt32(int32(i64)) } return vw.WriteInt64(i64) + + case reflect.Uint8, reflect.Uint16: + return vw.WriteInt32(int32(val.Uint())) + case reflect.Uint, reflect.Uint32, reflect.Uint64: + u64 := val.Uint() + + // If encodeToMinSize is true for a non-uint64 value we should write val as an int32 + useMinSize := ic.encodeToMinSize && val.Kind() != reflect.Uint64 + + if u64 <= math.MaxInt32 && useMinSize { + return vw.WriteInt32(int32(u64)) + } + if u64 > math.MaxInt64 { + return fmt.Errorf("%d overflows int64", u64) + } + return vw.WriteInt64(int64(u64)) } return ValueEncoderError{ - Name: "IntEncodeValue", - Kinds: []reflect.Kind{reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int}, + Name: "IntEncodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, Received: val, } } + +// DecodeValue is the ValueDecoder for uint types. +func (ic *intCodec) DecodeValue(_ *Registry, vr ValueReader, val reflect.Value) error { + if !val.CanSet() { + return ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: val, + } + } + + var i64 int64 + switch vrType := vr.Type(); vrType { + case TypeInt32: + i32, err := vr.ReadInt32() + if err != nil { + return err + } + i64 = int64(i32) + case TypeInt64: + var err error + i64, err = vr.ReadInt64() + if err != nil { + return err + } + case TypeDouble: + f64, err := vr.ReadDouble() + if err != nil { + return err + } + if !ic.truncate && math.Floor(f64) != f64 { + return errCannotTruncate + } + if f64 > float64(math.MaxInt64) { + return fmt.Errorf("%g overflows int64", f64) + } + i64 = int64(f64) + case TypeBoolean: + b, err := vr.ReadBoolean() + if err != nil { + return err + } + if b { + i64 = 1 + } + case TypeNull: + if err := vr.ReadNull(); err != nil { + return err + } + case TypeUndefined: + if err := vr.ReadUndefined(); err != nil { + return err + } + default: + return fmt.Errorf("cannot decode %v into an integer type", vrType) + } + + switch t := val.Type(); t.Kind() { + case reflect.Int8: + if i64 < math.MinInt8 || i64 > math.MaxInt8 { + return fmt.Errorf("%d overflows int8", i64) + } + val.SetInt(i64) + case reflect.Int16: + if i64 < math.MinInt16 || i64 > math.MaxInt16 { + return fmt.Errorf("%d overflows int16", i64) + } + val.SetInt(i64) + case reflect.Int32: + if i64 < math.MinInt32 || i64 > math.MaxInt32 { + return fmt.Errorf("%d overflows int32", i64) + } + val.SetInt(i64) + case reflect.Int64: + val.SetInt(i64) + case reflect.Int: + if int64(int(i64)) != i64 { // Can we fit this inside of an int + return fmt.Errorf("%d overflows int", i64) + } + val.SetInt(i64) + + case reflect.Uint8: + if i64 < 0 || i64 > math.MaxUint8 { + return fmt.Errorf("%d overflows uint8", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint16: + if i64 < 0 || i64 > math.MaxUint16 { + return fmt.Errorf("%d overflows uint16", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint32: + if i64 < 0 || i64 > math.MaxUint32 { + return fmt.Errorf("%d overflows uint32", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint64: + if i64 < 0 { + return fmt.Errorf("%d overflows uint64", i64) + } + val.SetUint(uint64(i64)) + case reflect.Uint: + if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint + return fmt.Errorf("%d overflows uint", i64) + } + val.SetUint(uint64(i64)) + + default: + return ValueDecoderError{ + Name: "IntDecodeValue", + Kinds: []reflect.Kind{ + reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int, + reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, + }, + Received: reflect.Zero(t), + } + } + + return nil +} diff --git a/bson/map_codec.go b/bson/map_codec.go index ce2ac9ff9e..bfa77ca0d8 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -50,7 +50,7 @@ type KeyUnmarshaler interface { } // EncodeValue is the ValueEncoder for map[*]* types. -func (mc *mapCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (mc *mapCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Map { return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} } @@ -78,7 +78,7 @@ func (mc *mapCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value // mapEncodeValue handles encoding of the values of a map. The collisionFn returns // true if the provided key exists, this is mainly used for inline maps in the // struct codec. -func (mc *mapCodec) mapEncodeValue(reg *Registry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { +func (mc *mapCodec) mapEncodeValue(reg EncoderRegistry, dw DocumentWriter, val reflect.Value, collisionFn func(string) bool) error { elemType := val.Type().Elem() encoder, err := reg.LookupEncoder(elemType) diff --git a/bson/marshal_test.go b/bson/marshal_test.go index 6013d7b911..5eeada562c 100644 --- a/bson/marshal_test.go +++ b/bson/marshal_test.go @@ -149,15 +149,16 @@ func TestCachingEncodersNotSharedAcrossRegistries(t *testing.T) { // different Registry is used. // Create a custom Registry that negates int32 values when encoding. - var encodeInt32 ValueEncoderFunc = func(_ *Registry, vw ValueWriter, val reflect.Value) error { + var encodeInt32 ValueEncoderFunc = func(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Int32 { return fmt.Errorf("expected kind to be int32, got %v", val.Kind()) } return vw.WriteInt32(int32(val.Int()) * -1) } - customReg := NewRegistry() - customReg.RegisterTypeEncoder(tInt32, encodeInt32) + customReg := NewRegistryBuilder(). + RegisterTypeEncoder(tInt32, func() ValueEncoder { return encodeInt32 }). + Build() // Helper function to run the test and make assertions. The provided original value should result in the document // {"x": {$numberInt: 1}} when marshalled with the default registry. diff --git a/bson/mgoregistry.go b/bson/mgoregistry.go index 398de2afb9..f0b77f4efb 100644 --- a/bson/mgoregistry.go +++ b/bson/mgoregistry.go @@ -22,10 +22,7 @@ var ( tSetter = reflect.TypeOf((*Setter)(nil)).Elem() ) -// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. -func NewMgoRegistry() *Registry { - reg := NewRegistry() - +func newMgoRegistryBuilder() *RegistryBuilder { structcodec := &structCodec{ parser: DefaultStructTagParser, decodeZeroStruct: true, @@ -37,45 +34,47 @@ func NewMgoRegistry() *Registry { encodeNilAsEmpty: true, encodeKeysWithStringer: true, } - uintcodec := &uintCodec{encodeToMinSize: true} + intcodec := func() ValueEncoder { return &intCodec{encodeToMinSize: true} } - reg.RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}) - reg.RegisterKindDecoder(reflect.String, &stringCodec{}) - reg.RegisterKindDecoder(reflect.Struct, structcodec) - reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: true}) - reg.RegisterKindEncoder(reflect.Struct, structcodec) - reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{encodeNilAsEmpty: true}) - reg.RegisterKindEncoder(reflect.Map, mapCodec) - reg.RegisterKindEncoder(reflect.Uint, uintcodec) - reg.RegisterKindEncoder(reflect.Uint8, uintcodec) - reg.RegisterKindEncoder(reflect.Uint16, uintcodec) - reg.RegisterKindEncoder(reflect.Uint32, uintcodec) - reg.RegisterKindEncoder(reflect.Uint64, uintcodec) - reg.RegisterTypeMapEntry(TypeInt32, tInt) - reg.RegisterTypeMapEntry(TypeDateTime, tTime) - reg.RegisterTypeMapEntry(TypeArray, tInterfaceSlice) - reg.RegisterTypeMapEntry(Type(0), tM) - reg.RegisterTypeMapEntry(TypeEmbeddedDocument, tM) - reg.RegisterInterfaceEncoder(tGetter, ValueEncoderFunc(GetterEncodeValue)) - reg.RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(SetterDecodeValue)) + return NewRegistryBuilder(). + RegisterTypeDecoder(tEmpty, &emptyInterfaceCodec{decodeBinaryAsSlice: true}). + RegisterKindDecoder(reflect.String, &stringCodec{}). + RegisterKindDecoder(reflect.Struct, structcodec). + RegisterKindDecoder(reflect.Map, mapCodec). + RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return structcodec }). + RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{encodeNilAsEmpty: true} }). + RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). + RegisterKindEncoder(reflect.Uint, intcodec). + RegisterKindEncoder(reflect.Uint8, intcodec). + RegisterKindEncoder(reflect.Uint16, intcodec). + RegisterKindEncoder(reflect.Uint32, intcodec). + RegisterKindEncoder(reflect.Uint64, intcodec). + RegisterTypeMapEntry(TypeInt32, tInt). + RegisterTypeMapEntry(TypeDateTime, tTime). + RegisterTypeMapEntry(TypeArray, tInterfaceSlice). + RegisterTypeMapEntry(Type(0), tM). + RegisterTypeMapEntry(TypeEmbeddedDocument, tM). + RegisterInterfaceEncoder(tGetter, func() ValueEncoder { return ValueEncoderFunc(GetterEncodeValue) }). + RegisterInterfaceDecoder(tSetter, ValueDecoderFunc(SetterDecodeValue)) +} - return reg +// NewMgoRegistry creates a new bson.Registry configured with the default encoders and decoders. +func NewMgoRegistry() *Registry { + return newMgoRegistryBuilder().Build() } // NewRespectNilValuesMgoRegistry creates a new bson.Registry configured to behave like mgo/bson // with RespectNilValues set to true. func NewRespectNilValuesMgoRegistry() *Registry { - reg := NewMgoRegistry() - mapCodec := &mapCodec{ decodeZerosMap: true, } - reg.RegisterKindDecoder(reflect.Map, mapCodec) - reg.RegisterTypeEncoder(tByteSlice, &byteSliceCodec{encodeNilAsEmpty: false}) - reg.RegisterKindEncoder(reflect.Slice, &sliceCodec{}) - reg.RegisterKindEncoder(reflect.Map, mapCodec) - - return reg + return newMgoRegistryBuilder(). + RegisterKindDecoder(reflect.Map, mapCodec). + RegisterTypeEncoder(tByteSlice, func() ValueEncoder { return &byteSliceCodec{encodeNilAsEmpty: false} }). + RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return &sliceCodec{} }). + RegisterKindEncoder(reflect.Map, func() ValueEncoder { return mapCodec }). + Build() } diff --git a/bson/pointer_codec.go b/bson/pointer_codec.go index 425d371d0e..af35da68b2 100644 --- a/bson/pointer_codec.go +++ b/bson/pointer_codec.go @@ -18,7 +18,7 @@ type pointerCodec struct { // EncodeValue handles encoding a pointer by either encoding it to BSON Null if the pointer is nil // or looking up an encoder for the type of value the pointer points to. -func (pc *pointerCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (pc *pointerCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.Ptr { if !val.IsValid() { return vw.WriteNull() diff --git a/bson/primitive_codecs.go b/bson/primitive_codecs.go index adbb28d601..082cd15357 100644 --- a/bson/primitive_codecs.go +++ b/bson/primitive_codecs.go @@ -16,22 +16,22 @@ var tRawValue = reflect.TypeOf(RawValue{}) var tRaw = reflect.TypeOf(Raw(nil)) // registerPrimitiveCodecs will register the encode and decode methods with the provided Registry. -func registerPrimitiveCodecs(reg *Registry) { - if reg == nil { +func registerPrimitiveCodecs(rb *RegistryBuilder) { + if rb == nil { panic(errors.New("argument to RegisterPrimitiveCodecs must not be nil")) } - reg.RegisterTypeEncoder(tRawValue, ValueEncoderFunc(rawValueEncodeValue)) - reg.RegisterTypeEncoder(tRaw, ValueEncoderFunc(rawEncodeValue)) - reg.RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)) - reg.RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) + rb.RegisterTypeEncoder(tRawValue, func() ValueEncoder { return ValueEncoderFunc(rawValueEncodeValue) }). + RegisterTypeEncoder(tRaw, func() ValueEncoder { return ValueEncoderFunc(rawEncodeValue) }). + RegisterTypeDecoder(tRawValue, ValueDecoderFunc(rawValueDecodeValue)). + RegisterTypeDecoder(tRaw, ValueDecoderFunc(rawDecodeValue)) } // rawValueEncodeValue is the ValueEncoderFunc for RawValue. // // If the RawValue's Type is "invalid" and the RawValue's Value is not empty or // nil, then this method will return an error. -func rawValueEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func rawValueEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRawValue { return ValueEncoderError{ Name: "RawValueEncodeValue", @@ -65,7 +65,7 @@ func rawValueDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) err } // rawEncodeValue is the ValueEncoderFunc for Reader. -func rawEncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func rawEncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tRaw { return ValueEncoderError{Name: "RawEncodeValue", Types: []reflect.Type{tRaw}, Received: val} } diff --git a/bson/raw_value_test.go b/bson/raw_value_test.go index 67444faa61..18598ebe8f 100644 --- a/bson/raw_value_test.go +++ b/bson/raw_value_test.go @@ -25,7 +25,7 @@ func TestRawValue(t *testing.T) { t.Run("Uses registry attached to value", func(t *testing.T) { t.Parallel() - reg := newTestRegistry() + reg := newTestRegistryBuilder().Build() val := RawValue{Type: TypeString, Value: bsoncore.AppendString(nil, "foobar"), r: reg} var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -63,7 +63,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - reg := newTestRegistry() + reg := newTestRegistryBuilder().Build() var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -75,7 +75,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - reg := NewRegistry() + reg := NewRegistryBuilder().Build() val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) @@ -87,7 +87,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - reg := NewRegistry() + reg := NewRegistryBuilder().Build() want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 @@ -114,7 +114,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns lookup error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: newTestRegistry()} + dc := DecodeContext{Registry: newTestRegistryBuilder().Build()} var val RawValue var s string want := ErrNoDecoder{Type: reflect.TypeOf(s)} @@ -126,7 +126,7 @@ func TestRawValue(t *testing.T) { t.Run("Returns DecodeValue error", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistry()} + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, 3.14159)} var s string want := fmt.Errorf("cannot decode %v into a string type", TypeDouble) @@ -138,7 +138,7 @@ func TestRawValue(t *testing.T) { t.Run("Success", func(t *testing.T) { t.Parallel() - dc := DecodeContext{Registry: NewRegistry()} + dc := DecodeContext{Registry: NewRegistryBuilder().Build()} want := float64(3.14159) val := RawValue{Type: TypeDouble, Value: bsoncore.AppendDouble(nil, want)} var got float64 diff --git a/bson/registry.go b/bson/registry.go index 71d65259d6..179b61de91 100644 --- a/bson/registry.go +++ b/bson/registry.go @@ -15,7 +15,7 @@ import ( // DefaultRegistry is the default Registry. It contains the default codecs and the // primitive codecs. -var DefaultRegistry = NewRegistry() +var DefaultRegistry = NewRegistryBuilder().Build() // ErrNilType is returned when nil is passed to either LookupEncoder or LookupDecoder. // @@ -58,75 +58,54 @@ func (entme ErrNoTypeMapEntry) Error() string { return "no type map entry found for " + entme.Type.String() } -// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type -// documentation for examples of registering various custom encoders and decoders. A Registry can -// have four main types of codecs: -// -// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and -// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value -// whose type matches the registered type exactly. -// If the registered type is an interface, the codec will be invoked when encoding or decoding -// values whose type is the interface, but not for values with concrete types that implement the -// interface. -// -// 2. Interface encoders/decoders - These can be registered using the RegisterInterfaceEncoder and -// RegisterInterfaceDecoder methods. These methods only accept interface types and the registered codecs -// will be invoked when encoding or decoding values whose types implement the interface. An example -// of an interface defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method -// for any value whose type implements bson.Marshaler, regardless of the value's concrete type. -// -// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type -// associations are used when decoding into a bson.D/bson.M or a struct field of type interface{}. -// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances, -// respectively, when decoding into a bson.D. The following code would change the behavior so these -// values decode as Go int instances instead: -// -// intType := reflect.TypeOf(int(0)) -// registry.RegisterTypeMapEntry(bson.TypeInt32, intType).RegisterTypeMapEntry(bson.TypeInt64, intType) -// -// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and -// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding -// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't -// match a registered type or interface encoder/decoder first. These methods should be used to change the -// behavior for all values for a specific kind. -// -// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. -type Registry struct { - interfaceEncoders []interfaceValueEncoder - interfaceDecoders []interfaceValueDecoder - typeEncoders *typeEncoderCache +// A RegistryBuilder is used to build a Registry. This type is not goroutine +// safe. +type RegistryBuilder struct { + typeEncoders map[reflect.Type]EncoderFactory typeDecoders *typeDecoderCache - kindEncoders *kindEncoderCache + interfaceEncoders map[reflect.Type]EncoderFactory + interfaceDecoders []interfaceValueDecoder + kindEncoders [reflect.UnsafePointer + 1]EncoderFactory kindDecoders *kindDecoderCache - typeMap sync.Map // map[Type]reflect.Type + typeMap map[Type]reflect.Type } -// NewRegistry creates a new empty Registry. -func NewRegistry() *Registry { - reg := &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), +// NewRegistryBuilder creates a new empty RegistryBuilder. +func NewRegistryBuilder() *RegistryBuilder { + rb := &RegistryBuilder{ + typeEncoders: make(map[reflect.Type]EncoderFactory), + typeDecoders: new(typeDecoderCache), + interfaceEncoders: make(map[reflect.Type]EncoderFactory), + kindDecoders: new(kindDecoderCache), + typeMap: make(map[Type]reflect.Type), } - registerDefaultEncoders(reg) - registerDefaultDecoders(reg) - registerPrimitiveCodecs(reg) - return reg + registerDefaultEncoders(rb) + registerDefaultDecoders(rb) + registerPrimitiveCodecs(rb) + return rb } -// RegisterTypeEncoder registers the provided ValueEncoder for the provided type. +// EncoderFactory is a factory function that generates a new ValueEncoder. +type EncoderFactory func() ValueEncoder + +// DecoderFactory is a factory function that generates a new ValueDecoder. +type DecoderFactory func() ValueDecoder + +// RegisterTypeEncoder registers a ValueEncoder factory for the provided type. // -// The type will be used as provided, so an encoder can be registered for a type and a different -// encoder can be registered for a pointer to that type. +// The type will be used as provided, so an encoder factory can be registered for a type and a +// different one can be registered for a pointer to that type. // // If the given type is an interface, the encoder will be called when marshaling a type that is // that interface. It will not be called when marshaling a non-interface type that implements the -// interface. To get the latter behavior, call RegisterHookEncoder instead. +// interface. To get the latter behavior, call RegisterInterfaceEncoder instead. // // RegisterTypeEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) { - r.typeEncoders.Store(valueType, enc) +func (rb *RegistryBuilder) RegisterTypeEncoder(valueType reflect.Type, encFac EncoderFactory) *RegistryBuilder { + if encFac != nil { + rb.typeEncoders[valueType] = encFac + } + return rb } // RegisterTypeDecoder registers the provided ValueDecoder for the provided type. @@ -139,24 +118,28 @@ func (r *Registry) RegisterTypeEncoder(valueType reflect.Type, enc ValueEncoder) // implements the interface. To get the latter behavior, call RegisterHookDecoder instead. // // RegisterTypeDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) { - r.typeDecoders.Store(valueType, dec) +func (rb *RegistryBuilder) RegisterTypeDecoder(valueType reflect.Type, dec ValueDecoder) *RegistryBuilder { + rb.typeDecoders.Store(valueType, dec) + return rb } -// RegisterKindEncoder registers the provided ValueEncoder for the provided kind. +// RegisterKindEncoder registers a ValueEncoder factory for the provided kind. // -// Use RegisterKindEncoder to register an encoder for any type with the same underlying kind. For -// example, consider the type MyInt defined as +// Use RegisterKindEncoder to register an encoder factory for any type with the same underlying kind. +// For example, consider the type MyInt defined as // // type MyInt int32 // -// To define an encoder for MyInt and int32, use RegisterKindEncoder like +// To define an encoder factory for MyInt and int32, use RegisterKindEncoder like // // reg.RegisterKindEncoder(reflect.Int32, myEncoder) // // RegisterKindEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { - r.kindEncoders.Store(kind, enc) +func (rb *RegistryBuilder) RegisterKindEncoder(kind reflect.Kind, encFac EncoderFactory) *RegistryBuilder { + if encFac != nil && kind < reflect.Kind(len(rb.kindEncoders)) { + rb.kindEncoders[kind] = encFac + } + return rb } // RegisterKindDecoder registers the provided ValueDecoder for the provided kind. @@ -171,31 +154,29 @@ func (r *Registry) RegisterKindEncoder(kind reflect.Kind, enc ValueEncoder) { // reg.RegisterKindDecoder(reflect.Int32, myDecoder) // // RegisterKindDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) { - r.kindDecoders.Store(kind, dec) +func (rb *RegistryBuilder) RegisterKindDecoder(kind reflect.Kind, dec ValueDecoder) *RegistryBuilder { + rb.kindDecoders.Store(kind, dec) + return rb } -// RegisterInterfaceEncoder registers an encoder for the provided interface type iface. This encoder will -// be called when marshaling a type if the type implements iface or a pointer to the type +// RegisterInterfaceEncoder registers an encoder factory for the provided interface type iface. This +// encoder will be called when marshaling a type if the type implements iface or a pointer to the type // implements iface. If the provided type is not an interface // (i.e. iface.Kind() != reflect.Interface), this method will panic. // // RegisterInterfaceEncoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder) { +func (rb *RegistryBuilder) RegisterInterfaceEncoder(iface reflect.Type, encFac EncoderFactory) *RegistryBuilder { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceEncoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) panic(panicStr) } - for idx, encoder := range r.interfaceEncoders { - if encoder.i == iface { - r.interfaceEncoders[idx].ve = enc - return - } + if encFac != nil { + rb.interfaceEncoders[iface] = encFac } - r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{i: iface, ve: enc}) + return rb } // RegisterInterfaceDecoder registers an decoder for the provided interface type iface. This decoder will @@ -204,21 +185,23 @@ func (r *Registry) RegisterInterfaceEncoder(iface reflect.Type, enc ValueEncoder // this method will panic. // // RegisterInterfaceDecoder should not be called concurrently with any other Registry method. -func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) { +func (rb *RegistryBuilder) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder) *RegistryBuilder { if iface.Kind() != reflect.Interface { panicStr := fmt.Errorf("RegisterInterfaceDecoder expects a type with kind reflect.Interface, "+ "got type %s with kind %s", iface, iface.Kind()) panic(panicStr) } - for idx, decoder := range r.interfaceDecoders { + for idx, decoder := range rb.interfaceDecoders { if decoder.i == iface { - r.interfaceDecoders[idx].vd = dec - return + rb.interfaceDecoders[idx].vd = dec + return rb } } - r.interfaceDecoders = append(r.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) + rb.interfaceDecoders = append(rb.interfaceDecoders, interfaceValueDecoder{i: iface, vd: dec}) + + return rb } // RegisterTypeMapEntry will register the provided type to the BSON type. The primary usage for this @@ -230,8 +213,113 @@ func (r *Registry) RegisterInterfaceDecoder(iface reflect.Type, dec ValueDecoder // to decode to bson.Raw, use the following code: // // reg.RegisterTypeMapEntry(TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) -func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { - r.typeMap.Store(bt, rt) +// +// RegisterTypeMapEntry should not be called concurrently with any other Registry method. +func (rb *RegistryBuilder) RegisterTypeMapEntry(bt Type, rt reflect.Type) *RegistryBuilder { + rb.typeMap[bt] = rt + return rb +} + +// Build creates a Registry from the current state of this RegistryBuilder. +func (rb *RegistryBuilder) Build() *Registry { + r := &Registry{ + typeEncoders: new(sync.Map), + typeDecoders: rb.typeDecoders.Clone(), + interfaceEncoders: make([]interfaceValueEncoder, 0, len(rb.interfaceEncoders)), + interfaceDecoders: append([]interfaceValueDecoder(nil), rb.interfaceDecoders...), + kindDecoders: rb.kindDecoders.Clone(), + encoderTypeMap: make(map[reflect.Type][]ValueEncoder), + typeMap: make(map[Type]reflect.Type), + } + encoderCache := make(map[reflect.Value]ValueEncoder) + for k, v := range rb.typeEncoders { + var encoder ValueEncoder + if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { + encoder = enc + } else { + encoder = v() + encoderCache[reflect.ValueOf(v)] = encoder + et := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + } + r.typeEncoders.Store(k, encoder) + } + for k, v := range rb.interfaceEncoders { + var encoder ValueEncoder + if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { + encoder = enc + } else { + encoder = v() + encoderCache[reflect.ValueOf(v)] = encoder + et := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + } + r.interfaceEncoders = append(r.interfaceEncoders, interfaceValueEncoder{k, encoder}) + } + for i, v := range rb.kindEncoders { + if v == nil { + continue + } + var encoder ValueEncoder + if enc, ok := encoderCache[reflect.ValueOf(v)]; ok { + encoder = enc + } else { + encoder = v() + encoderCache[reflect.ValueOf(v)] = encoder + et := reflect.ValueOf(encoder).Type() + r.encoderTypeMap[et] = append(r.encoderTypeMap[et], encoder) + } + r.kindEncoders[i] = encoder + } + for k, v := range rb.typeMap { + r.typeMap[k] = v + } + return r +} + +// A Registry is a store for ValueEncoders, ValueDecoders, and a type map. See the Registry type +// documentation for examples of registering various custom encoders and decoders. A Registry can +// have four main types of codecs: +// +// 1. Type encoders/decoders - These can be registered using the RegisterTypeEncoder and +// RegisterTypeDecoder methods. The registered codec will be invoked when encoding/decoding a value +// whose type matches the registered type exactly. +// If the registered type is an interface, the codec will be invoked when encoding or decoding +// values whose type is the interface, but not for values with concrete types that implement the +// interface. +// +// 2. Interface encoders/decoders - These can be registered using the RegisterInterfaceEncoder and +// RegisterInterfaceDecoder methods. These methods only accept interface types and the registered codecs +// will be invoked when encoding or decoding values whose types implement the interface. An example +// of an interface defined by the driver is bson.Marshaler. The driver will call the MarshalBSON method +// for any value whose type implements bson.Marshaler, regardless of the value's concrete type. +// +// 3. Type map entries - This can be used to associate a BSON type with a Go type. These type +// associations are used when decoding into a bson.D/bson.M or a struct field of type interface{}. +// For example, by default, BSON int32 and int64 values decode as Go int32 and int64 instances, +// respectively, when decoding into a bson.D. The following code would change the behavior so these +// values decode as Go int instances instead: +// +// intType := reflect.TypeOf(int(0)) +// registry.RegisterTypeMapEntry(bson.TypeInt32, intType).RegisterTypeMapEntry(bson.TypeInt64, intType) +// +// 4. Kind encoder/decoders - These can be registered using the RegisterDefaultEncoder and +// RegisterDefaultDecoder methods. The registered codec will be invoked when encoding or decoding +// values whose reflect.Kind matches the registered reflect.Kind as long as the value's type doesn't +// match a registered type or interface encoder/decoder first. These methods should be used to change the +// behavior for all values for a specific kind. +// +// Read [Registry.LookupDecoder] and [Registry.LookupEncoder] for Registry lookup procedure. +type Registry struct { + typeEncoders *sync.Map // map[reflect.Type]ValueEncoder + typeDecoders *typeDecoderCache + interfaceEncoders []interfaceValueEncoder + interfaceDecoders []interfaceValueDecoder + kindEncoders [reflect.UnsafePointer + 1]ValueEncoder + kindDecoders *kindDecoderCache + typeMap map[Type]reflect.Type + + encoderTypeMap map[reflect.Type][]ValueEncoder } // LookupEncoder returns the first matching encoder in the Registry. It uses the following lookup @@ -250,36 +338,38 @@ func (r *Registry) RegisterTypeMapEntry(bt Type, rt reflect.Type) { // 3. An encoder registered using RegisterKindEncoder for the kind of value. // // If no encoder is found, an error of type ErrNoEncoder is returned. LookupEncoder is safe for -// concurrent use by multiple goroutines after all codecs and encoders are registered. +// concurrent use by multiple goroutines. func (r *Registry) LookupEncoder(valueType reflect.Type) (ValueEncoder, error) { if valueType == nil { return nil, ErrNoEncoder{Type: valueType} } - enc, found := r.lookupTypeEncoder(valueType) - if found { + + if enc, found := r.typeEncoders.Load(valueType); found { if enc == nil { return nil, ErrNoEncoder{Type: valueType} } - return enc, nil + return enc.(ValueEncoder), nil } - enc, found = r.lookupInterfaceEncoder(valueType, true) - if found { - return r.typeEncoders.LoadOrStore(valueType, enc), nil + if enc, found := r.lookupInterfaceEncoder(valueType, true); found { + r.typeEncoders.Store(valueType, enc) + return enc, nil } - if v, ok := r.kindEncoders.Load(valueType.Kind()); ok { - return r.storeTypeEncoder(valueType, v), nil + if enc, found := r.lookupKindEncoder(valueType.Kind()); found { + r.typeEncoders.Store(valueType, enc) + return enc, nil } return nil, ErrNoEncoder{Type: valueType} } -func (r *Registry) storeTypeEncoder(rt reflect.Type, enc ValueEncoder) ValueEncoder { - return r.typeEncoders.LoadOrStore(rt, enc) -} - -func (r *Registry) lookupTypeEncoder(rt reflect.Type) (ValueEncoder, bool) { - return r.typeEncoders.Load(rt) +func (r *Registry) lookupKindEncoder(valueKind reflect.Kind) (ValueEncoder, bool) { + if valueKind < reflect.Kind(len(r.kindEncoders)) { + if enc := r.kindEncoders[valueKind]; enc != nil { + return enc, true + } + } + return nil, false } func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool) (ValueEncoder, bool) { @@ -295,7 +385,7 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // ahead in interfaceEncoders defaultEnc, found := r.lookupInterfaceEncoder(valueType, false) if !found { - defaultEnc, _ = r.kindEncoders.Load(valueType.Kind()) + defaultEnc, _ = r.lookupKindEncoder(valueType.Kind()) } return &condAddrEncoder{canAddrEnc: ienc.ve, elseEnc: defaultEnc}, true } @@ -319,12 +409,12 @@ func (r *Registry) lookupInterfaceEncoder(valueType reflect.Type, allowAddr bool // 3. A decoder registered using RegisterKindDecoder for the kind of value. // // If no decoder is found, an error of type ErrNoDecoder is returned. LookupDecoder is safe for -// concurrent use by multiple goroutines after all codecs and decoders are registered. +// concurrent use by multiple goroutines. func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { if valueType == nil { return nil, ErrNilType } - dec, found := r.lookupTypeDecoder(valueType) + dec, found := r.typeDecoders.Load(valueType) if found { if dec == nil { return nil, ErrNoDecoder{Type: valueType} @@ -334,23 +424,15 @@ func (r *Registry) LookupDecoder(valueType reflect.Type) (ValueDecoder, error) { dec, found = r.lookupInterfaceDecoder(valueType, true) if found { - return r.storeTypeDecoder(valueType, dec), nil + return r.typeDecoders.LoadOrStore(valueType, dec), nil } if v, ok := r.kindDecoders.Load(valueType.Kind()); ok { - return r.storeTypeDecoder(valueType, v), nil + return r.typeDecoders.LoadOrStore(valueType, v), nil } return nil, ErrNoDecoder{Type: valueType} } -func (r *Registry) lookupTypeDecoder(valueType reflect.Type) (ValueDecoder, bool) { - return r.typeDecoders.Load(valueType) -} - -func (r *Registry) storeTypeDecoder(typ reflect.Type, dec ValueDecoder) ValueDecoder { - return r.typeDecoders.LoadOrStore(typ, dec) -} - func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool) (ValueDecoder, bool) { for _, idec := range r.interfaceDecoders { if valueType.Implements(idec.i) { @@ -371,14 +453,12 @@ func (r *Registry) lookupInterfaceDecoder(valueType reflect.Type, allowAddr bool // LookupTypeMapEntry inspects the registry's type map for a Go type for the corresponding BSON // type. If no type is found, ErrNoTypeMapEntry is returned. -// -// LookupTypeMapEntry should not be called concurrently with any other Registry method. func (r *Registry) LookupTypeMapEntry(bt Type) (reflect.Type, error) { - v, ok := r.typeMap.Load(bt) + v, ok := r.typeMap[bt] if v == nil || !ok { return nil, ErrNoTypeMapEntry{Type: bt} } - return v.(reflect.Type), nil + return v, nil } type interfaceValueEncoder struct { diff --git a/bson/registry_examples_test.go b/bson/registry_examples_test.go index b866df8cdb..4b15dde3d5 100644 --- a/bson/registry_examples_test.go +++ b/bson/registry_examples_test.go @@ -23,7 +23,7 @@ func ExampleRegistry_customEncoder() { negatedIntType := reflect.TypeOf(negatedInt(0)) negatedIntEncoder := func( - _ *bson.Registry, + _ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -46,10 +46,13 @@ func ExampleRegistry_customEncoder() { return vw.WriteInt64(negatedVal) } - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterTypeEncoder( negatedIntType, - bson.ValueEncoderFunc(negatedIntEncoder)) + func() bson.ValueEncoder { + return bson.ValueEncoderFunc(negatedIntEncoder) + }, + ) // Define a document that includes both int and negatedInt fields with the // same value. @@ -67,7 +70,7 @@ func ExampleRegistry_customEncoder() { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) enc := bson.NewEncoder(vw) - enc.SetRegistry(reg) + enc.SetRegistry(reg.Build()) err := enc.Encode(doc) if err != nil { panic(err) @@ -129,10 +132,11 @@ func ExampleRegistry_customDecoder() { return nil } - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterTypeDecoder( lenientBoolType, - bson.ValueDecoderFunc(lenientBoolDecoder)) + bson.ValueDecoderFunc(lenientBoolDecoder), + ) // Marshal a BSON document with a single field "isOK" that is a non-zero // integer value. @@ -148,7 +152,7 @@ func ExampleRegistry_customDecoder() { IsOK lenientBool `bson:"isOK"` } var doc MyDocument - err = bson.UnmarshalWithRegistry(reg, b, &doc) + err = bson.UnmarshalWithRegistry(reg.Build(), b, &doc) if err != nil { panic(err) } @@ -156,13 +160,13 @@ func ExampleRegistry_customDecoder() { // Output: {IsOK:true} } -func ExampleRegistry_RegisterKindEncoder() { +func ExampleRegistryBuilder_RegisterKindEncoder() { // Create a custom encoder that writes any Go type that has underlying type // int32 as an a BSON int64. To do that, we register the encoder as a "kind" // encoder for kind reflect.Int32. That way, even user-defined types with // underlying type int32 will be encoded as a BSON int64. int32To64Encoder := func( - _ *bson.Registry, + _ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value, ) error { @@ -181,10 +185,13 @@ func ExampleRegistry_RegisterKindEncoder() { // Create a default registry and register our int32-to-int64 encoder for // kind reflect.Int32. - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterKindEncoder( reflect.Int32, - bson.ValueEncoderFunc(int32To64Encoder)) + func() bson.ValueEncoder { + return bson.ValueEncoderFunc(int32To64Encoder) + }, + ) // Define a document that includes an int32, an int64, and a user-defined // type "myInt" that has underlying type int32. @@ -205,7 +212,7 @@ func ExampleRegistry_RegisterKindEncoder() { buf := new(bytes.Buffer) vw := bson.NewValueWriter(buf) enc := bson.NewEncoder(vw) - enc.SetRegistry(reg) + enc.SetRegistry(reg.Build()) err := enc.Encode(doc) if err != nil { panic(err) @@ -214,7 +221,7 @@ func ExampleRegistry_RegisterKindEncoder() { // Output: {"myint": {"$numberLong":"1"},"int32": {"$numberLong":"1"},"int64": {"$numberLong":"1"}} } -func ExampleRegistry_RegisterKindDecoder() { +func ExampleRegistryBuilder_RegisterKindDecoder() { // Create a custom decoder that can decode any integer value, including // integer values encoded as floating point numbers, to any Go type // with underlying type int64. To do that, we register the decoder as a @@ -270,10 +277,11 @@ func ExampleRegistry_RegisterKindDecoder() { return nil } - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder() reg.RegisterKindDecoder( reflect.Int64, - bson.ValueDecoderFunc(flexibleInt64KindDecoder)) + bson.ValueDecoderFunc(flexibleInt64KindDecoder), + ) // Marshal a BSON document with fields that are mixed numeric types but all // hold integer values (i.e. values with no fractional part). @@ -290,7 +298,7 @@ func ExampleRegistry_RegisterKindDecoder() { Int64 int64 } var doc myDocument - err = bson.UnmarshalWithRegistry(reg, b, &doc) + err = bson.UnmarshalWithRegistry(reg.Build(), b, &doc) if err != nil { panic(err) } diff --git a/bson/registry_test.go b/bson/registry_test.go index 5375b6f444..003bb69d6b 100644 --- a/bson/registry_test.go +++ b/bson/registry_test.go @@ -15,508 +15,247 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" ) -// newTestRegistry creates a new empty Registry. -func newTestRegistry() *Registry { - return &Registry{ - typeEncoders: new(typeEncoderCache), - typeDecoders: new(typeDecoderCache), - kindEncoders: new(kindEncoderCache), - kindDecoders: new(kindDecoderCache), +// newTestRegistryBuilder creates a new empty RegistryBuilder. +func newTestRegistryBuilder() *RegistryBuilder { + return &RegistryBuilder{ + typeEncoders: make(map[reflect.Type]EncoderFactory), + typeDecoders: new(typeDecoderCache), + interfaceEncoders: make(map[reflect.Type]EncoderFactory), + kindDecoders: new(kindDecoderCache), + typeMap: make(map[Type]reflect.Type), } } func TestRegistryBuilder(t *testing.T) { + t.Parallel() + t.Run("Register", func(t *testing.T) { + t.Parallel() + fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) t.Run("interface", func(t *testing.T) { - var t1f *testInterface1 - var t2f *testInterface2 - var t4f *testInterface4 - ips := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + t.Parallel() + + t1f, t2f, t3f, t4f := + reflect.TypeOf((*testInterface1)(nil)).Elem(), + reflect.TypeOf((*testInterface2)(nil)).Elem(), + reflect.TypeOf((*testInterface3)(nil)).Elem(), + reflect.TypeOf((*testInterface4)(nil)).Elem() + + var c1, c2, c3, c4 int + ef1 := func() ValueEncoder { + c1++ + return fc1 + } + ef2 := func() ValueEncoder { + c2++ + return fc2 + } + ef3 := func() ValueEncoder { + c3++ + return fc3 + } + ef4 := func() ValueEncoder { + c4++ + return fc4 + } + + ips := []struct { + i reflect.Type + ef EncoderFactory + }{ + {i: t1f, ef: ef1}, + {i: t2f, ef: ef2}, + {i: t1f, ef: ef3}, + {i: t3f, ef: ef2}, + {i: t4f, ef: ef4}, } want := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + {i: t1f, ve: fc3}, {i: t2f, ve: fc2}, + {i: t3f, ve: fc2}, {i: t4f, ve: fc4}, } - reg := newTestRegistry() + + rb := newTestRegistryBuilder() for _, ip := range ips { - reg.RegisterInterfaceEncoder(ip.i, ip.ve) + rb.RegisterInterfaceEncoder(ip.i, ip.ef) } + reg := rb.Build() - got := reg.interfaceEncoders - if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { - t.Errorf("the registered interfaces are not correct: got %#v, want %#v", got, want) + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) } - }) - t.Run("type", func(t *testing.T) { - ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) - reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) - reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) - want := []struct { - t reflect.Type - c ValueEncoder - }{ - {reflect.TypeOf(ft1), fc3}, - {reflect.TypeOf(ft2), fc2}, - {reflect.TypeOf(ft4), fc4}, + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) } - - got := reg.typeEncoders - for _, s := range want { - wantT, wantC := s.t, s.c - gotC, exists := got.Load(wantT) - if !exists { - t.Errorf("Did not find type in the type registry: %v", wantT) - } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) - } + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) } - }) - t.Run("kind", func(t *testing.T) { - k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistry() - reg.RegisterKindEncoder(k1, fc1) - reg.RegisterKindEncoder(k2, fc2) - reg.RegisterKindEncoder(k1, fc3) - reg.RegisterKindEncoder(k4, fc4) - want := []struct { - k reflect.Kind - c ValueEncoder - }{ - {k1, fc3}, - {k2, fc2}, - {k4, fc4}, + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) + } + codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") + } + got := make(map[reflect.Type]ValueEncoder) + for _, e := range reg.interfaceEncoders { + got[e.i] = e.ve } - - got := reg.kindEncoders for _, s := range want { - wantK, wantC := s.k, s.c - gotC, exists := got.Load(wantK) + wantI, wantVe := s.i, s.ve + gotVe, exists := got[wantI] if !exists { - t.Errorf("Did not find kind in the kind registry: %v", wantK) + t.Errorf("Did not find type in the type registry: %v", wantI) } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) + if !cmp.Equal(gotVe, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v; want %#v", gotVe, wantVe) } } }) - t.Run("RegisterDefault", func(t *testing.T) { - t.Run("MapCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - - reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders.get(reflect.Map) != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) - } - - reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders.get(reflect.Map) != codec2 { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) - } - }) - t.Run("StructCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - - reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders.get(reflect.Struct) != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) - } - - reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders.get(reflect.Struct) != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) - } - }) - t.Run("SliceCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - - reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders.get(reflect.Slice) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) - } - - reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders.get(reflect.Slice) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) - } - }) - t.Run("ArrayCodec", func(t *testing.T) { - codec := &fakeCodec{num: 1} - codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() + t.Run("type", func(t *testing.T) { + t.Parallel() - reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders.get(reflect.Array) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) - } + ft1, ft2, ft3, ft4 := + reflect.TypeOf(fakeType1{}), + reflect.TypeOf(fakeType2{}), + reflect.TypeOf(fakeType3{}), + reflect.TypeOf(fakeType4{}) - reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders.get(reflect.Array) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) - } - }) - }) - t.Run("Lookup", func(t *testing.T) { - type Codec interface { - ValueEncoder - ValueDecoder + var c1, c2, c3, c4 int + ef1 := func() ValueEncoder { + c1++ + return fc1 + } + ef2 := func() ValueEncoder { + c2++ + return fc2 + } + ef3 := func() ValueEncoder { + c3++ + return fc3 + } + ef4 := func() ValueEncoder { + c4++ + return fc4 } - var ( - arrinstance [12]int - arr = reflect.TypeOf(arrinstance) - slc = reflect.TypeOf(make([]int, 12)) - m = reflect.TypeOf(make(map[string]int)) - strct = reflect.TypeOf(struct{ Foo string }{}) - ft1 = reflect.PtrTo(reflect.TypeOf(fakeType1{})) - ft2 = reflect.TypeOf(fakeType2{}) - ft3 = reflect.TypeOf(fakeType5(func(string, string) string { return "fakeType5" })) - ti1 = reflect.TypeOf((*testInterface1)(nil)).Elem() - ti2 = reflect.TypeOf((*testInterface2)(nil)).Elem() - ti1Impl = reflect.TypeOf(testInterface1Impl{}) - ti2Impl = reflect.TypeOf(testInterface2Impl{}) - ti3 = reflect.TypeOf((*testInterface3)(nil)).Elem() - ti3Impl = reflect.TypeOf(testInterface3Impl{}) - ti3ImplPtr = reflect.TypeOf((*testInterface3Impl)(nil)) - fc1, fc2 = &fakeCodec{num: 1}, &fakeCodec{num: 2} - fsc, fslcc, fmc = new(fakeStructCodec), new(fakeSliceCodec), new(fakeMapCodec) - pc = &pointerCodec{} - ) - - reg := newTestRegistry() - reg.RegisterTypeEncoder(ft1, fc1) - reg.RegisterTypeEncoder(ft2, fc2) - reg.RegisterTypeEncoder(ti1, fc1) - reg.RegisterKindEncoder(reflect.Struct, fsc) - reg.RegisterKindEncoder(reflect.Slice, fslcc) - reg.RegisterKindEncoder(reflect.Array, fslcc) - reg.RegisterKindEncoder(reflect.Map, fmc) - reg.RegisterKindEncoder(reflect.Ptr, pc) - reg.RegisterTypeDecoder(ft1, fc1) - reg.RegisterTypeDecoder(ft2, fc2) - reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder - reg.RegisterKindDecoder(reflect.Struct, fsc) - reg.RegisterKindDecoder(reflect.Slice, fslcc) - reg.RegisterKindDecoder(reflect.Array, fslcc) - reg.RegisterKindDecoder(reflect.Map, fmc) - reg.RegisterKindDecoder(reflect.Ptr, pc) - reg.RegisterInterfaceEncoder(ti2, fc2) - reg.RegisterInterfaceEncoder(ti3, fc3) - reg.RegisterInterfaceDecoder(ti2, fc2) - reg.RegisterInterfaceDecoder(ti3, fc3) - - testCases := []struct { - name string - t reflect.Type - wantcodec Codec - wanterr error - testcache bool + ips := []struct { + i reflect.Type + ef EncoderFactory }{ - { - "type registry (pointer)", - ft1, - fc1, - nil, - false, - }, - { - "type registry (non-pointer)", - ft2, - fc2, - nil, - false, - }, - { - // lookup an interface type and expect that the registered encoder is returned - "interface with type encoder", - ti1, - fc1, - nil, - true, - }, - { - // lookup a type that implements an interface and expect that the default struct codec is returned - "interface implementation with type encoder", - ti1Impl, - fsc, - nil, - false, - }, - { - // lookup an interface type and expect that the registered hook is returned - "interface with hook", - ti2, - fc2, - nil, - false, - }, - { - // lookup a type that implements an interface and expect that the registered hook is returned - "interface implementation with hook", - ti2Impl, - fc2, - nil, - false, - }, - { - // lookup a pointer to a type where the pointer implements an interface and expect that the - // registered hook is returned - "interface pointer to implementation with hook (pointer)", - ti3ImplPtr, - fc3, - nil, - false, - }, - { - "default struct codec (pointer)", - reflect.PtrTo(strct), - pc, - nil, - false, - }, - { - "default struct codec (non-pointer)", - strct, - fsc, - nil, - false, - }, - { - "default array codec", - arr, - fslcc, - nil, - false, - }, - { - "default slice codec", - slc, - fslcc, - nil, - false, - }, - { - "default map", - m, - fmc, - nil, - false, - }, - { - "map non-string key", - reflect.TypeOf(map[int]int{}), - fmc, - nil, - false, - }, - { - "No Codec Registered", - ft3, - nil, - ErrNoEncoder{Type: ft3}, - false, - }, + {i: ft1, ef: ef1}, + {i: ft2, ef: ef2}, + {i: ft1, ef: ef3}, + {i: ft3, ef: ef2}, + {i: ft4, ef: ef4}, } - - allowunexported := cmp.AllowUnexported(fakeCodec{}, fakeStructCodec{}, fakeSliceCodec{}, fakeMapCodec{}) - comparepc := func(pc1, pc2 *pointerCodec) bool { return true } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Run("Encoder", func(t *testing.T) { - gotcodec, goterr := reg.LookupEncoder(tc.t) - if !cmp.Equal(goterr, tc.wanterr, cmp.Comparer(assert.CompareErrors)) { - t.Errorf("errors did not match: got %#v, want %#v", goterr, tc.wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec) - } - }) - t.Run("Decoder", func(t *testing.T) { - wanterr := tc.wanterr - if ene, ok := tc.wanterr.(ErrNoEncoder); ok { - wanterr = ErrNoDecoder(ene) - } - - gotcodec, goterr := reg.LookupDecoder(tc.t) - if !cmp.Equal(goterr, wanterr, cmp.Comparer(assert.CompareErrors)) { - t.Errorf("errors did not match: got %#v, want %#v", goterr, wanterr) - } - if !cmp.Equal(gotcodec, tc.wantcodec, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("codecs did not match: got %#v, want %#v", gotcodec, tc.wantcodec) - } - }) - }) + want := []interfaceValueEncoder{ + {i: ft1, ve: fc3}, {i: ft2, ve: fc2}, + {i: ft3, ve: fc2}, {i: ft4, ve: fc4}, } - // lookup a type whose pointer implements an interface and expect that the registered hook is - // returned - t.Run("interface implementation with hook (pointer)", func(t *testing.T) { - t.Run("Encoder", func(t *testing.T) { - gotEnc, err := reg.LookupEncoder(ti3Impl) - assert.Nil(t, err, "LookupEncoder error: %v", err) - cae, ok := gotEnc.(*condAddrEncoder) - assert.True(t, ok, "Expected CondAddrEncoder, got %T", gotEnc) - if !cmp.Equal(cae.canAddrEnc, fc3, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected canAddrEnc %#v, got %#v", cae.canAddrEnc, fc3) - } - if !cmp.Equal(cae.elseEnc, fsc, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected elseEnc %#v, got %#v", cae.elseEnc, fsc) - } - }) - t.Run("Decoder", func(t *testing.T) { - gotDec, err := reg.LookupDecoder(ti3Impl) - assert.Nil(t, err, "LookupDecoder error: %v", err) - - cad, ok := gotDec.(*condAddrDecoder) - assert.True(t, ok, "Expected CondAddrDecoder, got %T", gotDec) - if !cmp.Equal(cad.canAddrDec, fc3, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected canAddrDec %#v, got %#v", cad.canAddrDec, fc3) - } - if !cmp.Equal(cad.elseDec, fsc, allowunexported, cmp.Comparer(comparepc)) { - t.Errorf("expected elseDec %#v, got %#v", cad.elseDec, fsc) - } - }) - }) - }) - }) - t.Run("Type Map", func(t *testing.T) { - reg := newTestRegistry() - reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) - reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) - - var got, want reflect.Type - - want = reflect.TypeOf("") - got, err := reg.LookupTypeMapEntry(TypeString) - noerr(t, err) - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - - want = reflect.TypeOf(int(0)) - got, err = reg.LookupTypeMapEntry(TypeInt32) - noerr(t, err) - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - - want = nil - wanterr := ErrNoTypeMapEntry{Type: TypeObjectID} - got, err = reg.LookupTypeMapEntry(TypeObjectID) - if !errors.Is(err, wanterr) { - t.Errorf("did not get expected error: got %#v, want %#v", err, wanterr) - } - if got != want { - t.Errorf("unexpected type: got %#v, want %#v", got, want) - } - }) -} - -func TestRegistry(t *testing.T) { - t.Parallel() - - t.Run("Register", func(t *testing.T) { - t.Parallel() - - fc1, fc2, fc3, fc4 := new(fakeCodec), new(fakeCodec), new(fakeCodec), new(fakeCodec) - t.Run("interface", func(t *testing.T) { - t.Parallel() + rb := newTestRegistryBuilder() + for _, ip := range ips { + rb.RegisterTypeEncoder(ip.i, ip.ef) + } + reg := rb.Build() - var t1f *testInterface1 - var t2f *testInterface2 - var t4f *testInterface4 - ips := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc1}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) } - want := []interfaceValueEncoder{ - {i: reflect.TypeOf(t1f).Elem(), ve: fc3}, - {i: reflect.TypeOf(t2f).Elem(), ve: fc2}, - {i: reflect.TypeOf(t4f).Elem(), ve: fc4}, + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) } - reg := newTestRegistry() - for _, ip := range ips { - reg.RegisterInterfaceEncoder(ip.i, ip.ve) + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) } - got := reg.interfaceEncoders - if !cmp.Equal(got, want, cmp.AllowUnexported(interfaceValueEncoder{}, fakeCodec{}), cmp.Comparer(typeComparer)) { - t.Errorf("registered interfaces are not correct: got %#v, want %#v", got, want) + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) } - }) - t.Run("type", func(t *testing.T) { - t.Parallel() - - ft1, ft2, ft4 := fakeType1{}, fakeType2{}, fakeType4{} - reg := newTestRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc1) - reg.RegisterTypeEncoder(reflect.TypeOf(ft2), fc2) - reg.RegisterTypeEncoder(reflect.TypeOf(ft1), fc3) - reg.RegisterTypeEncoder(reflect.TypeOf(ft4), fc4) - - want := []struct { - t reflect.Type - c ValueEncoder - }{ - {reflect.TypeOf(ft1), fc3}, - {reflect.TypeOf(ft2), fc2}, - {reflect.TypeOf(ft4), fc4}, + codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") } got := reg.typeEncoders for _, s := range want { - wantT, wantC := s.t, s.c - gotC, exists := got.Load(wantT) + wantI, wantVe := s.i, s.ve + gotVe, exists := got.Load(wantI) if !exists { - t.Errorf("type missing in registry: %v", wantT) + t.Errorf("type missing in registry: %v", wantI) } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v; want %#v", gotC, wantC) + if !cmp.Equal(gotVe, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v; want %#v", gotVe, wantVe) } } }) t.Run("kind", func(t *testing.T) { t.Parallel() - k1, k2, k4 := reflect.Struct, reflect.Slice, reflect.Map - reg := newTestRegistry() - reg.RegisterKindEncoder(k1, fc1) - reg.RegisterKindEncoder(k2, fc2) - reg.RegisterKindEncoder(k1, fc3) - reg.RegisterKindEncoder(k4, fc4) + k1, k2, k3, k4 := reflect.Struct, reflect.Slice, reflect.Int, reflect.Map + var c1, c2, c3, c4 int + ef1 := func() ValueEncoder { + c1++ + return fc1 + } + ef2 := func() ValueEncoder { + c2++ + return fc2 + } + ef3 := func() ValueEncoder { + c3++ + return fc3 + } + ef4 := func() ValueEncoder { + c4++ + return fc4 + } + + ips := []struct { + k reflect.Kind + ef EncoderFactory + }{ + {k: k1, ef: ef1}, + {k: k2, ef: ef2}, + {k: k1, ef: ef3}, + {k: k3, ef: ef2}, + {k: k4, ef: ef4}, + } want := []struct { k reflect.Kind c ValueEncoder }{ - {k1, fc3}, - {k2, fc2}, - {k4, fc4}, + {k1, fc3}, {k2, fc2}, {k4, fc4}, + } + + rb := newTestRegistryBuilder() + for _, ip := range ips { + rb.RegisterKindEncoder(ip.k, ip.ef) + } + reg := rb.Build() + + if !cmp.Equal(c1, 0) { + t.Errorf("ef1 is called %d time(s); expected 0", c1) + } + if !cmp.Equal(c2, 1) { + t.Errorf("ef2 is called %d time(s); expected 1", c2) + } + if !cmp.Equal(c3, 1) { + t.Errorf("ef3 is called %d time(s); expected 1", c3) + } + if !cmp.Equal(c4, 1) { + t.Errorf("ef4 is called %d time(s); expected 1", c4) + } + codecs, ok := reg.encoderTypeMap[reflect.TypeOf((*fakeCodec)(nil))] + if !cmp.Equal(len(reg.encoderTypeMap), 1) || !cmp.Equal(ok, true) || len(codecs) != 3 { + t.Errorf("codecs were not cached correctly") } got := reg.kindEncoders for _, s := range want { - wantK, wantC := s.k, s.c - gotC, exists := got.Load(wantK) - if !exists { - t.Errorf("type missing in registry: %v", wantK) - } - if !cmp.Equal(gotC, wantC, cmp.AllowUnexported(fakeCodec{})) { - t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantC) + wantI, wantVe := s.k, s.c + gotC := got[wantI] + if !cmp.Equal(gotC, wantVe, cmp.AllowUnexported(fakeCodec{})) { + t.Errorf("codecs did not match: got %#v, want %#v", gotC, wantVe) } } }) @@ -528,14 +267,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Map, codec) - if reg.kindEncoders.get(reflect.Map) != codec { - t.Errorf("map codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Map]; got != codec { + t.Errorf("map codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Map, codec2) - if reg.kindEncoders.get(reflect.Map) != codec2 { - t.Errorf("map codec properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Map), codec2) + + rb.RegisterKindEncoder(reflect.Map, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Map]; got != codec2 { + t.Errorf("map codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("StructCodec", func(t *testing.T) { @@ -543,14 +286,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Struct, codec) - if reg.kindEncoders.get(reflect.Struct) != codec { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Struct]; got != codec { + t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Struct, codec2) - if reg.kindEncoders.get(reflect.Struct) != codec2 { - t.Errorf("struct codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Struct), codec2) + + rb.RegisterKindEncoder(reflect.Struct, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Struct]; got != codec2 { + t.Errorf("struct codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("SliceCodec", func(t *testing.T) { @@ -558,14 +305,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Slice, codec) - if reg.kindEncoders.get(reflect.Slice) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Slice]; got != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Slice, codec2) - if reg.kindEncoders.get(reflect.Slice) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Slice), codec2) + + rb.RegisterKindEncoder(reflect.Slice, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Slice]; got != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) } }) t.Run("ArrayCodec", func(t *testing.T) { @@ -573,14 +324,18 @@ func TestRegistry(t *testing.T) { codec := &fakeCodec{num: 1} codec2 := &fakeCodec{num: 2} - reg := newTestRegistry() - reg.RegisterKindEncoder(reflect.Array, codec) - if reg.kindEncoders.get(reflect.Array) != codec { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec) + rb := newTestRegistryBuilder() + + rb.RegisterKindEncoder(reflect.Array, func() ValueEncoder { return codec }) + reg := rb.Build() + if got := reg.kindEncoders[reflect.Array]; got != codec { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec) } - reg.RegisterKindEncoder(reflect.Array, codec2) - if reg.kindEncoders.get(reflect.Array) != codec2 { - t.Errorf("slice codec not properly set: got %#v, want %#v", reg.kindEncoders.get(reflect.Array), codec2) + + rb.RegisterKindEncoder(reflect.Array, func() ValueEncoder { return codec2 }) + reg = rb.Build() + if got := reg.kindEncoders[reflect.Array]; got != codec2 { + t.Errorf("slice codec not properly set: got %#v, want %#v", got, codec2) } }) }) @@ -613,27 +368,36 @@ func TestRegistry(t *testing.T) { pc = &pointerCodec{} ) - reg := newTestRegistry() - reg.RegisterTypeEncoder(ft1, fc1) - reg.RegisterTypeEncoder(ft2, fc2) - reg.RegisterTypeEncoder(ti1, fc1) - reg.RegisterKindEncoder(reflect.Struct, fsc) - reg.RegisterKindEncoder(reflect.Slice, fslcc) - reg.RegisterKindEncoder(reflect.Array, fslcc) - reg.RegisterKindEncoder(reflect.Map, fmc) - reg.RegisterKindEncoder(reflect.Ptr, pc) - reg.RegisterTypeDecoder(ft1, fc1) - reg.RegisterTypeDecoder(ft2, fc2) - reg.RegisterTypeDecoder(ti1, fc1) // values whose exact type is testInterface1 will use fc1 encoder - reg.RegisterKindDecoder(reflect.Struct, fsc) - reg.RegisterKindDecoder(reflect.Slice, fslcc) - reg.RegisterKindDecoder(reflect.Array, fslcc) - reg.RegisterKindDecoder(reflect.Map, fmc) - reg.RegisterKindDecoder(reflect.Ptr, pc) - reg.RegisterInterfaceEncoder(ti2, fc2) - reg.RegisterInterfaceDecoder(ti2, fc2) - reg.RegisterInterfaceEncoder(ti3, fc3) - reg.RegisterInterfaceDecoder(ti3, fc3) + fc1EncFac := func() ValueEncoder { return fc1 } + fc2EncFac := func() ValueEncoder { return fc2 } + fc3EncFac := func() ValueEncoder { return fc3 } + fscEncFac := func() ValueEncoder { return fsc } + fslccEncFac := func() ValueEncoder { return fslcc } + fmcEncFac := func() ValueEncoder { return fmc } + pcEncFac := func() ValueEncoder { return pc } + + reg := newTestRegistryBuilder(). + RegisterTypeEncoder(ft1, fc1EncFac). + RegisterTypeEncoder(ft2, fc2EncFac). + RegisterTypeEncoder(ti1, fc1EncFac). + RegisterKindEncoder(reflect.Struct, fscEncFac). + RegisterKindEncoder(reflect.Slice, fslccEncFac). + RegisterKindEncoder(reflect.Array, fslccEncFac). + RegisterKindEncoder(reflect.Map, fmcEncFac). + RegisterKindEncoder(reflect.Ptr, pcEncFac). + RegisterTypeDecoder(ft1, fc1). + RegisterTypeDecoder(ft2, fc2). + RegisterTypeDecoder(ti1, fc1). // values whose exact type is testInterface1 will use fc1 encoder + RegisterKindDecoder(reflect.Struct, fsc). + RegisterKindDecoder(reflect.Slice, fslcc). + RegisterKindDecoder(reflect.Array, fslcc). + RegisterKindDecoder(reflect.Map, fmc). + RegisterKindDecoder(reflect.Ptr, pc). + RegisterInterfaceEncoder(ti2, fc2EncFac). + RegisterInterfaceEncoder(ti3, fc3EncFac). + RegisterInterfaceDecoder(ti2, fc2). + RegisterInterfaceDecoder(ti3, fc3). + Build() testCases := []struct { name string @@ -854,9 +618,10 @@ func TestRegistry(t *testing.T) { }) t.Run("Type Map", func(t *testing.T) { t.Parallel() - reg := newTestRegistry() - reg.RegisterTypeMapEntry(TypeString, reflect.TypeOf("")) - reg.RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))) + reg := newTestRegistryBuilder(). + RegisterTypeMapEntry(TypeString, reflect.TypeOf("")). + RegisterTypeMapEntry(TypeInt32, reflect.TypeOf(int(0))). + Build() var got, want reflect.Type @@ -886,12 +651,6 @@ func TestRegistry(t *testing.T) { }) } -// get is only for testing as it does return if the value was found -func (c *kindEncoderCache) get(rt reflect.Kind) ValueEncoder { - e, _ := c.Load(rt) - return e -} - func BenchmarkLookupEncoder(b *testing.B) { type childStruct struct { V1, V2, V3, V4 int @@ -908,10 +667,11 @@ func BenchmarkLookupEncoder(b *testing.B) { reflect.TypeOf(&testInterface1Impl{}), reflect.TypeOf(&nestedStruct{}), } - r := NewRegistry() + rb := NewRegistryBuilder() for _, typ := range types { - r.RegisterTypeEncoder(typ, &fakeCodec{}) + rb.RegisterTypeEncoder(typ, func() ValueEncoder { return &fakeCodec{} }) } + r := rb.Build() b.Run("Serial", func(b *testing.B) { for i := 0; i < b.N; i++ { _, err := r.LookupEncoder(types[i%len(types)]) @@ -934,6 +694,7 @@ func BenchmarkLookupEncoder(b *testing.B) { type fakeType1 struct{} type fakeType2 struct{} +type fakeType3 struct{} type fakeType4 struct{} type fakeType5 func(string, string) string type fakeStructCodec struct{ *fakeCodec } @@ -948,7 +709,7 @@ type fakeCodec struct { num int } -func (*fakeCodec) EncodeValue(*Registry, ValueWriter, reflect.Value) error { +func (*fakeCodec) EncodeValue(EncoderRegistry, ValueWriter, reflect.Value) error { return nil } func (*fakeCodec) DecodeValue(DecodeContext, ValueReader, reflect.Value) error { @@ -977,5 +738,3 @@ type testInterface3Impl struct{} var _ testInterface3 = (*testInterface3Impl)(nil) func (*testInterface3Impl) test3() {} - -func typeComparer(i1, i2 reflect.Type) bool { return i1 == i2 } diff --git a/bson/setter_getter.go b/bson/setter_getter.go index 069408c9ab..46706241be 100644 --- a/bson/setter_getter.go +++ b/bson/setter_getter.go @@ -84,7 +84,7 @@ func SetterDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error } // GetterEncodeValue is the ValueEncoderFunc for Getter types. -func GetterEncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func GetterEncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { // Either val or a pointer to val must implement Getter switch { case !val.IsValid(): diff --git a/bson/slice_codec.go b/bson/slice_codec.go index d7db3cf9da..3640cdd124 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -24,7 +24,7 @@ type sliceCodec struct { } // EncodeValue is the ValueEncoder for slice types. -func (sc sliceCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (sc sliceCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { return ValueEncoderError{Name: "SliceEncodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } diff --git a/bson/string_codec.go b/bson/string_codec.go index de73fc6f0d..9f1ee76136 100644 --- a/bson/string_codec.go +++ b/bson/string_codec.go @@ -24,7 +24,7 @@ var ( ) // EncodeValue is the ValueEncoder for string types. -func (sc *stringCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (sc *stringCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if val.Kind() != reflect.String { return ValueEncoderError{ Name: "StringEncodeValue", diff --git a/bson/struct_codec.go b/bson/struct_codec.go index b489b1dc56..c3ddd4f2c6 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -88,12 +88,12 @@ func newStructCodec(p StructTagParser) *structCodec { } // EncodeValue handles encoding generic struct types. -func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Value) error { +func (sc *structCodec) EncodeValue(reg EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Struct { return ValueEncoderError{Name: "structCodec.EncodeValue", Kinds: []reflect.Kind{reflect.Struct}, Received: val} } - sd, err := sc.describeStruct(reg, val.Type(), sc.useJSONStructTags, !sc.overwriteDuplicatedInlinedFields) + sd, err := sc.describeStruct(val.Type(), sc.useJSONStructTags, !sc.overwriteDuplicatedInlinedFields) if err != nil { return err } @@ -113,7 +113,12 @@ func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Va } } - desc.encoder, rv, err = lookupElementEncoder(reg, desc.encoder, rv) + var encoder ValueEncoder + if encoder, err = reg.LookupEncoder(desc.fieldType); err != nil { + encoder = nil + } + + encoder, rv, err = lookupElementEncoder(reg, encoder, rv) if err != nil && !errors.Is(err, errInvalidValue) { return err @@ -134,12 +139,10 @@ func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Va continue } - if desc.encoder == nil { + if encoder == nil { return ErrNoEncoder{Type: rv.Type()} } - encoder := desc.encoder - var empty bool if cz, ok := encoder.(CodecZeroer); ok { empty = cz.IsTypeZero(rv.Interface()) @@ -160,12 +163,7 @@ func (sc *structCodec) EncodeValue(reg *Registry, vw ValueWriter, val reflect.Va } // defaultUIntCodec.encodeToMinSize = desc.minSize - switch v := encoder.(type) { - case *uintCodec: - encoder = &uintCodec{ - encodeToMinSize: v.encodeToMinSize || desc.minSize, - } - case *intCodec: + if v, ok := encoder.(*intCodec); ok { encoder = &intCodec{ encodeToMinSize: v.encodeToMinSize || desc.minSize, } @@ -231,7 +229,7 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type()) } - sd, err := sc.describeStruct(dc.Registry, val.Type(), dc.useJSONStructTags, false) + sd, err := sc.describeStruct(val.Type(), dc.useJSONStructTags, false) if err != nil { return err } @@ -330,11 +328,12 @@ func (sc *structCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect zeroStructs: dc.zeroStructs, } - if fd.decoder == nil { + decoder, err := dc.Registry.LookupDecoder(fd.fieldType) + if err != nil { return newDecodeError(fd.name, ErrNoDecoder{Type: field.Elem().Type()}) } - err = fd.decoder.DecodeValue(dctx, vr, field.Elem()) + err = decoder.DecodeValue(dctx, vr, field.Elem()) if err != nil { return newDecodeError(fd.name, err) } @@ -389,8 +388,7 @@ type fieldDescription struct { minSize bool truncate bool inline []int - encoder ValueEncoder - decoder ValueDecoder + fieldType reflect.Type } type byIndex []fieldDescription @@ -423,7 +421,6 @@ func (bi byIndex) Less(i, j int) bool { } func (sc *structCodec) describeStruct( - r *Registry, t reflect.Type, useJSONStructTags bool, errorOnDuplicates bool, @@ -435,7 +432,7 @@ func (sc *structCodec) describeStruct( } // TODO(charlie): Only describe the struct once when called // concurrently with the same type. - ds, err := sc.describeStructSlow(r, t, useJSONStructTags, errorOnDuplicates) + ds, err := sc.describeStructSlow(t, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } @@ -446,7 +443,6 @@ func (sc *structCodec) describeStruct( } func (sc *structCodec) describeStructSlow( - r *Registry, t reflect.Type, useJSONStructTags bool, errorOnDuplicates bool, @@ -467,23 +463,15 @@ func (sc *structCodec) describeStructSlow( } sfType := sf.Type - encoder, err := r.LookupEncoder(sfType) - if err != nil { - encoder = nil - } - decoder, err := r.LookupDecoder(sfType) - if err != nil { - decoder = nil - } description := fieldDescription{ fieldName: sf.Name, idx: i, - encoder: encoder, - decoder: decoder, + fieldType: sfType, } var stags StructTags + var err error // If the caller requested that we use JSON struct tags, use the JSONFallbackStructTagParser // instead of the parser defined on the codec. if useJSONStructTags { @@ -520,7 +508,7 @@ func (sc *structCodec) describeStructSlow( } fallthrough case reflect.Struct: - inlinesf, err := sc.describeStruct(r, sfType, useJSONStructTags, errorOnDuplicates) + inlinesf, err := sc.describeStruct(sfType, useJSONStructTags, errorOnDuplicates) if err != nil { return nil, err } diff --git a/bson/time_codec.go b/bson/time_codec.go index 535861ed71..d9bb57404b 100644 --- a/bson/time_codec.go +++ b/bson/time_codec.go @@ -99,7 +99,7 @@ func (tc *timeCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.V } // EncodeValue is the ValueEncoderFunc for time.TIme. -func (tc *timeCodec) EncodeValue(_ *Registry, vw ValueWriter, val reflect.Value) error { +func (tc *timeCodec) EncodeValue(_ EncoderRegistry, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Type() != tTime { return ValueEncoderError{Name: "TimeEncodeValue", Types: []reflect.Type{tTime}, Received: val} } diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 0871237386..d8ef9b69ba 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -228,8 +228,9 @@ func TestCachingDecodersNotSharedAcrossRegistries(t *testing.T) { val.SetInt(int64(-1 * i32)) return nil } - customReg := NewRegistry() - customReg.RegisterTypeDecoder(tInt32, decodeInt32) + customReg := NewRegistryBuilder(). + RegisterTypeDecoder(tInt32, decodeInt32). + Build() docBytes := bsoncore.BuildDocumentFromElements( nil, diff --git a/bson/unmarshal_value_test.go b/bson/unmarshal_value_test.go index fd379b5daa..3af7578d12 100644 --- a/bson/unmarshal_value_test.go +++ b/bson/unmarshal_value_test.go @@ -75,8 +75,9 @@ func TestUnmarshalValue(t *testing.T) { bytes: bsoncore.AppendString(nil, "hello world"), }, } - reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) + reg := NewRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}). + Build() for _, tc := range testCases { tc := tc @@ -110,8 +111,9 @@ func BenchmarkSliceCodecUnmarshal(b *testing.B) { bytes: bsoncore.AppendString(nil, strings.Repeat("t", 4096)), }, } - reg := NewRegistry() - reg.RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}) + reg := NewRegistryBuilder(). + RegisterTypeDecoder(reflect.TypeOf([]byte{}), &sliceCodec{}). + Build() for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { b.RunParallel(func(pb *testing.PB) { diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index a098880b4c..d3ebb1421b 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -39,7 +39,7 @@ type negateCodec struct { ID int64 `bson:"_id"` } -func (e *negateCodec) EncodeValue(_ *bson.Registry, vw bson.ValueWriter, val reflect.Value) error { +func (e *negateCodec) EncodeValue(_ bson.EncoderRegistry, vw bson.ValueWriter, val reflect.Value) error { return vw.WriteInt64(val.Int()) } @@ -100,9 +100,10 @@ func (sc *slowConn) Read(b []byte) (n int, err error) { func TestClient(t *testing.T) { mt := mtest.New(t, noClientOpts) - reg := bson.NewRegistry() - reg.RegisterTypeEncoder(reflect.TypeOf(int64(0)), &negateCodec{}) - reg.RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}) + reg := bson.NewRegistryBuilder(). + RegisterTypeEncoder(reflect.TypeOf(int64(0)), func() bson.ValueEncoder { return &negateCodec{} }). + RegisterTypeDecoder(reflect.TypeOf(int64(0)), &negateCodec{}). + Build() registryOpts := options.Client(). SetRegistry(reg) mt.RunOpts("registry passed to cursors", mtest.NewOptions().ClientOptions(registryOpts), func(mt *mtest.T) { diff --git a/internal/integration/crud_spec_test.go b/internal/integration/crud_spec_test.go index e6583f8ade..996cdd27f4 100644 --- a/internal/integration/crud_spec_test.go +++ b/internal/integration/crud_spec_test.go @@ -55,11 +55,9 @@ type crudOutcome struct { Collection *outcomeCollection `bson:"collection"` } -var crudRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - return reg -}() +var crudRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + Build() func TestCrudSpec(t *testing.T) { for _, dir := range []string{crudReadDir, crudWriteDir} { diff --git a/internal/integration/database_test.go b/internal/integration/database_test.go index 12c2e0cd53..da043a6636 100644 --- a/internal/integration/database_test.go +++ b/internal/integration/database_test.go @@ -29,11 +29,9 @@ const ( ) var ( - interfaceAsMapRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.M{})) - return reg - }() + interfaceAsMapRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.M{})). + Build() ) func TestDatabase(t *testing.T) { diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index c9199f6135..487714d834 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -181,12 +181,10 @@ var directories = []string{ } var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local()) -var specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - reg.RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)) - return reg -}() +var specTestRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + RegisterTypeDecoder(reflect.TypeOf(testData{}), bson.ValueDecoderFunc(decodeTestData)). + Build() func TestUnifiedSpecs(t *testing.T) { for _, specDir := range directories { diff --git a/mongo/database_test.go b/mongo/database_test.go index 31bd900439..1142b6df9c 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -53,7 +53,7 @@ func TestDatabase(t *testing.T) { wc2 := &writeconcern.WriteConcern{W: 10} rcLocal := readconcern.Local() rcMajority := readconcern.Majority() - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder().Build() opts := options.Database().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetWriteConcern(wc1). SetReadPreference(rpSecondary).SetReadConcern(rcMajority).SetWriteConcern(wc2).SetRegistry(reg) @@ -70,7 +70,7 @@ func TestDatabase(t *testing.T) { rpPrimary := readpref.Primary() rcLocal := readconcern.Local() wc1 := &writeconcern.WriteConcern{W: 10} - reg := bson.NewRegistry() + reg := bson.NewRegistryBuilder().Build() client := setupClient(options.Client().SetReadPreference(rpPrimary).SetReadConcern(rcLocal).SetRegistry(reg)) got := client.Database("foo", options.Database().SetWriteConcern(wc1)) diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index beba45514f..078c029308 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -80,7 +80,7 @@ func TestClientOptions(t *testing.T) { {"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false}, {"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false}, {"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false}, - {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistry(), "Registry", false}, + {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false}, {"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true}, {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true}, {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true}, diff --git a/mongo/read_write_concern_spec_test.go b/mongo/read_write_concern_spec_test.go index ec49bb91db..c737f76a9b 100644 --- a/mongo/read_write_concern_spec_test.go +++ b/mongo/read_write_concern_spec_test.go @@ -31,11 +31,9 @@ const ( var ( serverDefaultConcern = []byte{5, 0, 0, 0, 0} // server default read concern and write concern is empty document - specTestRegistry = func() *bson.Registry { - reg := bson.NewRegistry() - reg.RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})) - return reg - }() + specTestRegistry = bson.NewRegistryBuilder(). + RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})). + Build() ) type connectionStringTestFile struct { diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index c02600e232..dca9c0581b 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -17,7 +17,7 @@ import ( "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) -var defaultRegistry = bson.NewRegistry() +var defaultRegistry = bson.NewRegistryBuilder().Build() type serverConfig struct { clock *session.ClusterClock