diff --git a/cam/src/main/kotlin/org/dreamexposure/discal/cam/business/OauthStateService.kt b/cam/src/main/kotlin/org/dreamexposure/discal/cam/business/OauthStateService.kt new file mode 100644 index 000000000..5352a697f --- /dev/null +++ b/cam/src/main/kotlin/org/dreamexposure/discal/cam/business/OauthStateService.kt @@ -0,0 +1,19 @@ +package org.dreamexposure.discal.cam.business + +import org.dreamexposure.discal.OauthStateCache +import org.dreamexposure.discal.core.crypto.KeyGenerator +import org.springframework.stereotype.Component + +@Component +class OauthStateService( + private val stateCache: OauthStateCache, +) { + suspend fun generateState(): String { + val state = KeyGenerator.csRandomAlphaNumericString(64) + stateCache.put(state, state) + + return state + } + + suspend fun validateState(state: String) = stateCache.getAndRemove(state) != null +} diff --git a/cam/src/main/kotlin/org/dreamexposure/discal/cam/endpoints/v1/oauth2/DiscordOauthEndpoint.kt b/cam/src/main/kotlin/org/dreamexposure/discal/cam/endpoints/v1/oauth2/DiscordOauthEndpoint.kt index 2ddd428f5..113e420d9 100644 --- a/cam/src/main/kotlin/org/dreamexposure/discal/cam/endpoints/v1/oauth2/DiscordOauthEndpoint.kt +++ b/cam/src/main/kotlin/org/dreamexposure/discal/cam/endpoints/v1/oauth2/DiscordOauthEndpoint.kt @@ -1,11 +1,11 @@ package org.dreamexposure.discal.cam.endpoints.v1.oauth2 import kotlinx.coroutines.reactor.awaitSingle +import org.dreamexposure.discal.cam.business.OauthStateService import org.dreamexposure.discal.cam.discord.DiscordOauthHandler import org.dreamexposure.discal.cam.json.discal.LoginResponse import org.dreamexposure.discal.cam.json.discal.TokenRequest import org.dreamexposure.discal.cam.json.discal.TokenResponse -import org.dreamexposure.discal.cam.service.StateService import org.dreamexposure.discal.core.annotations.Authentication import org.dreamexposure.discal.core.business.SessionService import org.dreamexposure.discal.core.config.Config @@ -21,8 +21,8 @@ import java.nio.charset.Charset.defaultCharset @RestController @RequestMapping("/oauth2/discord/") class DiscordOauthEndpoint( - private val stateService: StateService, private val sessionService: SessionService, + private val oauthStateService: OauthStateService, private val discordOauthHandler: DiscordOauthHandler, ) { private val redirectUrl = Config.URL_DISCORD_REDIRECT.getString() @@ -34,8 +34,8 @@ class DiscordOauthEndpoint( @GetMapping("login") @Authentication(access = Authentication.AccessLevel.PUBLIC) - fun login(): LoginResponse { - val state = stateService.generateState() + suspend fun login(): LoginResponse { + val state = oauthStateService.generateState() val link = "$oauthLinkWithoutState&state=$state" @@ -52,7 +52,7 @@ class DiscordOauthEndpoint( @Authentication(access = Authentication.AccessLevel.PUBLIC) suspend fun token(@RequestBody body: TokenRequest): TokenResponse { // Validate state - if (!stateService.validateState(body.state)) { + if (!oauthStateService.validateState(body.state)) { // State invalid - 400 throw ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid state") } diff --git a/cam/src/main/kotlin/org/dreamexposure/discal/cam/google/GoogleAuth.kt b/cam/src/main/kotlin/org/dreamexposure/discal/cam/google/GoogleAuth.kt index 3d7b575de..516db592b 100644 --- a/cam/src/main/kotlin/org/dreamexposure/discal/cam/google/GoogleAuth.kt +++ b/cam/src/main/kotlin/org/dreamexposure/discal/cam/google/GoogleAuth.kt @@ -79,7 +79,7 @@ class GoogleAuth( .flatMap(this::doAccessTokenRequest) .flatMap { credential.setAccessToken(it.accessToken).thenReturn(it) } .doOnNext { credential.credential.expiresAt = it.validUntil } - .flatMap { DatabaseManager.updateCredentialData(credential.credentialData).thenReturn(it) }//TODO: Replace this + .flatMap(mono { credentialService.updateCredential(credential.credential) }::thenReturn) }.switchIfEmpty(Mono.error(EmptyNotAllowedException())) } diff --git a/cam/src/main/kotlin/org/dreamexposure/discal/cam/service/StateService.kt b/cam/src/main/kotlin/org/dreamexposure/discal/cam/service/StateService.kt deleted file mode 100644 index ed5ffebb9..000000000 --- a/cam/src/main/kotlin/org/dreamexposure/discal/cam/service/StateService.kt +++ /dev/null @@ -1,44 +0,0 @@ -package org.dreamexposure.discal.cam.service - -import org.dreamexposure.discal.core.crypto.KeyGenerator -import org.springframework.stereotype.Component -import reactor.core.publisher.Flux -import java.time.Duration -import java.time.Instant -import java.time.temporal.ChronoUnit -import java.util.concurrent.ConcurrentHashMap - -@Component -class StateService { - - private val states: MutableMap = ConcurrentHashMap() - - init { - // occasionally remove expired/unused states - Flux.interval(Duration.ofHours(1)) - .doOnNext { - val toRemove = mutableListOf() - - states.forEach { (state, expires) -> - if (expires.isBefore(Instant.now())) - toRemove.add(state) - } - - toRemove.forEach(states::remove) - }.subscribe() - } - - fun generateState(): String { - val state = KeyGenerator.csRandomAlphaNumericString(64) - states[state] = Instant.now().plus(5, ChronoUnit.MINUTES) - - return state - } - - fun validateState(state: String): Boolean { - val expiresAt = states[state] - states.remove(state) // Remove state immediately to prevent replay attacks - - return expiresAt != null && expiresAt.isAfter(Instant.now()) - } -} diff --git a/core/src/main/kotlin/org/dreamexposure/discal/core/business/CredentialService.kt b/core/src/main/kotlin/org/dreamexposure/discal/core/business/CredentialService.kt index e7b421a4d..ee8c919e0 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/core/business/CredentialService.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/core/business/CredentialService.kt @@ -2,6 +2,7 @@ package org.dreamexposure.discal.core.business import kotlinx.coroutines.reactor.awaitSingle import org.dreamexposure.discal.CredentialsCache +import org.dreamexposure.discal.core.database.CredentialData import org.dreamexposure.discal.core.database.CredentialsRepository import org.dreamexposure.discal.core.`object`.new.Credential import org.springframework.stereotype.Component @@ -12,6 +13,19 @@ class DefaultCredentialService( private val credentialsRepository: CredentialsRepository, private val credentialsCache: CredentialsCache, ) : CredentialService { + + override suspend fun createCredential(credential: Credential): Credential { + val saved = credentialsRepository.save(CredentialData( + credentialNumber = credential.credentialNumber, + accessToken = credential.encryptedAccessToken, + refreshToken = credential.encryptedRefreshToken, + expiresAt = credential.expiresAt.toEpochMilli(), + )).map(::Credential).awaitSingle() + + credentialsCache.put(saved.credentialNumber, saved) + return saved + } + override suspend fun getCredential(number: Int): Credential? { var credential = credentialsCache.get(number) if (credential != null) return credential @@ -24,8 +38,20 @@ class DefaultCredentialService( return credential } + override suspend fun updateCredential(credential: Credential) { + credentialsRepository.updateByCredentialNumber( + credentialNumber = credential.credentialNumber, + refreshToken = credential.encryptedRefreshToken, + accessToken = credential.encryptedAccessToken, + expiresAt = credential.expiresAt.toEpochMilli(), + ).awaitSingle() + + credentialsCache.put(credential.credentialNumber, credential) + } } interface CredentialService { + suspend fun createCredential(credential: Credential): Credential suspend fun getCredential(number: Int): Credential? + suspend fun updateCredential(credential: Credential) } diff --git a/core/src/main/kotlin/org/dreamexposure/discal/core/cache/JdkCacheRepository.kt b/core/src/main/kotlin/org/dreamexposure/discal/core/cache/JdkCacheRepository.kt index 223343ad1..e251ca68a 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/core/cache/JdkCacheRepository.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/core/cache/JdkCacheRepository.kt @@ -19,14 +19,20 @@ class JdkCacheRepository(override val ttl: Duration) : CacheReposito } override suspend fun get(key: K): V? { - return cache[key]?.second + val cached = cache[key] ?: return null + if (Instant.now().isAfter(cached.first)) { + evict(key) + return null + } + return cached.second } override suspend fun getAndRemove(key: K): V? { - val cached = cache[key]?.second - + val cached = cache[key] ?: return null evict(key) - return cached + + return if (Instant.now().isAfter(cached.first)) null else cached.second + } override suspend fun evict(key: K) { @@ -36,5 +42,4 @@ class JdkCacheRepository(override val ttl: Duration) : CacheReposito private fun evictOld() { cache.forEach { (key, pair) -> if (Duration.between(pair.first, Instant.now()) >= ttl) cache.remove(key) } } - } diff --git a/core/src/main/kotlin/org/dreamexposure/discal/core/config/CacheConfig.kt b/core/src/main/kotlin/org/dreamexposure/discal/core/config/CacheConfig.kt index 462386304..2621dd497 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/core/config/CacheConfig.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/core/config/CacheConfig.kt @@ -2,6 +2,7 @@ package org.dreamexposure.discal.core.config import com.fasterxml.jackson.databind.ObjectMapper import org.dreamexposure.discal.CredentialsCache +import org.dreamexposure.discal.OauthStateCache import org.dreamexposure.discal.core.cache.JdkCacheRepository import org.dreamexposure.discal.core.cache.RedisCacheRepository import org.dreamexposure.discal.core.extensions.asMinutes @@ -19,9 +20,11 @@ class CacheConfig { private val prefix = Config.CACHE_PREFIX.getString() private val settingsCacheName = "$prefix.settingsCache" private val credentialsCacheName = "$prefix.credentialsCache" + private val oauthStateCacheName = "$prefix.oauthStateCache" private val settingsTtl = Config.CACHE_TTL_SETTINGS_MINUTES.getLong().asMinutes() private val credentialsTll = Config.CACHE_TTL_CREDENTIALS_MINUTES.getLong().asMinutes() + private val oauthStateTtl = Config.CACHE_TTL_OAUTH_STATE_MINUTES.getLong().asMinutes() // Redis caching @@ -35,6 +38,8 @@ class CacheConfig { .withCacheConfiguration(credentialsCacheName, RedisCacheConfiguration.defaultCacheConfig().entryTtl(credentialsTll) ) + .withCacheConfiguration(oauthStateCacheName, + RedisCacheConfiguration.defaultCacheConfig().entryTtl(oauthStateTtl)) .build() } @@ -44,8 +49,17 @@ class CacheConfig { fun credentialsRedisCache(cacheManager: RedisCacheManager, objectMapper: ObjectMapper): CredentialsCache = RedisCacheRepository(cacheManager, objectMapper, credentialsCacheName) + @Bean + @Primary + @ConditionalOnProperty("bot.cache.redis", havingValue = "true") + fun oauthStateRedisCache(cacheManager: RedisCacheManager, objectMapper: ObjectMapper): OauthStateCache = + RedisCacheRepository(cacheManager, objectMapper, oauthStateCacheName) + // In-memory fallback caching @Bean fun credentialsFallbackCache(): CredentialsCache = JdkCacheRepository(settingsTtl) + + @Bean + fun oauthStateFallbackCache(): OauthStateCache = JdkCacheRepository(settingsTtl) } diff --git a/core/src/main/kotlin/org/dreamexposure/discal/core/config/Config.kt b/core/src/main/kotlin/org/dreamexposure/discal/core/config/Config.kt index a9c2e4803..e24df7463 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/core/config/Config.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/core/config/Config.kt @@ -18,6 +18,7 @@ enum class Config(private val key: String, private var value: Any? = null) { CACHE_TTL_SETTINGS_MINUTES("bot.cache.ttl-minutes.settings", 60), CACHE_TTL_CREDENTIALS_MINUTES("bot.cache.ttl-minutes.credentials", 120), CACHE_TTL_ACCOUNTS_MINUTES("bot.cache.ttl-minutes.accounts", 60), + CACHE_TTL_OAUTH_STATE_MINUTES("bot.cache.ttl-minutes.oauth.state", 5), // Security configuration diff --git a/core/src/main/kotlin/org/dreamexposure/discal/core/database/CredentialsRepository.kt b/core/src/main/kotlin/org/dreamexposure/discal/core/database/CredentialsRepository.kt index b4a4682b6..55f4bad90 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/core/database/CredentialsRepository.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/core/database/CredentialsRepository.kt @@ -1,5 +1,6 @@ package org.dreamexposure.discal.core.database +import org.springframework.data.r2dbc.repository.Query import org.springframework.data.r2dbc.repository.R2dbcRepository import reactor.core.publisher.Mono @@ -7,5 +8,17 @@ interface CredentialsRepository : R2dbcRepository { fun findByCredentialNumber(credentialNumber: Int): Mono - // TODO: Finish impl??? + @Query(""" + UPDATE credentials + SET refresh_token = :refreshToken, + access_token = :accessToken, + expires_at = :expiresAt + WHERE credential_number = :credentialNumber + """) + fun updateByCredentialNumber( + credentialNumber: Int, + refreshToken: String, + accessToken: String, + expiresAt: Long, + ): Mono } diff --git a/core/src/main/kotlin/org/dreamexposure/discal/core/database/DatabaseManager.kt b/core/src/main/kotlin/org/dreamexposure/discal/core/database/DatabaseManager.kt index b4b4eb67a..9fd023116 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/core/database/DatabaseManager.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/core/database/DatabaseManager.kt @@ -26,7 +26,6 @@ import org.dreamexposure.discal.core.`object`.announcement.Announcement import org.dreamexposure.discal.core.`object`.calendar.CalendarData import org.dreamexposure.discal.core.`object`.event.EventData import org.dreamexposure.discal.core.`object`.event.RsvpData -import org.dreamexposure.discal.core.`object`.google.GoogleCredentialData import org.dreamexposure.discal.core.`object`.web.UserAPIAccount import org.dreamexposure.discal.core.utils.GlobalVal.DEFAULT import org.intellij.lang.annotations.Language @@ -459,53 +458,6 @@ object DatabaseManager { } } - fun updateCredentialData(credData: GoogleCredentialData): Mono { - return connect { c -> - Mono.from( - c.createStatement(Queries.SELECT_CREDENTIAL_DATA) - .bind(0, credData.credentialNumber) - .execute() - ).flatMapMany { res -> - res.map { row, _ -> row } - }.hasElements().flatMap { exists -> - if (exists) { - val updateCommand = """UPDATE ${Tables.CREDS} SET - REFRESH_TOKEN = ?, ACCESS_TOKEN = ?, EXPIRES_AT = ? - WHERE CREDENTIAL_NUMBER = ?""".trimMargin() - - Mono.from( - c.createStatement(updateCommand) - .bind(0, credData.encryptedRefreshToken) - .bind(1, credData.encryptedAccessToken) - .bind(2, credData.expiresAt.toEpochMilli()) - .bind(3, credData.credentialNumber) - .execute() - ).flatMapMany(Result::getRowsUpdated) - .hasElements() - .thenReturn(true) - } else { - val insertCommand = """INSERT INTO ${Tables.CREDS} - |(CREDENTIAL_NUMBER, REFRESH_TOKEN, ACCESS_TOKEN, EXPIRES_AT) - |VALUES(?, ?, ?, ?)""".trimMargin() - - Mono.from( - c.createStatement(insertCommand) - .bind(0, credData.credentialNumber) - .bind(1, credData.encryptedRefreshToken) - .bind(2, credData.encryptedAccessToken) - .bind(3, credData.expiresAt.toEpochMilli()) - .execute() - ).flatMapMany(Result::getRowsUpdated) - .hasElements() - .thenReturn(true) - }.doOnError { - LOGGER.error(DEFAULT, "Failed to update credential data", it) - }.onErrorResume { Mono.just(false) } - } - - } - } - fun getAPIAccount(APIKey: String): Mono { return connect { c -> Mono.from( @@ -1578,11 +1530,6 @@ private object Queries { @Language("MySQL") val SELECT_ALL_ANNOUNCEMENT_COUNT = """SELECT COUNT(*) FROM ${Tables.ANNOUNCEMENTS}""" - @Language("MySQL") - val SELECT_CREDENTIAL_DATA = """SELECT * FROM ${Tables.CREDS} - WHERE CREDENTIAL_NUMBER = ? - """.trimMargin() - @Language("MySQL") val DELETE_ANNOUNCEMENT = """DELETE FROM ${Tables.ANNOUNCEMENTS} WHERE ANNOUNCEMENT_ID = ? @@ -1744,12 +1691,6 @@ private object Tables { @Language("Kotlin") const val RSVP = "rsvp" - @Language("Kotlin") - const val CREDS = "credentials" - @Language("Kotlin") const val STATIC_MESSAGES = "static_messages" - - @Language("Kotlin") - const val SESSIONS = "sessions" } diff --git a/core/src/main/kotlin/org/dreamexposure/discal/typealiases.kt b/core/src/main/kotlin/org/dreamexposure/discal/typealiases.kt index 07546054f..21e6f5047 100644 --- a/core/src/main/kotlin/org/dreamexposure/discal/typealiases.kt +++ b/core/src/main/kotlin/org/dreamexposure/discal/typealiases.kt @@ -6,3 +6,4 @@ import org.dreamexposure.discal.core.`object`.new.Credential // Cache //typealias GuildSettingsCache = CacheRepository typealias CredentialsCache = CacheRepository +typealias OauthStateCache = CacheRepository