diff --git a/cmd/api/src/api/marshalling.go b/cmd/api/src/api/marshalling.go index a9337d69b..edf734d11 100644 --- a/cmd/api/src/api/marshalling.go +++ b/cmd/api/src/api/marshalling.go @@ -101,11 +101,6 @@ func WriteBasicResponse(ctx context.Context, inputData any, statusCode int, resp } } -// intended for 2xx responses such as http.StatusNoContent -func WriteEmptyResponse(_ context.Context, statusCode int, response http.ResponseWriter) { - response.WriteHeader(statusCode) -} - func WriteResponseWrapperWithPagination(ctx context.Context, data any, limit int, skip, count, statusCode int, response http.ResponseWriter) { wrapper := ResponseWrapper{} wrapper.Data = data diff --git a/cmd/api/src/database/auth.go b/cmd/api/src/database/auth.go index 30fe62657..eeece900c 100644 --- a/cmd/api/src/database/auth.go +++ b/cmd/api/src/database/auth.go @@ -569,6 +569,10 @@ func (s *BloodhoundDB) EndUserSession(ctx context.Context, userSession model.Use // corresponding retrival function is model.UserSession.GetFlag() func (s *BloodhoundDB) SetUserSessionFlag(ctx context.Context, userSession *model.UserSession, key model.SessionFlagKey, state bool) error { + if userSession.ID == 0 { + return errors.Error("invalid session - missing session id") + } + var auditEntry = model.AuditEntry{} var doAudit = false // only audit if the new state is true, meaning the EULA is currently being accepted diff --git a/cmd/api/src/database/auth_test.go b/cmd/api/src/database/auth_test.go index 1b57c0caf..36e5a5295 100644 --- a/cmd/api/src/database/auth_test.go +++ b/cmd/api/src/database/auth_test.go @@ -390,3 +390,25 @@ func TestDatabase_CreateUserSession(t *testing.T) { assert.Equal(t, user, newUserSession.User) } } + +func TestDatabase_SetUserSessionFlag(t *testing.T) { + var ( + testCtx = context.Background() + dbInst, user = initAndCreateUser(t) + userSession = model.UserSession{ + User: user, + UserID: user.ID, + ExpiresAt: time.Now().UTC().Add(time.Hour), + } + ) + + newUserSession, err := dbInst.CreateUserSession(testCtx, userSession) + assert.Nil(t, err) + + err = dbInst.SetUserSessionFlag(testCtx, &newUserSession, model.SessionFlagFedEULAAccepted, true) + assert.Nil(t, err) + + dbSess, err := dbInst.GetUserSession(testCtx, newUserSession.ID) + assert.Nil(t, err) + assert.True(t, dbSess.Flags[string(model.SessionFlagFedEULAAccepted)]) +}