diff --git a/Cargo.lock b/Cargo.lock index 137ca81f..1faf6fbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1683,9 +1683,21 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "socket2", + "tokio-macros", "winapi", ] +[[package]] +name = "tokio-macros" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9724f9a975fb987ef7a3cd9be0350edcbe130698af5b8f7a631e23d42d052484" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-native-tls" version = "0.3.0" diff --git a/Cargo.toml b/Cargo.toml index 7f41b3d9..a72083c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ tar = "0.4.38" tempfile = "3.3.0" thiserror = "1.0.32" tinytemplate = "1.2.1" -tokio = { version = "1.20.1", features = ["rt-multi-thread", "process", "sync"], default-features = false } +tokio = { version = "1.20.1", features = ["macros", "rt-multi-thread", "process", "sync", "signal"], default-features = false } toml_edit = { version = "0.14.4", features = ["easy"] } url = { version = "2.2.2", features = ["serde"] } xz2 = "0.1.7" diff --git a/src/errors.rs b/src/errors.rs index 1fbfa54e..c1d1de22 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -18,7 +18,9 @@ pub enum BinstallError { #[diagnostic(severity(error), code(binstall::internal::task_join))] TaskJoinError(#[from] task::JoinError), - /// The installation was cancelled by a user at a confirmation prompt. + /// The installation was cancelled by a user at a confirmation prompt, + /// or user send a ctrl_c on all platforms or + /// `SIGINT`, `SIGHUP`, `SIGTERM` or `SIGQUIT` on unix to the program. /// /// - Code: `binstall::user_abort` /// - Exit: 32 diff --git a/src/fetchers.rs b/src/fetchers.rs index 952b90fe..2d12e13e 100644 --- a/src/fetchers.rs +++ b/src/fetchers.rs @@ -64,7 +64,7 @@ impl MultiFetcher { pub fn add(&mut self, fetcher: Arc<dyn Fetcher>) { self.0.push(( fetcher.clone(), - AutoAbortJoinHandle::new(tokio::spawn(async move { fetcher.find().await })), + AutoAbortJoinHandle::spawn(async move { fetcher.find().await }), )); } diff --git a/src/fetchers/gh_crate_meta.rs b/src/fetchers/gh_crate_meta.rs index 6375d7df..43603760 100644 --- a/src/fetchers/gh_crate_meta.rs +++ b/src/fetchers/gh_crate_meta.rs @@ -41,13 +41,13 @@ impl super::Fetcher for GhCrateMeta { let checks = urls .map(|url| { let client = self.client.clone(); - AutoAbortJoinHandle::new(tokio::spawn(async move { + AutoAbortJoinHandle::spawn(async move { let url = url?; info!("Checking for package at: '{url}'"); remote_exists(client, url.clone(), Method::HEAD) .await .map(|exists| (url.clone(), exists)) - })) + }) }) .collect::<Vec<_>>(); diff --git a/src/helpers.rs b/src/helpers.rs index 0194ad9b..2e010244 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -47,6 +47,9 @@ pub use crate_name::CrateName; mod flock; pub use flock::FileLock; +mod signal; +pub use signal::cancel_on_user_sig_term; + pub fn cargo_home() -> Result<&'static Path, io::Error> { static CARGO_HOME: OnceCell<PathBuf> = OnceCell::new(); diff --git a/src/helpers/auto_abort_join_handle.rs b/src/helpers/auto_abort_join_handle.rs index fa476a8b..669f352a 100644 --- a/src/helpers/auto_abort_join_handle.rs +++ b/src/helpers/auto_abort_join_handle.rs @@ -5,7 +5,9 @@ use std::{ task::{Context, Poll}, }; -use tokio::task::{JoinError, JoinHandle}; +use tokio::task::JoinHandle; + +use super::BinstallError; #[derive(Debug)] pub struct AutoAbortJoinHandle<T>(JoinHandle<T>); @@ -16,6 +18,18 @@ impl<T> AutoAbortJoinHandle<T> { } } +impl<T> AutoAbortJoinHandle<T> +where + T: Send + 'static, +{ + pub fn spawn<F>(future: F) -> Self + where + F: Future<Output = T> + Send + 'static, + { + Self(tokio::spawn(future)) + } +} + impl<T> Drop for AutoAbortJoinHandle<T> { fn drop(&mut self) { self.0.abort(); @@ -37,9 +51,11 @@ impl<T> DerefMut for AutoAbortJoinHandle<T> { } impl<T> Future for AutoAbortJoinHandle<T> { - type Output = Result<T, JoinError>; + type Output = Result<T, BinstallError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { - Pin::new(&mut Pin::into_inner(self).0).poll(cx) + Pin::new(&mut Pin::into_inner(self).0) + .poll(cx) + .map(|res| res.map_err(BinstallError::TaskJoinError)) } } diff --git a/src/helpers/signal.rs b/src/helpers/signal.rs new file mode 100644 index 00000000..ecfe7c4a --- /dev/null +++ b/src/helpers/signal.rs @@ -0,0 +1,80 @@ +use futures_util::future::pending; +use std::io; +use tokio::signal; + +use super::{AutoAbortJoinHandle, BinstallError}; + +/// This function will poll the handle while listening for ctrl_c, +/// `SIGINT`, `SIGHUP`, `SIGTERM` and `SIGQUIT`. +/// +/// When signal is received, [`BinstallError::UserAbort`] will be returned. +/// +/// It would also ignore `SIGUSER1` and `SIGUSER2` on unix. +/// +/// This function uses [`tokio::signal`] and once exit, does not reset the default +/// signal handler, so be careful when using it. +pub async fn cancel_on_user_sig_term<T>( + handle: AutoAbortJoinHandle<T>, +) -> Result<T, BinstallError> { + #[cfg(unix)] + unix::ignore_signals_on_unix()?; + + tokio::select! { + res = handle => res, + res = wait_on_cancellation_signal() => { + res.map_err(BinstallError::Io).and(Err(BinstallError::UserAbort)) + } + } +} + +async fn wait_on_cancellation_signal() -> Result<(), io::Error> { + #[cfg(unix)] + async fn inner() -> Result<(), io::Error> { + unix::wait_on_cancellation_signal_unix().await + } + + #[cfg(not(unix))] + async fn inner() -> Result<(), io::Error> { + // Use pending here so that tokio::select! would just skip this branch. + pending().await + } + + tokio::select! { + res = signal::ctrl_c() => res, + res = inner() => res, + } +} + +#[cfg(unix)] +mod unix { + use super::*; + use signal::unix::{signal, SignalKind}; + + /// Same as [`wait_on_cancellation_signal`] but is only available on unix. + pub async fn wait_on_cancellation_signal_unix() -> Result<(), io::Error> { + tokio::select! { + res = wait_for_signal_unix(SignalKind::interrupt()) => res, + res = wait_for_signal_unix(SignalKind::hangup()) => res, + res = wait_for_signal_unix(SignalKind::terminate()) => res, + res = wait_for_signal_unix(SignalKind::quit()) => res, + } + } + + /// Wait for first arrival of signal. + pub async fn wait_for_signal_unix(kind: signal::unix::SignalKind) -> Result<(), io::Error> { + let mut sig_listener = signal::unix::signal(kind)?; + if sig_listener.recv().await.is_some() { + Ok(()) + } else { + // Use pending() here for the same reason as above. + pending().await + } + } + + pub fn ignore_signals_on_unix() -> Result<(), BinstallError> { + drop(signal(SignalKind::user_defined1())?); + drop(signal(SignalKind::user_defined2())?); + + Ok(()) + } +} diff --git a/src/helpers/ui_thread.rs b/src/helpers/ui_thread.rs index 8daf945c..64a20c55 100644 --- a/src/helpers/ui_thread.rs +++ b/src/helpers/ui_thread.rs @@ -1,7 +1,9 @@ -use std::io::{self, BufRead, Write}; +use std::{ + io::{self, BufRead, Write}, + thread, +}; use tokio::sync::mpsc; -use tokio::task::spawn_blocking; use crate::BinstallError; @@ -19,7 +21,7 @@ impl UIThreadInner { let (request_tx, mut request_rx) = mpsc::channel(1); let (confirm_tx, confirm_rx) = mpsc::channel(10); - spawn_blocking(move || { + thread::spawn(move || { // This task should be the only one able to // access stdin let mut stdin = io::stdin().lock(); @@ -30,14 +32,14 @@ impl UIThreadInner { break; } - // Lock stdout so that nobody can interfere - // with confirmation. - let mut stdout = io::stdout().lock(); - let res = loop { - writeln!(&mut stdout, "Do you wish to continue? yes/[no]").unwrap(); - write!(&mut stdout, "? ").unwrap(); - stdout.flush().unwrap(); + { + let mut stdout = io::stdout().lock(); + + writeln!(&mut stdout, "Do you wish to continue? yes/[no]").unwrap(); + write!(&mut stdout, "? ").unwrap(); + stdout.flush().unwrap(); + } input.clear(); stdin.read_line(&mut input).unwrap(); diff --git a/src/main.rs b/src/main.rs index 35a3d9bf..8d74a90a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -218,23 +218,20 @@ fn main() -> MainExit { let start = Instant::now(); let rt = Runtime::new().unwrap(); - let handle = rt.spawn(entry(jobserver_client)); - let result = rt.block_on(handle); + let handle = AutoAbortJoinHandle::new(rt.spawn(entry(jobserver_client))); + let result = rt.block_on(cancel_on_user_sig_term(handle)); drop(rt); let done = start.elapsed(); debug!("run time: {done:?}"); - result.map_or_else( - |join_err| MainExit::Error(BinstallError::from(join_err)), - |res| { - res.map(|_| MainExit::Success(done)).unwrap_or_else(|err| { - err.downcast::<BinstallError>() - .map(MainExit::Error) - .unwrap_or_else(MainExit::Report) - }) - }, - ) + result.map_or_else(MainExit::Error, |res| { + res.map(|()| MainExit::Success(done)).unwrap_or_else(|err| { + err.downcast::<BinstallError>() + .map(MainExit::Error) + .unwrap_or_else(MainExit::Report) + }) + }) } async fn entry(jobserver_client: LazyJobserverClient) -> Result<()> { @@ -361,7 +358,7 @@ async fn entry(jobserver_client: LazyJobserverClient) -> Result<()> { let tasks: Vec<_> = crate_names .into_iter() .map(|crate_name| { - tokio::spawn(binstall::resolve( + AutoAbortJoinHandle::spawn(binstall::resolve( binstall_opts.clone(), crate_name, temp_dir_path.clone(), @@ -375,7 +372,7 @@ async fn entry(jobserver_client: LazyJobserverClient) -> Result<()> { // Confirm let mut resolutions = Vec::with_capacity(tasks.len()); for task in tasks { - resolutions.push(await_task(task).await?); + resolutions.push(task.await??); } uithread.confirm().await?; @@ -384,7 +381,7 @@ async fn entry(jobserver_client: LazyJobserverClient) -> Result<()> { resolutions .into_iter() .map(|resolution| { - tokio::spawn(binstall::install( + AutoAbortJoinHandle::spawn(binstall::install( resolution, binstall_opts.clone(), jobserver_client.clone(), @@ -403,7 +400,7 @@ async fn entry(jobserver_client: LazyJobserverClient) -> Result<()> { let crates_io_api_client = crates_io_api_client.clone(); let install_path = install_path.clone(); - tokio::spawn(async move { + AutoAbortJoinHandle::spawn(async move { let resolution = binstall::resolve( opts.clone(), crate_name, @@ -422,7 +419,7 @@ async fn entry(jobserver_client: LazyJobserverClient) -> Result<()> { let mut metadata_vec = Vec::with_capacity(tasks.len()); for task in tasks { - if let Some(metadata) = await_task(task).await? { + if let Some(metadata) = task.await?? { metadata_vec.push(metadata); } }