| use std::future::Future; |
| use std::pin::Pin; |
| use std::task::Context; |
| use std::task::Poll; |
| use std::time::Duration; |
| |
| use futures::ready; |
| use pin_project::pin_project; |
| |
| use crate::backoff::BackoffBuilder; |
| use crate::Backoff; |
| |
| /// Retryable will add retry support for functions that produces a futures with results. |
| /// |
| /// That means all types that implement `FnMut() -> impl Future<Output = Result<T, E>>` |
| /// will be able to use `retry`. |
| /// |
| /// For example: |
| /// |
| /// - Functions without extra args: |
| /// |
| /// ```ignore |
| /// async fn fetch() -> Result<String> { |
| /// Ok(reqwest::get("https://www.rust-lang.org").await?.text().await?) |
| /// } |
| /// ``` |
| /// |
| /// - Closures |
| /// |
| /// ```ignore |
| /// || async { |
| /// let x = reqwest::get("https://www.rust-lang.org") |
| /// .await? |
| /// .text() |
| /// .await?; |
| /// |
| /// Err(anyhow::anyhow!(x)) |
| /// } |
| /// ``` |
| /// |
| /// # Example |
| /// |
| /// ```no_run |
| /// use anyhow::Result; |
| /// use backon::ExponentialBuilder; |
| /// use backon::Retryable; |
| /// |
| /// async fn fetch() -> Result<String> { |
| /// Ok(reqwest::get("https://www.rust-lang.org") |
| /// .await? |
| /// .text() |
| /// .await?) |
| /// } |
| /// |
| /// #[tokio::main] |
| /// async fn main() -> Result<()> { |
| /// let content = fetch.retry(&ExponentialBuilder::default()).await?; |
| /// println!("fetch succeeded: {}", content); |
| /// |
| /// Ok(()) |
| /// } |
| /// ``` |
| pub trait Retryable< |
| B: BackoffBuilder, |
| T, |
| E, |
| Fut: Future<Output = Result<T, E>>, |
| FutureFn: FnMut() -> Fut, |
| > |
| { |
| /// Generate a new retry |
| fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Fut, FutureFn>; |
| } |
| |
| impl<B, T, E, Fut, FutureFn> Retryable<B, T, E, Fut, FutureFn> for FutureFn |
| where |
| B: BackoffBuilder, |
| Fut: Future<Output = Result<T, E>>, |
| FutureFn: FnMut() -> Fut, |
| { |
| fn retry(self, builder: &B) -> Retry<B::Backoff, T, E, Fut, FutureFn> { |
| Retry::new(self, builder.build()) |
| } |
| } |
| |
| /// Retry struct generated by [`Retryable`]. |
| #[pin_project] |
| pub struct Retry<B: Backoff, T, E, Fut: Future<Output = Result<T, E>>, FutureFn: FnMut() -> Fut> { |
| backoff: B, |
| retryable: fn(&E) -> bool, |
| notify: fn(&E, Duration), |
| future_fn: FutureFn, |
| |
| #[pin] |
| state: State<T, E, Fut>, |
| } |
| |
| impl<B, T, E, Fut, FutureFn> Retry<B, T, E, Fut, FutureFn> |
| where |
| B: Backoff, |
| Fut: Future<Output = Result<T, E>>, |
| FutureFn: FnMut() -> Fut, |
| { |
| /// Create a new retry. |
| fn new(future_fn: FutureFn, backoff: B) -> Self { |
| Retry { |
| backoff, |
| retryable: |_: &E| true, |
| notify: |_: &E, _: Duration| {}, |
| future_fn, |
| state: State::Idle, |
| } |
| } |
| |
| /// Set the conditions for retrying. |
| /// |
| /// If not specified, we treat all errors as retryable. |
| /// |
| /// # Examples |
| /// |
| /// ```no_run |
| /// use anyhow::Result; |
| /// use backon::ExponentialBuilder; |
| /// use backon::Retryable; |
| /// |
| /// async fn fetch() -> Result<String> { |
| /// Ok(reqwest::get("https://www.rust-lang.org") |
| /// .await? |
| /// .text() |
| /// .await?) |
| /// } |
| /// |
| /// #[tokio::main] |
| /// async fn main() -> Result<()> { |
| /// let content = fetch |
| /// .retry(&ExponentialBuilder::default()) |
| /// .when(|e| e.to_string() == "EOF") |
| /// .await?; |
| /// println!("fetch succeeded: {}", content); |
| /// |
| /// Ok(()) |
| /// } |
| /// ``` |
| pub fn when(mut self, retryable: fn(&E) -> bool) -> Self { |
| self.retryable = retryable; |
| self |
| } |
| |
| /// Set to notify for everything retrying. |
| /// |
| /// If not specified, this is a no-op. |
| /// |
| /// # Examples |
| /// |
| /// ```no_run |
| /// use std::time::Duration; |
| /// |
| /// use anyhow::Result; |
| /// use backon::ExponentialBuilder; |
| /// use backon::Retryable; |
| /// |
| /// async fn fetch() -> Result<String> { |
| /// Ok(reqwest::get("https://www.rust-lang.org") |
| /// .await? |
| /// .text() |
| /// .await?) |
| /// } |
| /// |
| /// #[tokio::main] |
| /// async fn main() -> Result<()> { |
| /// let content = fetch |
| /// .retry(&ExponentialBuilder::default()) |
| /// .notify(|err: &anyhow::Error, dur: Duration| { |
| /// println!("retrying error {:?} with sleeping {:?}", err, dur); |
| /// }) |
| /// .await?; |
| /// println!("fetch succeeded: {}", content); |
| /// |
| /// Ok(()) |
| /// } |
| /// ``` |
| pub fn notify(mut self, notify: fn(&E, Duration)) -> Self { |
| self.notify = notify; |
| self |
| } |
| } |
| |
| /// State maintains internal state of retry. |
| /// |
| /// # Notes |
| /// |
| /// `tokio::time::Sleep` is a very struct that occupy 640B, so we wrap it |
| /// into a `Pin<Box<_>>` to avoid this enum too large. |
| #[pin_project(project = StateProject)] |
| enum State<T, E, Fut: Future<Output = Result<T, E>>> { |
| Idle, |
| Polling(#[pin] Fut), |
| // TODO: we need to support other sleeper |
| Sleeping(#[pin] Pin<Box<tokio::time::Sleep>>), |
| } |
| |
| impl<T, E, Fut> Default for State<T, E, Fut> |
| where |
| Fut: Future<Output = Result<T, E>>, |
| { |
| fn default() -> Self { |
| State::Idle |
| } |
| } |
| |
| impl<B, T, E, Fut, FutureFn> Future for Retry<B, T, E, Fut, FutureFn> |
| where |
| B: Backoff, |
| Fut: Future<Output = Result<T, E>>, |
| FutureFn: FnMut() -> Fut, |
| { |
| type Output = Result<T, E>; |
| |
| fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
| let mut this = self.project(); |
| loop { |
| let state = this.state.as_mut().project(); |
| match state { |
| StateProject::Idle => { |
| let fut = (this.future_fn)(); |
| this.state.set(State::Polling(fut)); |
| continue; |
| } |
| StateProject::Polling(fut) => match ready!(fut.poll(cx)) { |
| Ok(v) => return Poll::Ready(Ok(v)), |
| Err(err) => { |
| // If input error is not retryable, return error directly. |
| if !(this.retryable)(&err) { |
| return Poll::Ready(Err(err)); |
| } |
| match this.backoff.next() { |
| None => return Poll::Ready(Err(err)), |
| Some(dur) => { |
| (this.notify)(&err, dur); |
| this.state |
| .set(State::Sleeping(Box::pin(tokio::time::sleep(dur)))); |
| continue; |
| } |
| } |
| } |
| }, |
| StateProject::Sleeping(sl) => { |
| ready!(sl.poll(cx)); |
| this.state.set(State::Idle); |
| continue; |
| } |
| } |
| } |
| } |
| } |
| |
| #[cfg(test)] |
| mod tests { |
| use std::time::Duration; |
| |
| use tokio::sync::Mutex; |
| |
| use super::*; |
| use crate::exponential::ExponentialBuilder; |
| |
| async fn always_error() -> anyhow::Result<()> { |
| Err(anyhow::anyhow!("test_query meets error")) |
| } |
| |
| #[tokio::test] |
| async fn test_retry() -> anyhow::Result<()> { |
| let result = always_error |
| .retry(&ExponentialBuilder::default().with_min_delay(Duration::from_millis(1))) |
| .await; |
| |
| assert!(result.is_err()); |
| assert_eq!("test_query meets error", result.unwrap_err().to_string()); |
| Ok(()) |
| } |
| |
| #[tokio::test] |
| async fn test_retry_with_not_retryable_error() -> anyhow::Result<()> { |
| let error_times = Mutex::new(0); |
| |
| let f = || async { |
| let mut x = error_times.lock().await; |
| *x += 1; |
| Err::<(), anyhow::Error>(anyhow::anyhow!("not retryable")) |
| }; |
| |
| let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)); |
| let result = f |
| .retry(&backoff) |
| // Only retry If error message is `retryable` |
| .when(|e| e.to_string() == "retryable") |
| .await; |
| |
| assert!(result.is_err()); |
| assert_eq!("not retryable", result.unwrap_err().to_string()); |
| // `f` always returns error "not retryable", so it should be executed |
| // only once. |
| assert_eq!(*error_times.lock().await, 1); |
| Ok(()) |
| } |
| |
| #[tokio::test] |
| async fn test_retry_with_retryable_error() -> anyhow::Result<()> { |
| let error_times = Mutex::new(0); |
| |
| let f = || async { |
| let mut x = error_times.lock().await; |
| *x += 1; |
| Err::<(), anyhow::Error>(anyhow::anyhow!("retryable")) |
| }; |
| |
| let backoff = ExponentialBuilder::default().with_min_delay(Duration::from_millis(1)); |
| let result = f |
| .retry(&backoff) |
| // Only retry If error message is `retryable` |
| .when(|e| e.to_string() == "retryable") |
| .await; |
| |
| assert!(result.is_err()); |
| assert_eq!("retryable", result.unwrap_err().to_string()); |
| // `f` always returns error "retryable", so it should be executed |
| // 4 times (retry 3 times). |
| assert_eq!(*error_times.lock().await, 4); |
| Ok(()) |
| } |
| } |