Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix counter marshal, unmarshall #293

Merged
merged 2 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"gopkg.in/inf.v0"

"github.com/gocql/gocql/marshal/bigint"
"github.com/gocql/gocql/marshal/counter"
"github.com/gocql/gocql/marshal/cqlint"
"github.com/gocql/gocql/marshal/smallint"
"github.com/gocql/gocql/marshal/tinyint"
Expand Down Expand Up @@ -147,7 +148,7 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) {
case TypeBigInt:
return marshalBigInt(value)
case TypeCounter:
return marshalBigIntOld(info, value)
return marshalCounter(value)
case TypeFloat:
return marshalFloat(info, value)
case TypeDouble:
Expand Down Expand Up @@ -247,7 +248,7 @@ func Unmarshal(info TypeInfo, data []byte, value interface{}) error {
case TypeBigInt:
return unmarshalBigInt(data, value)
case TypeCounter:
return unmarshalCounter(info, data, value)
return unmarshalCounter(data, value)
case TypeVarint:
return unmarshalVarint(info, data, value)
case TypeSmallInt:
Expand Down Expand Up @@ -429,6 +430,14 @@ func marshalBigInt(value interface{}) ([]byte, error) {
return nil, wrapMarshalError(err, "marshal error")
}
return data, nil
}

func marshalCounter(value interface{}) ([]byte, error) {
data, err := counter.Marshal(value)
if err != nil {
return nil, wrapMarshalError(err, "marshal error")
}
return data, nil

}

Expand Down Expand Up @@ -509,8 +518,12 @@ func bytesToUint64(data []byte) (ret uint64) {
return ret
}

func unmarshalCounter(info TypeInfo, data []byte, value interface{}) error {
return unmarshalIntlike(info, decBigInt(data), data, value)
func unmarshalCounter(data []byte, value interface{}) error {
err := counter.Unmarshal(data, value)
if err != nil {
return wrapUnmarshalError(err, "unmarshal error")
}
return nil
}

func unmarshalInt(data []byte, value interface{}) error {
Expand Down
74 changes: 74 additions & 0 deletions marshal/counter/marshal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package counter

import (
"math/big"
"reflect"
)

func Marshal(value interface{}) ([]byte, error) {
switch v := value.(type) {
case nil:
return nil, nil
case int8:
return EncInt8(v)
case int16:
return EncInt16(v)
case int32:
return EncInt32(v)
case int64:
return EncInt64(v)
case int:
return EncInt(v)

case uint8:
return EncUint8(v)
case uint16:
return EncUint16(v)
case uint32:
return EncUint32(v)
case uint64:
return EncUint64(v)
case uint:
return EncUint(v)

case big.Int:
return EncBigInt(v)
case string:
return EncString(v)

case *int8:
return EncInt8R(v)
case *int16:
return EncInt16R(v)
case *int32:
return EncInt32R(v)
case *int64:
return EncInt64R(v)
case *int:
return EncIntR(v)

case *uint8:
return EncUint8R(v)
case *uint16:
return EncUint16R(v)
case *uint32:
return EncUint32R(v)
case *uint64:
return EncUint64R(v)
case *uint:
return EncUintR(v)

case *big.Int:
return EncBigIntR(v)
case *string:
return EncStringR(v)
default:
// Custom types (type MyInt int) can be serialized only via `reflect` package.
// Later, when generic-based serialization is introduced we can do that via generics.
rv := reflect.TypeOf(value)
if rv.Kind() != reflect.Ptr {
return EncReflect(reflect.ValueOf(v))
}
return EncReflectR(reflect.ValueOf(v))
}
}
201 changes: 201 additions & 0 deletions marshal/counter/marshal_utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
package counter

import (
"fmt"
"math"
"math/big"
"reflect"
"strconv"
)

var (
maxBigInt = big.NewInt(math.MaxInt64)
minBigInt = big.NewInt(math.MinInt64)
)

func EncInt8(v int8) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, 255, byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, 0, byte(v)}, nil
}

func EncInt8R(v *int8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt8(*v)
}

func EncInt16(v int16) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, 255, 255, byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}

func EncInt16R(v *int16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt16(*v)
}

func EncInt32(v int32) ([]byte, error) {
if v < 0 {
return []byte{255, 255, 255, 255, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}

func EncInt32R(v *int32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt32(*v)
}

func EncInt64(v int64) ([]byte, error) {
return encInt64(v), nil
}

func EncInt64R(v *int64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt64(*v)
}

func EncInt(v int) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}

func EncIntR(v *int) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncInt(*v)
}

func EncUint8(v uint8) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, 0, v}, nil
}

func EncUint8R(v *uint8) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint8(*v)
}

func EncUint16(v uint16) ([]byte, error) {
return []byte{0, 0, 0, 0, 0, 0, byte(v >> 8), byte(v)}, nil
}

func EncUint16R(v *uint16) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint16(*v)
}

func EncUint32(v uint32) ([]byte, error) {
return []byte{0, 0, 0, 0, byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}

func EncUint32R(v *uint32) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint32(*v)
}

func EncUint64(v uint64) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}

func EncUint64R(v *uint64) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint64(*v)
}

func EncUint(v uint) ([]byte, error) {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}, nil
}

func EncUintR(v *uint) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncUint(*v)
}

func EncBigInt(v big.Int) ([]byte, error) {
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal counter: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}

func EncBigIntR(v *big.Int) ([]byte, error) {
if v == nil {
return nil, nil
}
if v.Cmp(maxBigInt) == 1 || v.Cmp(minBigInt) == -1 {
return nil, fmt.Errorf("failed to marshal counter: value (%T)(%s) out of range", v, v.String())
}
return encInt64(v.Int64()), nil
}

func EncString(v string) ([]byte, error) {
if v == "" {
return nil, nil
}

n, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal counter: can not marshal %#v %s", v, err)
}
return encInt64(n), nil
}

func EncStringR(v *string) ([]byte, error) {
if v == nil {
return nil, nil
}
return EncString(*v)
}

func EncReflect(v reflect.Value) ([]byte, error) {
switch v.Kind() {
case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8:
return EncInt64(v.Int())
case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8:
return EncUint64(v.Uint())
case reflect.String:
val := v.String()
if val == "" {
return nil, nil
}
n, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to marshal counter: can not marshal %#v %s", v.Interface(), err)
}
return encInt64(n), nil
default:
return nil, fmt.Errorf("failed to marshal counter: unsupported value type (%T)(%#[1]v)", v.Interface())
}
}

func EncReflectR(v reflect.Value) ([]byte, error) {
if v.IsNil() {
return nil, nil
}
return EncReflect(v.Elem())
}

func encInt64(v int64) []byte {
return []byte{byte(v >> 56), byte(v >> 48), byte(v >> 40), byte(v >> 32), byte(v >> 24), byte(v >> 16), byte(v >> 8), byte(v)}
}
Loading
Loading