Skip to content

Commit

Permalink
fix cache (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
xlc authored Sep 25, 2023
1 parent d60e46f commit 1c5fa4e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/middlewares/methods/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Middleware<CallRequest, Result<JsonValue, ErrorObjectOwned>> for CacheMiddl

let result = self
.cache
.get_or_insert_with(&key, || next(request, context).boxed())
.get_or_insert_with(key.clone(), || next(request, context).boxed())
.await;

if let Ok(ref value) = result {
Expand Down
93 changes: 89 additions & 4 deletions src/utils/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,25 @@ impl<D: Digest + 'static> Cache<D> {

pub async fn get_or_insert_with<F>(
&self,
key: &CacheKey<D>,
key: CacheKey<D>,
f: F,
) -> Result<JsonValue, ErrorObjectOwned>
where
F: FnOnce() -> BoxFuture<'static, Result<JsonValue, ErrorObjectOwned>>,
{
match self.cache.get(key) {
match self.cache.get(&key) {
Some(CacheValue::Value(value)) => Ok(value),
Some(CacheValue::Pending(rx)) => rx.borrow().clone().unwrap(),
Some(CacheValue::Pending(mut rx)) => {
{
let value = rx.borrow();
if value.is_some() {
return value.clone().unwrap();
}
}
let _ = rx.changed().await;
let value = rx.borrow();
value.clone().expect("Cache: should always be Some")
}
None => {
let (tx, rx) = watch::channel(None);
self.cache
Expand All @@ -119,7 +129,7 @@ impl<D: Digest + 'static> Cache<D> {
.await;
}
Err(_) => {
self.cache.remove(key).await;
self.cache.remove(&key).await;
}
};
value
Expand All @@ -137,3 +147,78 @@ impl<D: Digest + 'static> Cache<D> {
self.cache.sync();
}
}

#[cfg(test)]
mod tests {
use super::*;
use futures::FutureExt as _;
use serde_json::json;

#[tokio::test]
async fn get_insert_remove() {
let cache = Cache::<blake2::Blake2b512>::new(NonZeroUsize::new(1).unwrap(), None);

let key = CacheKey::<blake2::Blake2b512>::new(&"key".to_string(), &[]);

assert_eq!(cache.get(&key).await, None);

cache.insert(key.clone(), json!("value")).await;

assert_eq!(cache.get(&key).await, Some(json!("value")));

cache.remove(&key).await;

assert_eq!(cache.get(&key).await, None);
}

#[tokio::test]
async fn get_or_insert_with_basic() {
let cache = Cache::<blake2::Blake2b512>::new(NonZeroUsize::new(1).unwrap(), None);

let key = CacheKey::<blake2::Blake2b512>::new(&"key".to_string(), &[]);

let (tx, rx) = tokio::sync::oneshot::channel::<()>();

let cache2 = cache.clone();
let key2 = key.clone();
let h1 = tokio::spawn(async move {
let value = cache2
.get_or_insert_with(key2.clone(), || {
async move {
let _ = rx.await;
Ok(json!("value"))
}
.boxed()
})
.await;
assert_eq!(value, Ok(json!("value")));
});

tokio::task::yield_now().await;

let cache2 = cache.clone();
let key2 = key.clone();
let h2 = tokio::spawn(async move {
println!("5");

let value = cache2
.get_or_insert_with(key2, || {
async {
panic!();
}
.boxed()
})
.await;
assert_eq!(value, Ok(json!("value")));
});

tokio::task::yield_now().await;

tx.send(()).unwrap();

h1.await.unwrap();
h2.await.unwrap();

assert_eq!(cache.get(&key).await, Some(json!("value")));
}
}

0 comments on commit 1c5fa4e

Please sign in to comment.