Skip to content

Commit

Permalink
Add code-based error mapping for retaining old global error types.
Browse files Browse the repository at this point in the history
  • Loading branch information
knadh committed May 10, 2024
1 parent 001d993 commit 6ce940f
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 18 deletions.
85 changes: 67 additions & 18 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ var (
ErrNil = errors.New("simplesession: nil returned")
)

type errCode interface {
Code() int
}

// NewSession creates a new session. Reads cookie info from `GetCookie“ callback
// and validate the session with current store. If cookie not set then it creates
// new session and calls `SetCookie“ callback. If `DisableAutoSet` is set then it
Expand Down Expand Up @@ -78,7 +82,7 @@ func NewSession(m *Manager, r, w interface{}) (*Session, error) {
// Store also calls `WriteCookie`` to write to http interface
cv, err := m.store.Create()
if err != nil {
return nil, err
return nil, errAs(err)
}

// Write cookie
Expand Down Expand Up @@ -139,7 +143,7 @@ func (s *Session) Create() error {
// Create new cookie in store and write to front.
cv, err := s.manager.store.Create()
if err != nil {
return err
return errAs(err)
}

// Write cookie
Expand Down Expand Up @@ -187,7 +191,8 @@ func (s *Session) GetAll() (map[string]interface{}, error) {
return s.values, nil
}

return s.manager.store.GetAll(s.cookie.Value)
out, err := s.manager.store.GetAll(s.cookie.Value)
return out, errAs(err)
}

// GetMulti gets a map of values for multiple session keys.
Expand All @@ -209,7 +214,8 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) {
return vals, nil
}

return s.manager.store.GetMulti(s.cookie.Value, keys...)
out, err := s.manager.store.GetMulti(s.cookie.Value, keys...)
return out, errAs(err)
}

// Get gets a value for given key in session.
Expand All @@ -229,7 +235,8 @@ func (s *Session) Get(key string) (interface{}, error) {
}

// Get from backend if not found in previous step
return s.manager.store.Get(s.cookie.Value, key)
out, err := s.manager.store.Get(s.cookie.Value, key)
return out, errAs(err)
}

// Set sets a value for given key in session. Its up to store to commit
Expand All @@ -240,7 +247,8 @@ func (s *Session) Set(key string, val interface{}) error {
return ErrInvalidSession
}

return s.manager.store.Set(s.cookie.Value, key, val)
err := s.manager.store.Set(s.cookie.Value, key, val)
return errAs(err)
}

// SetMulti sets all values in the session.
Expand All @@ -253,8 +261,8 @@ func (s *Session) SetMulti(values map[string]interface{}) error {
}

for k, v := range values {
if err := s.manager.store.Set(s, s.cookie.Value, k, v); err != nil {
return err
if err := s.manager.store.Set(s.cookie.Value, k, v); err != nil {
return errAs(err)
}
}

Expand All @@ -269,7 +277,11 @@ func (s *Session) Commit() error {
return ErrInvalidSession
}

return s.manager.store.Commit(s.cookie.Value)
if err := s.manager.store.Commit(s.cookie.Value); err != nil {
return errAs(err)
}

return nil
}

// Delete deletes a field from session.
Expand All @@ -279,7 +291,11 @@ func (s *Session) Delete(key string) error {
return ErrInvalidSession
}

return s.manager.store.Delete(s.cookie.Value, key)
if err := s.manager.store.Delete(s.cookie.Value, key); err != nil {
return errAs(err)
}

return nil
}

// Clear clears session data from store and clears the cookie
Expand All @@ -290,43 +306,76 @@ func (s *Session) Clear() error {
}

if err := s.manager.store.Clear(s.cookie.Value); err != nil {
return err
return errAs(err)
}

return s.clearCookie()
}

// Int is a helper to get values as integer
func (s *Session) Int(r interface{}, err error) (int, error) {
return s.manager.store.Int(r, err)
out, err := s.manager.store.Int(r, err)
return out, errAs(err)
}

// Int64 is a helper to get values as Int64
func (s *Session) Int64(r interface{}, err error) (int64, error) {
return s.manager.store.Int64(r, err)
out, err := s.manager.store.Int64(r, err)
return out, errAs(err)
}

// UInt64 is a helper to get values as UInt64
func (s *Session) UInt64(r interface{}, err error) (uint64, error) {
return s.manager.store.UInt64(r, err)
out, err := s.manager.store.UInt64(r, err)
return out, errAs(err)
}

// Float64 is a helper to get values as Float64
func (s *Session) Float64(r interface{}, err error) (float64, error) {
return s.manager.store.Float64(r, err)
out, err := s.manager.store.Float64(r, err)
return out, errAs(err)
}

// String is a helper to get values as String
func (s *Session) String(r interface{}, err error) (string, error) {
return s.manager.store.String(r, err)
out, err := s.manager.store.String(r, err)
return out, errAs(err)
}

// Bytes is a helper to get values as Bytes
func (s *Session) Bytes(r interface{}, err error) ([]byte, error) {
return s.manager.store.Bytes(r, err)
out, err := s.manager.store.Bytes(r, err)
return out, errAs(err)
}

// Bool is a helper to get values as Bool
func (s *Session) Bool(r interface{}, err error) (bool, error) {
return s.manager.store.Bool(r, err)
out, err := s.manager.store.Bool(r, err)
return out, errAs(err)
}

// errAs takes an error coming from a store and maps it to an error
// defined in the sessions package based on its code, if it's available at all.
func errAs(err error) error {
if err == nil {
return nil
}

e, ok := err.(errCode)
if !ok {
return err
}

switch e.Code() {
case 1:
return ErrInvalidSession
case 2:
return ErrFieldNotFound
case 3:
return ErrAssertType
case 4:
return ErrNil
}

return err
}
29 changes: 29 additions & 0 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -668,3 +668,32 @@ func TestSessionClearInvalidSession(t *testing.T) {
err = sess.Clear()
assert.Error(err, ErrInvalidSession.Error())
}

type Err struct {
code int
msg string
}

func (e *Err) Error() string {
return e.msg
}

func (e *Err) Code() int {
return e.code
}

func TestErrorTypes(t *testing.T) {
var (
// Error codes for store errors. This should match the codes
// defined in the /simplesessions package exactly.
errInvalidSession = &Err{code: 1, msg: "invalid session"}
errFieldNotFound = &Err{code: 2, msg: "field not found"}
errAssertType = &Err{code: 3, msg: "assertion failed"}
errNil = &Err{code: 4, msg: "nil returned"}
)

assert.Equal(t, errAs(errInvalidSession), ErrInvalidSession)
assert.Equal(t, errAs(errFieldNotFound), ErrFieldNotFound)
assert.Equal(t, errAs(errAssertType), ErrAssertType)
assert.Equal(t, errAs(errNil), ErrNil)
}

0 comments on commit 6ce940f

Please sign in to comment.