From 58cd2a430a172b964f546ca7a26299c4be92889b Mon Sep 17 00:00:00 2001 From: Cifko Date: Fri, 20 Sep 2024 14:49:39 +0200 Subject: [PATCH] fix: properly refresh expired token --- atoma-helpers/src/firebase/auth.rs | 52 ++++++++++++++++++------ atoma-helpers/src/firebase/mod.rs | 36 ++++++++++++---- atoma-input-manager/src/firebase/mod.rs | 14 +++---- atoma-output-manager/src/firebase/mod.rs | 20 ++++----- atoma-streamer/src/lib.rs | 13 +++--- 5 files changed, 84 insertions(+), 51 deletions(-) diff --git a/atoma-helpers/src/firebase/auth.rs b/atoma-helpers/src/firebase/auth.rs index 69dfd8a2..4980da42 100644 --- a/atoma-helpers/src/firebase/auth.rs +++ b/atoma-helpers/src/firebase/auth.rs @@ -59,6 +59,22 @@ pub struct SignInResponse { local_id: String, } +enum Response { + SignIn(SignInResponse), + Refresh(RefreshResponse), +} + +#[derive(Debug, Deserialize)] +pub struct RefreshResponse { + /// These are not all the fields returned, but these are all the fields we need + #[serde(rename = "id_token")] + id_token: String, + #[serde(rename = "refresh_token")] + refresh_token: String, + #[serde(rename = "expires_in")] + expires_in: String, +} + impl FirebaseAuth { pub(crate) async fn new(api_key: String) -> Result { let mut res = Self { @@ -76,7 +92,7 @@ impl FirebaseAuth { } /// Sign up with email and password - pub async fn sign_up(&self) -> Result { + async fn sign_up(&self) -> Result { let client = Client::new(); let url = SIGN_UP_URL(&self.api_key); let res = client @@ -84,11 +100,11 @@ impl FirebaseAuth { .json(&json!({"returnSecureToken": true})) .send() .await?; - Ok(res.json::().await?) + Ok(Response::SignIn(res.json::().await?)) } // The token is about to expire (or it already has), refresh it - pub async fn refresh(&mut self) -> Result<(), FirebaseAuthError> { + async fn refresh(&mut self) -> Result<(), FirebaseAuthError> { let client = Client::new(); let url = REFRESH_URL(&self.api_key); let res = client @@ -97,27 +113,38 @@ impl FirebaseAuth { .send() .await?; let response = if res.status().is_success() { - res.json::().await? + Response::Refresh(res.json::().await?) } else { // In rare occasions, the refresh token may expire, in which case we need to sign in again self.sign_up().await? }; + self.set_from_response(response)?; Ok(()) } /// Set the fields from a firebase response - pub fn set_from_response(&mut self, response: SignInResponse) -> Result<(), FirebaseAuthError> { - self.expires_in = Some(response.expires_in.parse()?); - self.id_token = Some(response.id_token); - self.refresh_token = Some(response.refresh_token); - self.requested_at = Some(Instant::now()); - self.local_id = Some(response.local_id); + fn set_from_response(&mut self, response: Response) -> Result<(), FirebaseAuthError> { + match response { + Response::SignIn(response) => { + self.expires_in = Some(response.expires_in.parse()?); + self.id_token = Some(response.id_token); + self.refresh_token = Some(response.refresh_token); + self.requested_at = Some(Instant::now()); + self.local_id = Some(response.local_id); + } + Response::Refresh(response) => { + self.expires_in = Some(response.expires_in.parse()?); + self.id_token = Some(response.id_token); + self.refresh_token = Some(response.refresh_token); + self.requested_at = Some(Instant::now()); + } + }; Ok(()) } /// Get the id_token - pub async fn get_id_token(&mut self) -> Result { + pub(crate) async fn get_id_token(&mut self) -> Result { // If the id_token is None, we need to sign in if self.id_token.is_none() { let response = self.sign_up().await?; @@ -127,8 +154,7 @@ impl FirebaseAuth { if self.requested_at.unwrap().elapsed().as_secs() as usize >= self.expires_in.unwrap() - EXPIRATION_DELTA { - let response = self.sign_up().await.unwrap(); - self.set_from_response(response)?; + self.refresh().await?; } // Return the id_token that is valid at least `EXPIRATION_DELTA` seconds Ok(self.id_token.clone().unwrap()) diff --git a/atoma-helpers/src/firebase/mod.rs b/atoma-helpers/src/firebase/mod.rs index bccb2a47..27721020 100644 --- a/atoma-helpers/src/firebase/mod.rs +++ b/atoma-helpers/src/firebase/mod.rs @@ -13,6 +13,7 @@ pub struct Firebase { auth: Arc>, realtime_db_url: Url, storage_url: Url, + node_id: SmallId, } impl Firebase { @@ -22,10 +23,22 @@ impl Firebase { storage_url: Url, node_id: SmallId, ) -> Result { - let mut auth = FirebaseAuth::new(api_key).await?; + let auth = FirebaseAuth::new(api_key).await?; + let firebase = Self { + auth: Arc::new(Mutex::new(auth)), + realtime_db_url, + storage_url, + node_id, + }; + firebase.store_node_id().await?; + Ok(firebase) + } + + pub async fn store_node_id(&self) -> Result<(), FirebaseAuthError> { let client = Client::new(); + let mut auth = self.auth.lock().await; let token = auth.get_id_token().await?; - let mut add_node_url = realtime_db_url.clone(); + let mut add_node_url = self.realtime_db_url.clone(); { let mut path_segment = add_node_url.path_segments_mut().unwrap(); path_segment.push("nodes"); @@ -33,15 +46,22 @@ impl Firebase { } add_node_url.set_query(Some(&format!("auth={token}"))); let data = json!({ - "id":node_id.to_string() + "id":self.node_id.to_string() }); client.put(add_node_url).json(&data).send().await?; + Ok(()) + } - Ok(Self { - auth: Arc::new(Mutex::new(auth)), - realtime_db_url, - storage_url, - }) + pub async fn get_id_token(&self) -> Result { + let mut auth = self.auth.lock().await; + let old_local_id = auth.get_local_id()?; + let id_token = auth.get_id_token().await?; + let new_local_id = auth.get_local_id()?; + if old_local_id != new_local_id { + // The local id has changed, so we need to store the new node id + self.store_node_id().await?; + } + Ok(id_token) } pub fn get_auth(&self) -> Arc> { diff --git a/atoma-input-manager/src/firebase/mod.rs b/atoma-input-manager/src/firebase/mod.rs index 5ffc251c..81425758 100644 --- a/atoma-input-manager/src/firebase/mod.rs +++ b/atoma-input-manager/src/firebase/mod.rs @@ -15,17 +15,13 @@ const SLEEP_BETWEEN_REQUESTS_SEC: u64 = 1; /// `FirebaseInputManager` - Responsible for getting the prompt from the user pub struct FirebaseInputManager { /// The Atoma's firebase URL - firebase_url: Url, - auth: Arc>, + firebase: Firebase, } impl FirebaseInputManager { /// Constructor pub fn new(firebase: Firebase) -> Self { - Self { - auth: firebase.get_auth(), - firebase_url: firebase.get_realtime_db_url(), - } + Self { firebase } } /// Handles a new chat request. Encapsulates the logic necessary @@ -36,8 +32,8 @@ impl FirebaseInputManager { request_id: String, ) -> Result { let client = Client::new(); - let token = self.auth.lock().await.get_id_token().await?; - let mut url = self.firebase_url.clone(); + let token = self.firebase.get_id_token().await?; + let mut url = self.firebase.get_realtime_db_url().clone(); { let mut path_segment = url .path_segments_mut() @@ -57,7 +53,7 @@ impl FirebaseInputManager { if let Some(previous_transaction) = json.get("previous_transaction") { let previous_transaction = previous_transaction.as_str().unwrap(); // There is a previous transaction from which we can get the context tokens - let mut url = self.firebase_url.clone(); + let mut url = self.firebase.get_realtime_db_url().clone(); { let mut path_segment = url.path_segments_mut().map_err(|_| { AtomaInputManagerError::UrlError("URL is not valid".to_string()) diff --git a/atoma-output-manager/src/firebase/mod.rs b/atoma-output-manager/src/firebase/mod.rs index 90ef28e2..b06aa4de 100644 --- a/atoma-output-manager/src/firebase/mod.rs +++ b/atoma-output-manager/src/firebase/mod.rs @@ -16,19 +16,13 @@ use crate::AtomaOutputManagerError; /// tech stack, it is fine for applications such as chat applications. pub struct FirebaseOutputManager { /// The Atoma's firebase URL - realtime_db_url: Url, - storage_url: Url, - auth: Arc>, + firebase: Firebase, } impl FirebaseOutputManager { /// Constructor pub fn new(firebase: Firebase) -> Self { - Self { - auth: firebase.get_auth(), - realtime_db_url: firebase.get_realtime_db_url(), - storage_url: firebase.get_storage_url(), - } + Self { firebase } } /// Handles a new post request. Encapsulates the logic necessary @@ -41,11 +35,11 @@ impl FirebaseOutputManager { ipfs_cid: Option, ) -> Result<(), AtomaOutputManagerError> { let client = Client::new(); - let token = self.auth.lock().await.get_id_token().await?; + let token = self.firebase.get_id_token().await?; match output_metadata.output_type { OutputType::Text => { - let mut url = self.realtime_db_url.clone(); + let mut url = self.firebase.get_realtime_db_url(); { let mut path_segment = url.path_segments_mut().map_err(|_| { AtomaOutputManagerError::UrlError("URL is not valid".to_string()) @@ -67,7 +61,7 @@ impl FirebaseOutputManager { }); submit_put_request(&client, url, &data).await?; if !output_metadata.tokens.is_empty() { - let mut url = self.realtime_db_url.clone(); + let mut url = self.firebase.get_realtime_db_url(); { let mut path_segment = url.path_segments_mut().map_err(|_| { AtomaOutputManagerError::UrlError("URL is not valid".to_string()) @@ -83,7 +77,7 @@ impl FirebaseOutputManager { } OutputType::Image => { // First store the metadata - let mut realtime_db_url = self.realtime_db_url.clone(); + let mut realtime_db_url = self.firebase.get_realtime_db_url(); { let mut path_segment = realtime_db_url.path_segments_mut().map_err(|_| { AtomaOutputManagerError::UrlError("URL is not valid".to_string()) @@ -113,7 +107,7 @@ impl FirebaseOutputManager { }; submit_put_request(&client, realtime_db_url, &data).await?; // Then store the image - let mut storage_url = self.storage_url.clone(); + let mut storage_url = self.firebase.get_storage_url(); storage_url.set_query(Some(&format!( "name=images/{}.png", output_metadata.output_destination.request_id() diff --git a/atoma-streamer/src/lib.rs b/atoma-streamer/src/lib.rs index 372ffc2c..b04d854f 100644 --- a/atoma-streamer/src/lib.rs +++ b/atoma-streamer/src/lib.rs @@ -10,16 +10,14 @@ use tracing::{debug, error, info, instrument}; /// `AtomaStreamer` instance pub struct AtomaStreamer { - /// Firebase url - firebase_url: Url, + /// Firebase + firebase: Firebase, /// A `mpsc::Receiver` channel, listening to newly /// AI generated outputs streamer_rx: mpsc::Receiver, /// Last streamed index mapping, for each /// `Digest` last_streamed_index: HashMap, - /// Firebase authentication - auth: Arc>, } impl AtomaStreamer { @@ -29,10 +27,9 @@ impl AtomaStreamer { firebase: Firebase, ) -> Result { Ok(Self { - firebase_url: firebase.get_realtime_db_url(), + firebase, streamer_rx, last_streamed_index: HashMap::new(), - auth: firebase.get_auth(), }) } @@ -62,8 +59,8 @@ impl AtomaStreamer { data: String, ) -> Result<(), AtomaStreamerError> { let client = Client::new(); - let mut url = self.firebase_url.clone(); - let token = self.auth.lock().await.get_id_token().await?; + let mut url = self.firebase.get_realtime_db_url().clone(); + let token = self.firebase.get_id_token().await?; { let mut path_segment = url .path_segments_mut()