Skip to content

Commit

Permalink
Refactor remaining db code relating to credentials table
Browse files Browse the repository at this point in the history
Include fix for in-memory cache impl, and refactored stateService
  • Loading branch information
NovaFox161 committed Aug 29, 2023
1 parent 2d292bf commit 2b9eea4
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 115 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package org.dreamexposure.discal.cam.business

import org.dreamexposure.discal.OauthStateCache
import org.dreamexposure.discal.core.crypto.KeyGenerator
import org.springframework.stereotype.Component

@Component
class OauthStateService(
private val stateCache: OauthStateCache,
) {
suspend fun generateState(): String {
val state = KeyGenerator.csRandomAlphaNumericString(64)
stateCache.put(state, state)

return state
}

suspend fun validateState(state: String) = stateCache.getAndRemove(state) != null
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package org.dreamexposure.discal.cam.endpoints.v1.oauth2

import kotlinx.coroutines.reactor.awaitSingle
import org.dreamexposure.discal.cam.business.OauthStateService
import org.dreamexposure.discal.cam.discord.DiscordOauthHandler
import org.dreamexposure.discal.cam.json.discal.LoginResponse
import org.dreamexposure.discal.cam.json.discal.TokenRequest
import org.dreamexposure.discal.cam.json.discal.TokenResponse
import org.dreamexposure.discal.cam.service.StateService
import org.dreamexposure.discal.core.annotations.Authentication
import org.dreamexposure.discal.core.business.SessionService
import org.dreamexposure.discal.core.config.Config
Expand All @@ -21,8 +21,8 @@ import java.nio.charset.Charset.defaultCharset
@RestController
@RequestMapping("/oauth2/discord/")
class DiscordOauthEndpoint(
private val stateService: StateService,
private val sessionService: SessionService,
private val oauthStateService: OauthStateService,
private val discordOauthHandler: DiscordOauthHandler,
) {
private val redirectUrl = Config.URL_DISCORD_REDIRECT.getString()
Expand All @@ -34,8 +34,8 @@ class DiscordOauthEndpoint(

@GetMapping("login")
@Authentication(access = Authentication.AccessLevel.PUBLIC)
fun login(): LoginResponse {
val state = stateService.generateState()
suspend fun login(): LoginResponse {
val state = oauthStateService.generateState()

val link = "$oauthLinkWithoutState&state=$state"

Expand All @@ -52,7 +52,7 @@ class DiscordOauthEndpoint(
@Authentication(access = Authentication.AccessLevel.PUBLIC)
suspend fun token(@RequestBody body: TokenRequest): TokenResponse {
// Validate state
if (!stateService.validateState(body.state)) {
if (!oauthStateService.validateState(body.state)) {
// State invalid - 400
throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid state")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class GoogleAuth(
.flatMap(this::doAccessTokenRequest)
.flatMap { credential.setAccessToken(it.accessToken).thenReturn(it) }
.doOnNext { credential.credential.expiresAt = it.validUntil }
.flatMap { DatabaseManager.updateCredentialData(credential.credentialData).thenReturn(it) }//TODO: Replace this
.flatMap(mono { credentialService.updateCredential(credential.credential) }::thenReturn)
}.switchIfEmpty(Mono.error(EmptyNotAllowedException()))

}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.dreamexposure.discal.core.business

import kotlinx.coroutines.reactor.awaitSingle
import org.dreamexposure.discal.CredentialsCache
import org.dreamexposure.discal.core.database.CredentialData
import org.dreamexposure.discal.core.database.CredentialsRepository
import org.dreamexposure.discal.core.`object`.new.Credential
import org.springframework.stereotype.Component
Expand All @@ -12,6 +13,19 @@ class DefaultCredentialService(
private val credentialsRepository: CredentialsRepository,
private val credentialsCache: CredentialsCache,
) : CredentialService {

override suspend fun createCredential(credential: Credential): Credential {
val saved = credentialsRepository.save(CredentialData(
credentialNumber = credential.credentialNumber,
accessToken = credential.encryptedAccessToken,
refreshToken = credential.encryptedRefreshToken,
expiresAt = credential.expiresAt.toEpochMilli(),
)).map(::Credential).awaitSingle()

credentialsCache.put(saved.credentialNumber, saved)
return saved
}

override suspend fun getCredential(number: Int): Credential? {
var credential = credentialsCache.get(number)
if (credential != null) return credential
Expand All @@ -24,8 +38,20 @@ class DefaultCredentialService(
return credential
}

override suspend fun updateCredential(credential: Credential) {
credentialsRepository.updateByCredentialNumber(
credentialNumber = credential.credentialNumber,
refreshToken = credential.encryptedRefreshToken,
accessToken = credential.encryptedAccessToken,
expiresAt = credential.expiresAt.toEpochMilli(),
).awaitSingle()

credentialsCache.put(credential.credentialNumber, credential)
}
}

interface CredentialService {
suspend fun createCredential(credential: Credential): Credential
suspend fun getCredential(number: Int): Credential?
suspend fun updateCredential(credential: Credential)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@ class JdkCacheRepository<K : Any, V>(override val ttl: Duration) : CacheReposito
}

override suspend fun get(key: K): V? {
return cache[key]?.second
val cached = cache[key] ?: return null
if (Instant.now().isAfter(cached.first)) {
evict(key)
return null
}
return cached.second
}

override suspend fun getAndRemove(key: K): V? {
val cached = cache[key]?.second

val cached = cache[key] ?: return null
evict(key)
return cached

return if (Instant.now().isAfter(cached.first)) null else cached.second

}

override suspend fun evict(key: K) {
Expand All @@ -36,5 +42,4 @@ class JdkCacheRepository<K : Any, V>(override val ttl: Duration) : CacheReposito
private fun evictOld() {
cache.forEach { (key, pair) -> if (Duration.between(pair.first, Instant.now()) >= ttl) cache.remove(key) }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.dreamexposure.discal.core.config

import com.fasterxml.jackson.databind.ObjectMapper
import org.dreamexposure.discal.CredentialsCache
import org.dreamexposure.discal.OauthStateCache
import org.dreamexposure.discal.core.cache.JdkCacheRepository
import org.dreamexposure.discal.core.cache.RedisCacheRepository
import org.dreamexposure.discal.core.extensions.asMinutes
Expand All @@ -19,9 +20,11 @@ class CacheConfig {
private val prefix = Config.CACHE_PREFIX.getString()
private val settingsCacheName = "$prefix.settingsCache"
private val credentialsCacheName = "$prefix.credentialsCache"
private val oauthStateCacheName = "$prefix.oauthStateCache"

private val settingsTtl = Config.CACHE_TTL_SETTINGS_MINUTES.getLong().asMinutes()
private val credentialsTll = Config.CACHE_TTL_CREDENTIALS_MINUTES.getLong().asMinutes()
private val oauthStateTtl = Config.CACHE_TTL_OAUTH_STATE_MINUTES.getLong().asMinutes()


// Redis caching
Expand All @@ -35,6 +38,8 @@ class CacheConfig {
.withCacheConfiguration(credentialsCacheName,
RedisCacheConfiguration.defaultCacheConfig().entryTtl(credentialsTll)
)
.withCacheConfiguration(oauthStateCacheName,
RedisCacheConfiguration.defaultCacheConfig().entryTtl(oauthStateTtl))
.build()
}

Expand All @@ -44,8 +49,17 @@ class CacheConfig {
fun credentialsRedisCache(cacheManager: RedisCacheManager, objectMapper: ObjectMapper): CredentialsCache =
RedisCacheRepository(cacheManager, objectMapper, credentialsCacheName)

@Bean
@Primary
@ConditionalOnProperty("bot.cache.redis", havingValue = "true")
fun oauthStateRedisCache(cacheManager: RedisCacheManager, objectMapper: ObjectMapper): OauthStateCache =
RedisCacheRepository(cacheManager, objectMapper, oauthStateCacheName)


// In-memory fallback caching
@Bean
fun credentialsFallbackCache(): CredentialsCache = JdkCacheRepository(settingsTtl)

@Bean
fun oauthStateFallbackCache(): OauthStateCache = JdkCacheRepository(settingsTtl)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ enum class Config(private val key: String, private var value: Any? = null) {
CACHE_TTL_SETTINGS_MINUTES("bot.cache.ttl-minutes.settings", 60),
CACHE_TTL_CREDENTIALS_MINUTES("bot.cache.ttl-minutes.credentials", 120),
CACHE_TTL_ACCOUNTS_MINUTES("bot.cache.ttl-minutes.accounts", 60),
CACHE_TTL_OAUTH_STATE_MINUTES("bot.cache.ttl-minutes.oauth.state", 5),

// Security configuration

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
package org.dreamexposure.discal.core.database

import org.springframework.data.r2dbc.repository.Query
import org.springframework.data.r2dbc.repository.R2dbcRepository
import reactor.core.publisher.Mono

interface CredentialsRepository : R2dbcRepository<CredentialData, Int> {

fun findByCredentialNumber(credentialNumber: Int): Mono<CredentialData>

// TODO: Finish impl???
@Query("""
UPDATE credentials
SET refresh_token = :refreshToken,
access_token = :accessToken,
expires_at = :expiresAt
WHERE credential_number = :credentialNumber
""")
fun updateByCredentialNumber(
credentialNumber: Int,
refreshToken: String,
accessToken: String,
expiresAt: Long,
): Mono<Int>
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.dreamexposure.discal.core.`object`.announcement.Announcement
import org.dreamexposure.discal.core.`object`.calendar.CalendarData
import org.dreamexposure.discal.core.`object`.event.EventData
import org.dreamexposure.discal.core.`object`.event.RsvpData
import org.dreamexposure.discal.core.`object`.google.GoogleCredentialData
import org.dreamexposure.discal.core.`object`.web.UserAPIAccount
import org.dreamexposure.discal.core.utils.GlobalVal.DEFAULT
import org.intellij.lang.annotations.Language
Expand Down Expand Up @@ -459,53 +458,6 @@ object DatabaseManager {
}
}

fun updateCredentialData(credData: GoogleCredentialData): Mono<Boolean> {
return connect { c ->
Mono.from(
c.createStatement(Queries.SELECT_CREDENTIAL_DATA)
.bind(0, credData.credentialNumber)
.execute()
).flatMapMany { res ->
res.map { row, _ -> row }
}.hasElements().flatMap { exists ->
if (exists) {
val updateCommand = """UPDATE ${Tables.CREDS} SET
REFRESH_TOKEN = ?, ACCESS_TOKEN = ?, EXPIRES_AT = ?
WHERE CREDENTIAL_NUMBER = ?""".trimMargin()

Mono.from(
c.createStatement(updateCommand)
.bind(0, credData.encryptedRefreshToken)
.bind(1, credData.encryptedAccessToken)
.bind(2, credData.expiresAt.toEpochMilli())
.bind(3, credData.credentialNumber)
.execute()
).flatMapMany(Result::getRowsUpdated)
.hasElements()
.thenReturn(true)
} else {
val insertCommand = """INSERT INTO ${Tables.CREDS}
|(CREDENTIAL_NUMBER, REFRESH_TOKEN, ACCESS_TOKEN, EXPIRES_AT)
|VALUES(?, ?, ?, ?)""".trimMargin()

Mono.from(
c.createStatement(insertCommand)
.bind(0, credData.credentialNumber)
.bind(1, credData.encryptedRefreshToken)
.bind(2, credData.encryptedAccessToken)
.bind(3, credData.expiresAt.toEpochMilli())
.execute()
).flatMapMany(Result::getRowsUpdated)
.hasElements()
.thenReturn(true)
}.doOnError {
LOGGER.error(DEFAULT, "Failed to update credential data", it)
}.onErrorResume { Mono.just(false) }
}

}
}

fun getAPIAccount(APIKey: String): Mono<UserAPIAccount> {
return connect { c ->
Mono.from(
Expand Down Expand Up @@ -1578,11 +1530,6 @@ private object Queries {
@Language("MySQL")
val SELECT_ALL_ANNOUNCEMENT_COUNT = """SELECT COUNT(*) FROM ${Tables.ANNOUNCEMENTS}"""

@Language("MySQL")
val SELECT_CREDENTIAL_DATA = """SELECT * FROM ${Tables.CREDS}
WHERE CREDENTIAL_NUMBER = ?
""".trimMargin()

@Language("MySQL")
val DELETE_ANNOUNCEMENT = """DELETE FROM ${Tables.ANNOUNCEMENTS}
WHERE ANNOUNCEMENT_ID = ?
Expand Down Expand Up @@ -1744,12 +1691,6 @@ private object Tables {
@Language("Kotlin")
const val RSVP = "rsvp"

@Language("Kotlin")
const val CREDS = "credentials"

@Language("Kotlin")
const val STATIC_MESSAGES = "static_messages"

@Language("Kotlin")
const val SESSIONS = "sessions"
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ import org.dreamexposure.discal.core.`object`.new.Credential
// Cache
//typealias GuildSettingsCache = CacheRepository<Long, GuildSettings>
typealias CredentialsCache = CacheRepository<Int, Credential>
typealias OauthStateCache = CacheRepository<String, String>

0 comments on commit 2b9eea4

Please sign in to comment.