diff --git a/src/helpers.rs b/src/helpers.rs index e735c2d0..ae7a432d 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -62,14 +62,14 @@ pub async fn download>(url: &str, path: P) -> Result<(), Binstall debug!("Downloading to file: '{}'", path.display()); let mut bytes_stream = resp.bytes_stream(); - let writer = AsyncFileWriter::new(path)?; + let mut writer = AsyncFileWriter::new(path)?; let guard = scopeguard::guard(path, |path| { fs::remove_file(path).ok(); }); while let Some(res) = bytes_stream.next().await { - writer.write(res?).await; + writer.write(res?).await?; } writer.done().await?; @@ -254,19 +254,47 @@ impl AsyncFileWriter { Ok(Self { handle, tx }) } - pub async fn write(&self, bytes: Bytes) { - self.tx - .send(bytes) - .await - .expect("Implementation bug: rx is closed before tx is dropped") + /// Upon error, this writer shall not be reused. + /// Otherwise, `Self::done` would panic. + pub async fn write(&mut self, bytes: Bytes) -> io::Result<()> { + let send_future = async { + self.tx + .send(bytes) + .await + .expect("Implementation bug: rx is closed before tx is dropped") + }; + tokio::pin!(send_future); + + let task_future = async { + Self::wait(&mut self.handle).await.map(|_| { + panic!("Implementation bug: write task finished before all writes are done") + }) + }; + tokio::pin!(task_future); + + // Use select to run them in parallel, so that if the send blocks + // the current future and the task failed with some error, the future + // returned by this function would not block forever. + tokio::select! { + // It isn't completely safe to cancel the send_future as it would + // cause us to lose our place in the queue, but if the send_future + // is cancelled, it means that the task has failed and the mpsc + // won't matter anyway. + _ = send_future => Ok(()), + res = task_future => res, + } } - pub async fn done(self) -> io::Result<()> { + pub async fn done(mut self) -> io::Result<()> { // Drop tx as soon as possible so that the task would wrap up what it // was doing and flush out all the pending data. drop(self.tx); - match self.handle.await { + Self::wait(&mut self.handle).await + } + + async fn wait(handle: &mut task::JoinHandle>) -> io::Result<()> { + match handle.await { Ok(res) => res, Err(join_err) => Err(io::Error::new(io::ErrorKind::Other, join_err)), }