Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly refresh expired token #155

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions atoma-helpers/src/firebase/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ pub struct SignInResponse {
local_id: String,
}

enum Response {
SignIn(SignInResponse),
Refresh(RefreshResponse),
}

#[derive(Debug, Deserialize)]
pub struct RefreshResponse {
/// These are not all the fields returned, but these are all the fields we need
#[serde(rename = "id_token")]
id_token: String,
#[serde(rename = "refresh_token")]
refresh_token: String,
#[serde(rename = "expires_in")]
expires_in: String,
}

impl FirebaseAuth {
pub(crate) async fn new(api_key: String) -> Result<Self, FirebaseAuthError> {
let mut res = Self {
Expand All @@ -76,19 +92,19 @@ impl FirebaseAuth {
}

/// Sign up with email and password
pub async fn sign_up(&self) -> Result<SignInResponse, FirebaseAuthError> {
async fn sign_up(&self) -> Result<Response, FirebaseAuthError> {
let client = Client::new();
let url = SIGN_UP_URL(&self.api_key);
let res = client
.post(url)
.json(&json!({"returnSecureToken": true}))
.send()
.await?;
Ok(res.json::<SignInResponse>().await?)
Ok(Response::SignIn(res.json::<SignInResponse>().await?))
}

// The token is about to expire (or it already has), refresh it
pub async fn refresh(&mut self) -> Result<(), FirebaseAuthError> {
async fn refresh(&mut self) -> Result<(), FirebaseAuthError> {
let client = Client::new();
let url = REFRESH_URL(&self.api_key);
let res = client
Expand All @@ -97,27 +113,38 @@ impl FirebaseAuth {
.send()
.await?;
let response = if res.status().is_success() {
res.json::<SignInResponse>().await?
Response::Refresh(res.json::<RefreshResponse>().await?)
} else {
// In rare occasions, the refresh token may expire, in which case we need to sign in again
self.sign_up().await?
};

self.set_from_response(response)?;
Ok(())
}

/// Set the fields from a firebase response
pub fn set_from_response(&mut self, response: SignInResponse) -> Result<(), FirebaseAuthError> {
self.expires_in = Some(response.expires_in.parse()?);
self.id_token = Some(response.id_token);
self.refresh_token = Some(response.refresh_token);
self.requested_at = Some(Instant::now());
self.local_id = Some(response.local_id);
fn set_from_response(&mut self, response: Response) -> Result<(), FirebaseAuthError> {
match response {
Response::SignIn(response) => {
self.expires_in = Some(response.expires_in.parse()?);
self.id_token = Some(response.id_token);
self.refresh_token = Some(response.refresh_token);
self.requested_at = Some(Instant::now());
self.local_id = Some(response.local_id);
}
Response::Refresh(response) => {
self.expires_in = Some(response.expires_in.parse()?);
self.id_token = Some(response.id_token);
self.refresh_token = Some(response.refresh_token);
self.requested_at = Some(Instant::now());
}
};
Ok(())
}

/// Get the id_token
pub async fn get_id_token(&mut self) -> Result<String, FirebaseAuthError> {
pub(crate) async fn get_id_token(&mut self) -> Result<String, FirebaseAuthError> {
// If the id_token is None, we need to sign in
if self.id_token.is_none() {
let response = self.sign_up().await?;
Expand All @@ -127,8 +154,7 @@ impl FirebaseAuth {
if self.requested_at.unwrap().elapsed().as_secs() as usize
>= self.expires_in.unwrap() - EXPIRATION_DELTA
{
let response = self.sign_up().await.unwrap();
self.set_from_response(response)?;
self.refresh().await?;
}
// Return the id_token that is valid at least `EXPIRATION_DELTA` seconds
Ok(self.id_token.clone().unwrap())
Expand Down
36 changes: 28 additions & 8 deletions atoma-helpers/src/firebase/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct Firebase {
auth: Arc<Mutex<FirebaseAuth>>,
realtime_db_url: Url,
storage_url: Url,
node_id: SmallId,
}

impl Firebase {
Expand All @@ -22,26 +23,45 @@ impl Firebase {
storage_url: Url,
node_id: SmallId,
) -> Result<Self, FirebaseAuthError> {
let mut auth = FirebaseAuth::new(api_key).await?;
let auth = FirebaseAuth::new(api_key).await?;
let firebase = Self {
auth: Arc::new(Mutex::new(auth)),
realtime_db_url,
storage_url,
node_id,
};
firebase.store_node_id().await?;
Ok(firebase)
}

pub async fn store_node_id(&self) -> Result<(), FirebaseAuthError> {
let client = Client::new();
let mut auth = self.auth.lock().await;
let token = auth.get_id_token().await?;
let mut add_node_url = realtime_db_url.clone();
let mut add_node_url = self.realtime_db_url.clone();
{
let mut path_segment = add_node_url.path_segments_mut().unwrap();
path_segment.push("nodes");
path_segment.push(&format!("{}.json", auth.get_local_id()?));
}
add_node_url.set_query(Some(&format!("auth={token}")));
let data = json!({
"id":node_id.to_string()
"id":self.node_id.to_string()
});
client.put(add_node_url).json(&data).send().await?;
Ok(())
}

Ok(Self {
auth: Arc::new(Mutex::new(auth)),
realtime_db_url,
storage_url,
})
pub async fn get_id_token(&self) -> Result<String, FirebaseAuthError> {
let mut auth = self.auth.lock().await;
let old_local_id = auth.get_local_id()?;
let id_token = auth.get_id_token().await?;
let new_local_id = auth.get_local_id()?;
if old_local_id != new_local_id {
// The local id has changed, so we need to store the new node id
self.store_node_id().await?;
}
Ok(id_token)
}

pub fn get_auth(&self) -> Arc<Mutex<FirebaseAuth>> {
Expand Down
14 changes: 5 additions & 9 deletions atoma-input-manager/src/firebase/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@ const SLEEP_BETWEEN_REQUESTS_SEC: u64 = 1;
/// `FirebaseInputManager` - Responsible for getting the prompt from the user
pub struct FirebaseInputManager {
/// The Atoma's firebase URL
firebase_url: Url,
auth: Arc<Mutex<FirebaseAuth>>,
firebase: Firebase,
}

impl FirebaseInputManager {
/// Constructor
pub fn new(firebase: Firebase) -> Self {
Self {
auth: firebase.get_auth(),
firebase_url: firebase.get_realtime_db_url(),
}
Self { firebase }
}

/// Handles a new chat request. Encapsulates the logic necessary
Expand All @@ -36,8 +32,8 @@ impl FirebaseInputManager {
request_id: String,
) -> Result<ModelInput, AtomaInputManagerError> {
let client = Client::new();
let token = self.auth.lock().await.get_id_token().await?;
let mut url = self.firebase_url.clone();
let token = self.firebase.get_id_token().await?;
let mut url = self.firebase.get_realtime_db_url().clone();
{
let mut path_segment = url
.path_segments_mut()
Expand All @@ -57,7 +53,7 @@ impl FirebaseInputManager {
if let Some(previous_transaction) = json.get("previous_transaction") {
let previous_transaction = previous_transaction.as_str().unwrap();
// There is a previous transaction from which we can get the context tokens
let mut url = self.firebase_url.clone();
let mut url = self.firebase.get_realtime_db_url().clone();
{
let mut path_segment = url.path_segments_mut().map_err(|_| {
AtomaInputManagerError::UrlError("URL is not valid".to_string())
Expand Down
20 changes: 7 additions & 13 deletions atoma-output-manager/src/firebase/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,13 @@ use crate::AtomaOutputManagerError;
/// tech stack, it is fine for applications such as chat applications.
pub struct FirebaseOutputManager {
/// The Atoma's firebase URL
realtime_db_url: Url,
storage_url: Url,
auth: Arc<Mutex<FirebaseAuth>>,
firebase: Firebase,
}

impl FirebaseOutputManager {
/// Constructor
pub fn new(firebase: Firebase) -> Self {
Self {
auth: firebase.get_auth(),
realtime_db_url: firebase.get_realtime_db_url(),
storage_url: firebase.get_storage_url(),
}
Self { firebase }
}

/// Handles a new post request. Encapsulates the logic necessary
Expand All @@ -41,11 +35,11 @@ impl FirebaseOutputManager {
ipfs_cid: Option<String>,
) -> Result<(), AtomaOutputManagerError> {
let client = Client::new();
let token = self.auth.lock().await.get_id_token().await?;
let token = self.firebase.get_id_token().await?;

match output_metadata.output_type {
OutputType::Text => {
let mut url = self.realtime_db_url.clone();
let mut url = self.firebase.get_realtime_db_url();
{
let mut path_segment = url.path_segments_mut().map_err(|_| {
AtomaOutputManagerError::UrlError("URL is not valid".to_string())
Expand All @@ -67,7 +61,7 @@ impl FirebaseOutputManager {
});
submit_put_request(&client, url, &data).await?;
if !output_metadata.tokens.is_empty() {
let mut url = self.realtime_db_url.clone();
let mut url = self.firebase.get_realtime_db_url();
{
let mut path_segment = url.path_segments_mut().map_err(|_| {
AtomaOutputManagerError::UrlError("URL is not valid".to_string())
Expand All @@ -83,7 +77,7 @@ impl FirebaseOutputManager {
}
OutputType::Image => {
// First store the metadata
let mut realtime_db_url = self.realtime_db_url.clone();
let mut realtime_db_url = self.firebase.get_realtime_db_url();
{
let mut path_segment = realtime_db_url.path_segments_mut().map_err(|_| {
AtomaOutputManagerError::UrlError("URL is not valid".to_string())
Expand Down Expand Up @@ -113,7 +107,7 @@ impl FirebaseOutputManager {
};
submit_put_request(&client, realtime_db_url, &data).await?;
// Then store the image
let mut storage_url = self.storage_url.clone();
let mut storage_url = self.firebase.get_storage_url();
storage_url.set_query(Some(&format!(
"name=images/{}.png",
output_metadata.output_destination.request_id()
Expand Down
13 changes: 5 additions & 8 deletions atoma-streamer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ use tracing::{debug, error, info, instrument};

/// `AtomaStreamer` instance
pub struct AtomaStreamer {
/// Firebase url
firebase_url: Url,
/// Firebase
firebase: Firebase,
/// A `mpsc::Receiver` channel, listening to newly
/// AI generated outputs
streamer_rx: mpsc::Receiver<AtomaStreamingData>,
/// Last streamed index mapping, for each
/// `Digest`
last_streamed_index: HashMap<String, usize>,
/// Firebase authentication
auth: Arc<Mutex<FirebaseAuth>>,
}

impl AtomaStreamer {
Expand All @@ -29,10 +27,9 @@ impl AtomaStreamer {
firebase: Firebase,
) -> Result<Self, AtomaStreamerError> {
Ok(Self {
firebase_url: firebase.get_realtime_db_url(),
firebase,
streamer_rx,
last_streamed_index: HashMap::new(),
auth: firebase.get_auth(),
})
}

Expand Down Expand Up @@ -62,8 +59,8 @@ impl AtomaStreamer {
data: String,
) -> Result<(), AtomaStreamerError> {
let client = Client::new();
let mut url = self.firebase_url.clone();
let token = self.auth.lock().await.get_id_token().await?;
let mut url = self.firebase.get_realtime_db_url().clone();
let token = self.firebase.get_id_token().await?;
{
let mut path_segment = url
.path_segments_mut()
Expand Down
Loading