diff --git a/src/helpers.rs b/src/helpers.rs index 0194ad9b..2e010244 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -47,6 +47,9 @@ pub use crate_name::CrateName; mod flock; pub use flock::FileLock; +mod signal; +pub use signal::cancel_on_user_sig_term; + pub fn cargo_home() -> Result<&'static Path, io::Error> { static CARGO_HOME: OnceCell = OnceCell::new(); diff --git a/src/helpers/signal.rs b/src/helpers/signal.rs new file mode 100644 index 00000000..ecfe7c4a --- /dev/null +++ b/src/helpers/signal.rs @@ -0,0 +1,80 @@ +use futures_util::future::pending; +use std::io; +use tokio::signal; + +use super::{AutoAbortJoinHandle, BinstallError}; + +/// This function will poll the handle while listening for ctrl_c, +/// `SIGINT`, `SIGHUP`, `SIGTERM` and `SIGQUIT`. +/// +/// When signal is received, [`BinstallError::UserAbort`] will be returned. +/// +/// It would also ignore `SIGUSER1` and `SIGUSER2` on unix. +/// +/// This function uses [`tokio::signal`] and once exit, does not reset the default +/// signal handler, so be careful when using it. +pub async fn cancel_on_user_sig_term( + handle: AutoAbortJoinHandle, +) -> Result { + #[cfg(unix)] + unix::ignore_signals_on_unix()?; + + tokio::select! { + res = handle => res, + res = wait_on_cancellation_signal() => { + res.map_err(BinstallError::Io).and(Err(BinstallError::UserAbort)) + } + } +} + +async fn wait_on_cancellation_signal() -> Result<(), io::Error> { + #[cfg(unix)] + async fn inner() -> Result<(), io::Error> { + unix::wait_on_cancellation_signal_unix().await + } + + #[cfg(not(unix))] + async fn inner() -> Result<(), io::Error> { + // Use pending here so that tokio::select! would just skip this branch. + pending().await + } + + tokio::select! { + res = signal::ctrl_c() => res, + res = inner() => res, + } +} + +#[cfg(unix)] +mod unix { + use super::*; + use signal::unix::{signal, SignalKind}; + + /// Same as [`wait_on_cancellation_signal`] but is only available on unix. + pub async fn wait_on_cancellation_signal_unix() -> Result<(), io::Error> { + tokio::select! { + res = wait_for_signal_unix(SignalKind::interrupt()) => res, + res = wait_for_signal_unix(SignalKind::hangup()) => res, + res = wait_for_signal_unix(SignalKind::terminate()) => res, + res = wait_for_signal_unix(SignalKind::quit()) => res, + } + } + + /// Wait for first arrival of signal. + pub async fn wait_for_signal_unix(kind: signal::unix::SignalKind) -> Result<(), io::Error> { + let mut sig_listener = signal::unix::signal(kind)?; + if sig_listener.recv().await.is_some() { + Ok(()) + } else { + // Use pending() here for the same reason as above. + pending().await + } + } + + pub fn ignore_signals_on_unix() -> Result<(), BinstallError> { + drop(signal(SignalKind::user_defined1())?); + drop(signal(SignalKind::user_defined2())?); + + Ok(()) + } +}