diff --git a/src/helpers.rs b/src/helpers.rs index 060aad8c..4a30d986 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -22,6 +22,7 @@ mod ui_thread; pub use ui_thread::UIThread; mod extracter; +mod stream_readable; mod readable_rx; diff --git a/src/helpers/stream_readable.rs b/src/helpers/stream_readable.rs new file mode 100644 index 00000000..17113591 --- /dev/null +++ b/src/helpers/stream_readable.rs @@ -0,0 +1,81 @@ +use std::cmp::min; +use std::io::{self, BufRead, Read}; + +use bytes::{Buf, Bytes}; +use futures_util::stream::{Stream, StreamExt}; +use tokio::runtime::Handle; + +use super::BinstallError; + +/// This wraps an AsyncIterator as a `Read`able. +/// It must be used in non-async context only, +/// meaning you have to use it with +/// `tokio::task::{block_in_place, spawn_blocking}` or +/// `std::thread::spawn`. +#[derive(Debug)] +pub(super) struct StreamReadable { + stream: S, + handle: Handle, + bytes: Bytes, +} + +impl StreamReadable { + pub(super) async fn new(stream: S) -> Self { + Self { + stream, + handle: Handle::current(), + bytes: Bytes::new(), + } + } +} + +impl Read for StreamReadable +where + S: Stream> + Unpin, + BinstallError: From, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if buf.is_empty() { + return Ok(0); + } + + if self.fill_buf()?.is_empty() { + return Ok(0); + } + + let bytes = &mut self.bytes; + + // copy_to_slice requires the bytes to have enough remaining bytes + // to fill buf. + let n = min(buf.len(), bytes.remaining()); + + bytes.copy_to_slice(&mut buf[..n]); + + Ok(n) + } +} +impl BufRead for StreamReadable +where + S: Stream> + Unpin, + BinstallError: From, +{ + fn fill_buf(&mut self) -> io::Result<&[u8]> { + let bytes = &mut self.bytes; + + if !bytes.has_remaining() { + match self.handle.block_on(async { self.stream.next().await }) { + Some(Ok(new_bytes)) => *bytes = new_bytes, + Some(Err(e)) => { + let e: BinstallError = e.into(); + return Err(io::Error::new(io::ErrorKind::Other, e)); + } + None => (), + } + } + Ok(&*bytes) + } + + fn consume(&mut self, amt: usize) { + self.bytes.advance(amt); + } +}