Skip to content

Commit

Permalink
feat: better error and nil handling from redis
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Oct 4, 2024
1 parent 77b298c commit 6dc9366
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 38 deletions.
1 change: 1 addition & 0 deletions admin_frontend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async fn main() {
tracing_subscriber::fmt()
.json()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_line_number(true)
.init();

let config = Config::from_env().unwrap();
Expand Down
65 changes: 39 additions & 26 deletions admin_frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,22 @@ impl SessionStorage {
Self { redis_client }
}

pub async fn get_user_session(&self, session_id: &str) -> Option<UserSession> {
pub async fn get_user_session(
&self,
session_id: &str,
) -> Result<Option<UserSession>, redis::RedisError> {
let key = session_id_key(session_id);
let s: Result<UserSession, redis::RedisError> = self.redis_client.clone().get(&key).await;
match s {
Ok(s) => Some(s),
Err(e) => {
tracing::info!("get user session in redis error: {:?}", e);
None
},
}
let user_session_optional: UserSessionOptional = self.redis_client.clone().get(&key).await?;
Ok(user_session_optional.0)
}

pub async fn get_code_session(&self, session_id: &str) -> Option<CodeSession> {
let key = code_session_key(session_id);
let s: Result<CodeSession, redis::RedisError> = self.redis_client.clone().get(&key).await;
match s {
Ok(s) => Some(s),
Err(e) => {
tracing::info!("get user session in redis error: {:?}", e);
None
},
}
pub async fn get_code_session(
&self,
code: &str,
) -> Result<Option<CodeSession>, redis::RedisError> {
let key = code_session_key(code);
let code_session_optional: CodeSessionOptional = self.redis_client.clone().get(&key).await?;
Ok(code_session_optional.0)
}

pub async fn put_user_session(&self, user_session: &UserSession) -> redis::RedisResult<()> {
Expand Down Expand Up @@ -84,7 +78,7 @@ impl SessionStorage {
code: &str,
code_session: &CodeSession,
) -> redis::RedisResult<()> {
let key = format!("session::code::{}", code);
let key = code_session_key(code);
self
.redis_client
.clone()
Expand All @@ -103,6 +97,7 @@ pub struct CodeSession {
pub code_challenge: Option<String>,
pub code_challenge_method: Option<String>,
}
pub struct CodeSessionOptional(Option<CodeSession>);

impl ToRedisArgs for CodeSession {
fn write_redis_args<W>(&self, out: &mut W)
Expand All @@ -114,10 +109,16 @@ impl ToRedisArgs for CodeSession {
}
}

impl FromRedisValue for CodeSession {
impl FromRedisValue for CodeSessionOptional {
fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
let bytes = expect_redis_value_data(v)?;
expect_redis_json_bytes(bytes)
match bytes {
Some(bytes) => {
let session = expect_redis_json_bytes(bytes).unwrap();
Ok(CodeSessionOptional(Some(session)))
},
None => Ok(CodeSessionOptional(None)),
}
}
}

Expand All @@ -126,6 +127,7 @@ pub struct UserSession {
pub session_id: String,
pub token: GotrueTokenResponse,
}
pub struct UserSessionOptional(Option<UserSession>);

#[async_trait]
impl FromRequestParts<AppState> for UserSession {
Expand Down Expand Up @@ -170,6 +172,10 @@ async fn get_session_from_store(
let mut session = session_store
.get_user_session(session_id)
.await
.map_err(|err| {
tracing::info!("failed to get session from store: {}", err);
SessionRejectionKind::SessionNotFound
})?
.ok_or(SessionRejectionKind::SessionNotFound)?;

if has_expired(session.token.access_token.as_str()) {
Expand Down Expand Up @@ -275,10 +281,16 @@ impl ToRedisArgs for UserSession {
}
}

impl FromRedisValue for UserSession {
impl FromRedisValue for UserSessionOptional {
fn from_redis_value(v: &redis::Value) -> redis::RedisResult<Self> {
let bytes = expect_redis_value_data(v)?;
expect_redis_json_bytes(bytes)
match bytes {
Some(bytes) => {
let session = expect_redis_json_bytes(bytes).unwrap();
Ok(UserSessionOptional(Some(session)))
},
None => Ok(UserSessionOptional(None)),
}
}
}

Expand All @@ -297,9 +309,10 @@ where
}
}

fn expect_redis_value_data(v: &redis::Value) -> redis::RedisResult<&[u8]> {
fn expect_redis_value_data(v: &redis::Value) -> redis::RedisResult<Option<&[u8]>> {
match v {
redis::Value::Data(ref bytes) => Ok(bytes),
redis::Value::Data(ref bytes) => Ok(Some(bytes)),
redis::Value::Nil => Ok(None),
x => Err(redis::RedisError::from((
redis::ErrorKind::TypeError,
"unexpected value from redis",
Expand Down
17 changes: 5 additions & 12 deletions admin_frontend/src/web_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,13 @@ async fn oauth_redirect_token_handler(
State(state): State<AppState>,
Query(token_req): Query<OAuthRedirectToken>,
) -> Result<axum::response::Response, WebApiError<'static>> {
// TODO: check client_id and secret
// TODO: handle code challenge and code challenge method
// TODO: check client secret

let code_session = state
.session_store
.get_code_session(&token_req.code)
.await
.ok_or(WebApiError::new(
status::StatusCode::BAD_REQUEST,
"invalid code or expired",
))?;
.await?
.ok_or_else(|| WebApiError::new(StatusCode::BAD_REQUEST, "invalid code"))?;

if let Some(code_challenge) = code_session.code_challenge {
match code_session.code_challenge_method.as_deref() {
Expand Down Expand Up @@ -488,11 +484,8 @@ async fn oauth_redirect_token_handler(
let user_session = state
.session_store
.get_user_session(&code_session.session_id)
.await
.ok_or(WebApiError::new(
status::StatusCode::BAD_REQUEST,
"invalid session_id or expired session",
))?;
.await?
.ok_or_else(|| WebApiError::new(StatusCode::BAD_REQUEST, "invalid session"))?;

let resp = axum::Json::from(user_session.token);
Ok(resp.into_response())
Expand Down

0 comments on commit 6dc9366

Please sign in to comment.