From aa6012baaeb6710f91a620c5be573a227696eaad Mon Sep 17 00:00:00 2001
From: Jiahao XU <Jiahao_XU@outlook.com>
Date: Thu, 13 Oct 2022 11:31:13 +1100
Subject: [PATCH] Make extraction cancellable for `bin` and `tar` based formats
 (#481)

Extraction wasn't cancellable by `cancel_on_user_sig_term` used in `entry` since it calls `block_in_place`.

This PR adds cancellation support to it by adding a `static` variable `OnceCell` to `wait_on_cancellation_signal` so that once it returns `Ok(())`, all other calls to it after that point also returns `Ok(())` immediately.

`StreamReadable`, which is used in cancellation process, then stores a boxed future of `wait_on_cancellation_signal` and polled it in `BufReader::fill_buf`.

Note that for zip, the extraction process takes `File` instead of `StreamReadable` due to `io::Seek` requirement, so it cancelling during extraction for zip is still not possible.

This PR also optimized `extract_bin` and `extract_zip` by using `StreamReadable::copy` introduced to this PR instead of `io::copy`, which allocates an internal buffer on stack, which imposes extra copy.

It also fixed `StreamReadable::fill_buf` by ensuring that empty buffer is only returned on eof.

* Make `wait_on_cancellation_signal` pub
* Enable feature `parking_lot` of dep tokio
* Mod `wait_on_cancellation_signal`: Use `OnceCell` internally
   to archive the effect that once call to it return `Ok(())`, all calls to
   it after that also returns `Ok(())`.
* Impl `From<BinstallError>` for `io::Error`
* Impl cancellation on user signal in `StreamReadable`
* Fix err msg when cancelling during extraction in `ops::resolve`
* Optimize: Impl & use `StreamReadable::copy`
   which is same as `io::copy` but does not allocate any internal buffer
   since `StreamReadable` is buffered.
* Fix `next_stream`: Return non-empty bytes on `Ok(Some(bytes))`

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
---
 Cargo.lock                                    |  1 +
 crates/binstalk/Cargo.toml                    |  3 +-
 crates/binstalk/src/errors.rs                 | 18 ++++-
 .../src/helpers/download/async_extracter.rs   |  6 +-
 .../src/helpers/download/stream_readable.rs   | 79 ++++++++++++++++---
 crates/binstalk/src/helpers/signal.rs         | 19 ++++-
 crates/binstalk/src/ops/resolve.rs            |  3 +
 7 files changed, 109 insertions(+), 20 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index c2ffecfa..64765419 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2042,6 +2042,7 @@ dependencies = [
  "memchr",
  "mio",
  "num_cpus",
+ "parking_lot",
  "pin-project-lite",
  "signal-hook-registry",
  "socket2",
diff --git a/crates/binstalk/Cargo.toml b/crates/binstalk/Cargo.toml
index f1054b42..24d09797 100644
--- a/crates/binstalk/Cargo.toml
+++ b/crates/binstalk/Cargo.toml
@@ -46,7 +46,8 @@ tar = { package = "binstall-tar", version = "0.4.39" }
 tempfile = "3.3.0"
 thiserror = "1.0.37"
 tinytemplate = "1.2.1"
-tokio = { version = "1.21.2", features = ["macros", "rt", "process", "sync", "signal", "time"], default-features = false }
+# parking_lot - for OnceCell::const_new
+tokio = { version = "1.21.2", features = ["macros", "rt", "process", "sync", "signal", "time", "parking_lot"], default-features = false }
 toml_edit = { version = "0.14.4", features = ["easy"] }
 tower = { version = "0.4.13", features = ["limit", "util"] }
 trust-dns-resolver = { version = "0.21.2", optional = true, default-features = false, features = ["dnssec-ring"] }
diff --git a/crates/binstalk/src/errors.rs b/crates/binstalk/src/errors.rs
index 19b8554b..0fc36970 100644
--- a/crates/binstalk/src/errors.rs
+++ b/crates/binstalk/src/errors.rs
@@ -1,4 +1,5 @@
 use std::{
+    io,
     path::PathBuf,
     process::{ExitCode, ExitStatus, Termination},
 };
@@ -99,7 +100,7 @@ pub enum BinstallError {
     /// - Exit: 74
     #[error(transparent)]
     #[diagnostic(severity(error), code(binstall::io))]
-    Io(std::io::Error),
+    Io(io::Error),
 
     /// An error interacting with the crates.io API.
     ///
@@ -392,8 +393,8 @@ impl Termination for BinstallError {
     }
 }
 
-impl From<std::io::Error> for BinstallError {
-    fn from(err: std::io::Error) -> Self {
+impl From<io::Error> for BinstallError {
+    fn from(err: io::Error) -> Self {
         if err.get_ref().is_some() {
             let kind = err.kind();
 
@@ -404,9 +405,18 @@ impl From<std::io::Error> for BinstallError {
             inner
                 .downcast()
                 .map(|b| *b)
-                .unwrap_or_else(|err| BinstallError::Io(std::io::Error::new(kind, err)))
+                .unwrap_or_else(|err| BinstallError::Io(io::Error::new(kind, err)))
         } else {
             BinstallError::Io(err)
         }
     }
 }
+
+impl From<BinstallError> for io::Error {
+    fn from(e: BinstallError) -> io::Error {
+        match e {
+            BinstallError::Io(io_error) => io_error,
+            e => io::Error::new(io::ErrorKind::Other, e),
+        }
+    }
+}
diff --git a/crates/binstalk/src/helpers/download/async_extracter.rs b/crates/binstalk/src/helpers/download/async_extracter.rs
index 840cbf21..21311b3d 100644
--- a/crates/binstalk/src/helpers/download/async_extracter.rs
+++ b/crates/binstalk/src/helpers/download/async_extracter.rs
@@ -1,7 +1,7 @@
 use std::{
     fmt::Debug,
     fs,
-    io::{copy, Read, Seek},
+    io::{Read, Seek},
     path::Path,
 };
 
@@ -33,7 +33,7 @@ where
             fs::remove_file(path).ok();
         });
 
-        copy(&mut reader, &mut file)?;
+        reader.copy(&mut file)?;
 
         // Operation isn't aborted and all writes succeed,
         // disarm the remove_guard.
@@ -54,7 +54,7 @@ where
 
         let mut file = tempfile()?;
 
-        copy(&mut reader, &mut file)?;
+        reader.copy(&mut file)?;
 
         // rewind it so that we can pass it to unzip
         file.rewind()?;
diff --git a/crates/binstalk/src/helpers/download/stream_readable.rs b/crates/binstalk/src/helpers/download/stream_readable.rs
index bc450fb5..6685c6bf 100644
--- a/crates/binstalk/src/helpers/download/stream_readable.rs
+++ b/crates/binstalk/src/helpers/download/stream_readable.rs
@@ -1,24 +1,26 @@
 use std::{
     cmp::min,
-    io::{self, BufRead, Read},
+    future::Future,
+    io::{self, BufRead, Read, Write},
+    pin::Pin,
 };
 
 use bytes::{Buf, Bytes};
 use futures_util::stream::{Stream, StreamExt};
 use tokio::runtime::Handle;
 
-use crate::errors::BinstallError;
+use crate::{errors::BinstallError, helpers::signal::wait_on_cancellation_signal};
 
 /// This wraps an AsyncIterator as a `Read`able.
 /// It must be used in non-async context only,
 /// meaning you have to use it with
 /// `tokio::task::{block_in_place, spawn_blocking}` or
 /// `std::thread::spawn`.
-#[derive(Debug)]
 pub struct StreamReadable<S> {
     stream: S,
     handle: Handle,
     bytes: Bytes,
+    cancellation_future: Pin<Box<dyn Future<Output = Result<(), io::Error>> + Send>>,
 }
 
 impl<S> StreamReadable<S> {
@@ -27,6 +29,39 @@ impl<S> StreamReadable<S> {
             stream,
             handle: Handle::current(),
             bytes: Bytes::new(),
+            cancellation_future: Box::pin(wait_on_cancellation_signal()),
+        }
+    }
+}
+
+impl<S, E> StreamReadable<S>
+where
+    S: Stream<Item = Result<Bytes, E>> + Unpin,
+    BinstallError: From<E>,
+{
+    /// Copies from `self` to `writer`.
+    ///
+    /// Same as `io::copy` but does not allocate any internal buffer
+    /// since `self` is buffered.
+    pub(super) fn copy<W>(&mut self, mut writer: W) -> io::Result<()>
+    where
+        W: Write,
+    {
+        self.copy_inner(&mut writer)
+    }
+
+    fn copy_inner(&mut self, writer: &mut dyn Write) -> io::Result<()> {
+        loop {
+            let buf = self.fill_buf()?;
+            if buf.is_empty() {
+                // Eof
+                break Ok(());
+            }
+
+            writer.write_all(buf)?;
+
+            let n = buf.len();
+            self.consume(n);
         }
     }
 }
@@ -56,6 +91,27 @@ where
         Ok(n)
     }
 }
+
+/// If `Ok(Some(bytes))` if returned, then `bytes.is_empty() == false`.
+async fn next_stream<S, E>(stream: &mut S) -> io::Result<Option<Bytes>>
+where
+    S: Stream<Item = Result<Bytes, E>> + Unpin,
+    BinstallError: From<E>,
+{
+    loop {
+        let option = stream
+            .next()
+            .await
+            .transpose()
+            .map_err(BinstallError::from)?;
+
+        match option {
+            Some(bytes) if bytes.is_empty() => continue,
+            option => break Ok(option),
+        }
+    }
+}
+
 impl<S, E> BufRead for StreamReadable<S>
 where
     S: Stream<Item = Result<Bytes, E>> + Unpin,
@@ -65,13 +121,18 @@ where
         let bytes = &mut self.bytes;
 
         if !bytes.has_remaining() {
-            match self.handle.block_on(async { self.stream.next().await }) {
-                Some(Ok(new_bytes)) => *bytes = new_bytes,
-                Some(Err(e)) => {
-                    let e: BinstallError = e.into();
-                    return Err(io::Error::new(io::ErrorKind::Other, e));
+            let option = self.handle.block_on(async {
+                tokio::select! {
+                    res = next_stream(&mut self.stream) => res,
+                    res = self.cancellation_future.as_mut() => {
+                        Err(res.err().unwrap_or_else(|| io::Error::from(BinstallError::UserAbort)))
+                    },
                 }
-                None => (),
+            })?;
+
+            if let Some(new_bytes) = option {
+                // new_bytes are guaranteed to be non-empty.
+                *bytes = new_bytes;
             }
         }
         Ok(&*bytes)
diff --git a/crates/binstalk/src/helpers/signal.rs b/crates/binstalk/src/helpers/signal.rs
index e15ed8e1..d01041df 100644
--- a/crates/binstalk/src/helpers/signal.rs
+++ b/crates/binstalk/src/helpers/signal.rs
@@ -1,7 +1,7 @@
 use std::io;
 
 use futures_util::future::pending;
-use tokio::signal;
+use tokio::{signal, sync::OnceCell};
 
 use super::tasks::AutoAbortJoinHandle;
 use crate::errors::BinstallError;
@@ -24,12 +24,25 @@ pub async fn cancel_on_user_sig_term<T>(
     tokio::select! {
         res = handle => res,
         res = wait_on_cancellation_signal() => {
-            res.map_err(BinstallError::Io).and(Err(BinstallError::UserAbort))
+            res
+        .map_err(BinstallError::Io)
+                .and(Err(BinstallError::UserAbort))
         }
     }
 }
 
-async fn wait_on_cancellation_signal() -> Result<(), io::Error> {
+/// If call to it returns `Ok(())`, then all calls to this function after
+/// that also returns `Ok(())`.
+pub async fn wait_on_cancellation_signal() -> Result<(), io::Error> {
+    static CANCELLED: OnceCell<()> = OnceCell::const_new();
+
+    CANCELLED
+        .get_or_try_init(wait_on_cancellation_signal_inner)
+        .await
+        .copied()
+}
+
+async fn wait_on_cancellation_signal_inner() -> Result<(), io::Error> {
     #[cfg(unix)]
     async fn inner() -> Result<(), io::Error> {
         unix::wait_on_cancellation_signal_unix().await
diff --git a/crates/binstalk/src/ops/resolve.rs b/crates/binstalk/src/ops/resolve.rs
index 59fa9ad8..3335a130 100644
--- a/crates/binstalk/src/ops/resolve.rs
+++ b/crates/binstalk/src/ops/resolve.rs
@@ -260,6 +260,9 @@ async fn resolve_inner(
                         }
                     }
                     Err(err) => {
+                        if let BinstallError::UserAbort = err {
+                            return Err(err);
+                        }
                         warn!(
                             "Error while downloading and extracting from fetcher {}: {}",
                             fetcher.source_name(),