From 6ce940f24d2b8d196ea97c744e1ff6f82e35de9c Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Fri, 10 May 2024 15:09:50 +0530 Subject: [PATCH] Add code-based error mapping for retaining old global error types. --- session.go | 85 ++++++++++++++++++++++++++++++++++++++----------- session_test.go | 29 +++++++++++++++++ 2 files changed, 96 insertions(+), 18 deletions(-) diff --git a/session.go b/session.go index d828adc..e34d742 100644 --- a/session.go +++ b/session.go @@ -48,6 +48,10 @@ var ( ErrNil = errors.New("simplesession: nil returned") ) +type errCode interface { + Code() int +} + // NewSession creates a new session. Reads cookie info from `GetCookie“ callback // and validate the session with current store. If cookie not set then it creates // new session and calls `SetCookie“ callback. If `DisableAutoSet` is set then it @@ -78,7 +82,7 @@ func NewSession(m *Manager, r, w interface{}) (*Session, error) { // Store also calls `WriteCookie`` to write to http interface cv, err := m.store.Create() if err != nil { - return nil, err + return nil, errAs(err) } // Write cookie @@ -139,7 +143,7 @@ func (s *Session) Create() error { // Create new cookie in store and write to front. cv, err := s.manager.store.Create() if err != nil { - return err + return errAs(err) } // Write cookie @@ -187,7 +191,8 @@ func (s *Session) GetAll() (map[string]interface{}, error) { return s.values, nil } - return s.manager.store.GetAll(s.cookie.Value) + out, err := s.manager.store.GetAll(s.cookie.Value) + return out, errAs(err) } // GetMulti gets a map of values for multiple session keys. @@ -209,7 +214,8 @@ func (s *Session) GetMulti(keys ...string) (map[string]interface{}, error) { return vals, nil } - return s.manager.store.GetMulti(s.cookie.Value, keys...) + out, err := s.manager.store.GetMulti(s.cookie.Value, keys...) + return out, errAs(err) } // Get gets a value for given key in session. @@ -229,7 +235,8 @@ func (s *Session) Get(key string) (interface{}, error) { } // Get from backend if not found in previous step - return s.manager.store.Get(s.cookie.Value, key) + out, err := s.manager.store.Get(s.cookie.Value, key) + return out, errAs(err) } // Set sets a value for given key in session. Its up to store to commit @@ -240,7 +247,8 @@ func (s *Session) Set(key string, val interface{}) error { return ErrInvalidSession } - return s.manager.store.Set(s.cookie.Value, key, val) + err := s.manager.store.Set(s.cookie.Value, key, val) + return errAs(err) } // SetMulti sets all values in the session. @@ -253,8 +261,8 @@ func (s *Session) SetMulti(values map[string]interface{}) error { } for k, v := range values { - if err := s.manager.store.Set(s, s.cookie.Value, k, v); err != nil { - return err + if err := s.manager.store.Set(s.cookie.Value, k, v); err != nil { + return errAs(err) } } @@ -269,7 +277,11 @@ func (s *Session) Commit() error { return ErrInvalidSession } - return s.manager.store.Commit(s.cookie.Value) + if err := s.manager.store.Commit(s.cookie.Value); err != nil { + return errAs(err) + } + + return nil } // Delete deletes a field from session. @@ -279,7 +291,11 @@ func (s *Session) Delete(key string) error { return ErrInvalidSession } - return s.manager.store.Delete(s.cookie.Value, key) + if err := s.manager.store.Delete(s.cookie.Value, key); err != nil { + return errAs(err) + } + + return nil } // Clear clears session data from store and clears the cookie @@ -290,7 +306,7 @@ func (s *Session) Clear() error { } if err := s.manager.store.Clear(s.cookie.Value); err != nil { - return err + return errAs(err) } return s.clearCookie() @@ -298,35 +314,68 @@ func (s *Session) Clear() error { // Int is a helper to get values as integer func (s *Session) Int(r interface{}, err error) (int, error) { - return s.manager.store.Int(r, err) + out, err := s.manager.store.Int(r, err) + return out, errAs(err) } // Int64 is a helper to get values as Int64 func (s *Session) Int64(r interface{}, err error) (int64, error) { - return s.manager.store.Int64(r, err) + out, err := s.manager.store.Int64(r, err) + return out, errAs(err) } // UInt64 is a helper to get values as UInt64 func (s *Session) UInt64(r interface{}, err error) (uint64, error) { - return s.manager.store.UInt64(r, err) + out, err := s.manager.store.UInt64(r, err) + return out, errAs(err) } // Float64 is a helper to get values as Float64 func (s *Session) Float64(r interface{}, err error) (float64, error) { - return s.manager.store.Float64(r, err) + out, err := s.manager.store.Float64(r, err) + return out, errAs(err) } // String is a helper to get values as String func (s *Session) String(r interface{}, err error) (string, error) { - return s.manager.store.String(r, err) + out, err := s.manager.store.String(r, err) + return out, errAs(err) } // Bytes is a helper to get values as Bytes func (s *Session) Bytes(r interface{}, err error) ([]byte, error) { - return s.manager.store.Bytes(r, err) + out, err := s.manager.store.Bytes(r, err) + return out, errAs(err) } // Bool is a helper to get values as Bool func (s *Session) Bool(r interface{}, err error) (bool, error) { - return s.manager.store.Bool(r, err) + out, err := s.manager.store.Bool(r, err) + return out, errAs(err) +} + +// errAs takes an error coming from a store and maps it to an error +// defined in the sessions package based on its code, if it's available at all. +func errAs(err error) error { + if err == nil { + return nil + } + + e, ok := err.(errCode) + if !ok { + return err + } + + switch e.Code() { + case 1: + return ErrInvalidSession + case 2: + return ErrFieldNotFound + case 3: + return ErrAssertType + case 4: + return ErrNil + } + + return err } diff --git a/session_test.go b/session_test.go index d85d45e..6a130f1 100644 --- a/session_test.go +++ b/session_test.go @@ -668,3 +668,32 @@ func TestSessionClearInvalidSession(t *testing.T) { err = sess.Clear() assert.Error(err, ErrInvalidSession.Error()) } + +type Err struct { + code int + msg string +} + +func (e *Err) Error() string { + return e.msg +} + +func (e *Err) Code() int { + return e.code +} + +func TestErrorTypes(t *testing.T) { + var ( + // Error codes for store errors. This should match the codes + // defined in the /simplesessions package exactly. + errInvalidSession = &Err{code: 1, msg: "invalid session"} + errFieldNotFound = &Err{code: 2, msg: "field not found"} + errAssertType = &Err{code: 3, msg: "assertion failed"} + errNil = &Err{code: 4, msg: "nil returned"} + ) + + assert.Equal(t, errAs(errInvalidSession), ErrInvalidSession) + assert.Equal(t, errAs(errFieldNotFound), ErrFieldNotFound) + assert.Equal(t, errAs(errAssertType), ErrAssertType) + assert.Equal(t, errAs(errNil), ErrNil) +}