Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error when reading from stdout splits UTF-8 codepoint #144

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/buffers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ pub const LINE_BUFFER_CAPACITY: usize = 1024;

/// The default capacity (in entries) of buffers storing a collection of items, usually lines.
pub const VEC_BUFFER_CAPACITY: usize = 1024;

/// If we need to split a codepiont in half, we know it won't have more than 4 bytes total.
pub const SPLIT_UTF8_CODEPOINT_CAPACITY: usize = 4;
7 changes: 2 additions & 5 deletions src/fake_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@ pub struct FakeReader {

impl FakeReader {
/// Construct a `FakeReader` from an iterator of strings.
pub fn with_str_chunks(chunks: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
pub fn with_byte_chunks<const N: usize>(chunks: [&[u8]; N]) -> Self {
Self {
chunks: chunks
.into_iter()
.map(|chunk| chunk.as_ref().bytes().collect())
.collect(),
chunks: chunks.into_iter().map(|chunk| chunk.to_vec()).collect(),
}
}
}
Expand Down
210 changes: 170 additions & 40 deletions src/incremental_reader.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
//! The [`IncrementalReader`] struct, which handles reading to delimiters without line buffering.

use std::borrow::Cow;
use std::pin::Pin;

use aho_corasick::AhoCorasick;
use line_span::LineSpans;
use miette::miette;
use miette::IntoDiagnostic;
use miette::WrapErr;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;

use crate::aho_corasick::AhoCorasickExt;
use crate::buffers::LINE_BUFFER_CAPACITY;
use crate::buffers::SPLIT_UTF8_CODEPOINT_CAPACITY;
use crate::buffers::VEC_BUFFER_CAPACITY;

/// A tool for incrementally reading from a stream like stdout (and forwarding that stream to a
Expand All @@ -32,12 +33,15 @@ use crate::buffers::VEC_BUFFER_CAPACITY;
pub struct IncrementalReader<R, W> {
/// The wrapped reader.
reader: Pin<Box<R>>,
/// The wrapped writer, if any.
writer: Option<Pin<Box<W>>>,
/// Lines we've already read since the last marker/chunk.
lines: String,
/// The line currently being written to.
line: String,
/// The wrapped writer, if any.
writer: Option<Pin<Box<W>>>,
/// We're not guaranteed that the data we read at one time is aligned on a UTF-8 boundary. If
/// that's the case, we store the data here until we get more data.
non_utf8: Vec<u8>,
}

impl<R, W> IncrementalReader<R, W>
Expand All @@ -49,9 +53,10 @@ where
pub fn new(reader: R) -> Self {
Self {
reader: Box::pin(reader),
writer: None,
lines: String::with_capacity(VEC_BUFFER_CAPACITY * LINE_BUFFER_CAPACITY),
line: String::with_capacity(LINE_BUFFER_CAPACITY),
writer: None,
non_utf8: Vec::with_capacity(SPLIT_UTF8_CODEPOINT_CAPACITY),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're never going to exceed 4 bytes here, would it be more efficient to use &[u8; 4] here instead of Vec<u8>?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or more correct, because you couldn't accidentally allocate more space.

}
}

Expand Down Expand Up @@ -97,15 +102,8 @@ where
Err(miette!("End-of-file reached"))
}
Ok(n) => {
let decoded = std::str::from_utf8(&opts.buffer[..n])
.into_diagnostic()
.wrap_err_with(|| {
format!(
"Read invalid UTF-8: {:?}",
String::from_utf8_lossy(&opts.buffer[..n])
)
})?;
match self.consume_str(decoded, opts).await? {
let decoded = self.decode(&opts.buffer[..n]);
match self.consume_str(&decoded, opts).await? {
Some(lines) => {
tracing::trace!(data = decoded, "Decoded data");
tracing::trace!(lines = lines.len(), "Got chunk");
Expand All @@ -121,6 +119,49 @@ where
}
}

fn decode(&mut self, buffer: &[u8]) -> String {
// Do we have data we failed to decode?
let buffer = if self.non_utf8.is_empty() {
Cow::Borrowed(buffer)
} else {
// We have some data that failed to decode when we read it, add the data we just read
// and hope that completes a UTF-8 boundary:
let mut non_utf8 = std::mem::replace(
&mut self.non_utf8,
Vec::with_capacity(SPLIT_UTF8_CODEPOINT_CAPACITY),
);
non_utf8.extend(buffer);
Cow::Owned(non_utf8)
};

match std::str::from_utf8(&buffer) {
Ok(data) => data.to_owned(),
Err(err) => {
match err.error_len() {
Some(_) => {
// An unexpected byte was encountered; this is a "real" UTF-8 decode
// failure that we can't recover from by reading more data.
//
// As a backup, we'll log an error and decode the rest lossily.
tracing::error!("Failed to decode UTF-8 from `ghci`: {err}.\n\
This is a bug, please report it upstream: https://github.com/MercuryTechnologies/ghciwatch/issues/new");
String::from_utf8_lossy(&buffer).into_owned()
}
None => {
// End of input reached unexpectedly.
let valid_utf8 = &buffer[..err.valid_up_to()];
self.non_utf8.extend(&buffer[err.valid_up_to()..]);
unsafe {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the first unsafe block in the codebase? Some dubious honor there.

// Safety: We already confirmed the input contains valid UTF-8 up to
// this index.
std::str::from_utf8_unchecked(valid_utf8).to_owned()
}
}
}
}
}
}

/// Consumes a string, adding it to the internal buffer.
///
/// If one of the lines in `data` begins with `end_marker`, the lines in the internal buffer
Expand Down Expand Up @@ -372,8 +413,8 @@ mod tests {
/// Basic test. Reads data from the reader, gets the first chunk.
#[tokio::test]
async fn test_read_until() {
let fake_reader = FakeReader::with_str_chunks([indoc!(
"Build profile: -w ghc-9.6.1 -O0
let fake_reader = FakeReader::with_byte_chunks([indoc!(
b"Build profile: -w ghc-9.6.1 -O0
In order, the following will be built (use -v for more details):
- mwb-0 (lib:test-dev) (ephemeral targets)
Preprocessing library 'test-dev' for mwb-0..
Expand Down Expand Up @@ -411,8 +452,8 @@ mod tests {
/// Same as `test_read_until` but with `FindAt::Anywhere`.
#[tokio::test]
async fn test_read_until_find_anywhere() {
let fake_reader = FakeReader::with_str_chunks([indoc!(
"Build profile: -w ghc-9.6.1 -O0
let fake_reader = FakeReader::with_byte_chunks([indoc!(
b"Build profile: -w ghc-9.6.1 -O0
In order, the following will be built (use -v for more details):
- mwb-0 (lib:test-dev) (ephemeral targets)
Preprocessing library 'test-dev' for mwb-0..
Expand Down Expand Up @@ -557,29 +598,29 @@ mod tests {
/// chunks.
#[tokio::test]
async fn test_read_until_incremental() {
let fake_reader = FakeReader::with_str_chunks([
"Build profile: -w ghc-9.6.1 -O0\n",
"In order, the following will be built (use -v for more details):\n",
" - mwb-0 (lib:test-dev) (ephemeral targets)\n",
"Preprocessing library 'test-dev' for mwb-0..\n",
"GH",
"C",
"i",
",",
" ",
"v",
"e",
"r",
"s",
"i",
"o",
"n",
" ",
"9",
".6.1: https://www.haskell.org/ghc/ :? for help\n",
"Loaded GHCi configuration from .ghci-mwb",
"Ok, 5699 modules loaded.",
"ghci> ",
let fake_reader = FakeReader::with_byte_chunks([
b"Build profile: -w ghc-9.6.1 -O0\n",
b"In order, the following will be built (use -v for more details):\n",
b" - mwb-0 (lib:test-dev) (ephemeral targets)\n",
b"Preprocessing library 'test-dev' for mwb-0..\n",
b"GH",
b"C",
b"i",
b",",
b" ",
b"v",
b"e",
b"r",
b"s",
b"i",
b"o",
b"n",
b" ",
b"9",
b".6.1: https://www.haskell.org/ghc/ :? for help\n",
b"Loaded GHCi configuration from .ghci-mwb",
b"Ok, 5699 modules loaded.",
b"ghci> ",
]);
let mut reader = IncrementalReader::new(fake_reader).with_writer(tokio::io::sink());
let end_marker = AhoCorasick::from_anchored_patterns(["GHCi, version "]);
Expand Down Expand Up @@ -607,4 +648,93 @@ mod tests {

assert_eq!(reader.buffer(), String::new());
}

/// Test that we can keep reading when a chunk from `read()` splits a UTF-8 boundary.
async fn utf8_boundary<const N: usize>(chunks: [&'static [u8]; N], decoded: &'static str) {
let fake_reader = FakeReader::with_byte_chunks(chunks);
let mut reader = IncrementalReader::new(fake_reader).with_writer(tokio::io::sink());
let end_marker = AhoCorasick::from_anchored_patterns(["ghci> "]);
let mut buffer = vec![0; LINE_BUFFER_CAPACITY];

assert_eq!(
reader
.read_until(&mut ReadOpts {
end_marker: &end_marker,
find: FindAt::LineStart,
writing: WriteBehavior::Hide,
buffer: &mut buffer
})
.await
.unwrap(),
decoded,
"Failed to decode codepoint {decoded:?} when split across two chunks: {chunks:?}",
);

assert_eq!(reader.buffer(), String::new());
}

#[tokio::test]
async fn test_read_utf8_boundary_u_00a9() {
// U+00A9 ©
// 2 bytes, 1 test case.
utf8_boundary([b"\xc2", b"\xa9\nghci> "], "©\n").await;
}

#[tokio::test]
async fn test_read_utf8_boundary_u_2194() {
// U+2194 ↔
// 3 bytes, 2 test cases.
utf8_boundary([b"\xe2", b"\x86\x94\nghci> "], "↔\n").await;
utf8_boundary([b"\xe2\x86", b"\x94\nghci> "], "↔\n").await;
}

#[tokio::test]
async fn test_read_utf8_boundary_u_1f436() {
// U+1F436 🐶
// 4 bytes, 3 test cases.
utf8_boundary([b"\xf0", b"\x9f\x90\xb6\nghci> "], "🐶\n").await;
utf8_boundary([b"\xf0\x9f", b"\x90\xb6\nghci> "], "🐶\n").await;
utf8_boundary([b"\xf0\x9f\x90", b"\xb6\nghci> "], "🐶\n").await;
}

#[tokio::test]
async fn test_read_invalid_utf8_overlong() {
// Overlong sequence, U+20AC € encoded as 4 bytes.
// We get four U+FFFD � replacement characters out, one for each byte in the sequence.
utf8_boundary([b"\xf0", b"\x82\x82\xac\nghci> "], "����\n").await;
utf8_boundary([b"\xf0\x82", b"\x82\xac\nghci> "], "����\n").await;
utf8_boundary([b"\xf0\x82\x82", b"\xac\nghci> "], "����\n").await;
}

#[tokio::test]
async fn test_read_invalid_utf8_surrogate_pair_half() {
// Half of a surrogate pair, invalid in UTF-8. (U+D800)
utf8_boundary([b"\xed", b"\xa0\x80\nghci> "], "���\n").await;
utf8_boundary([b"\xed\xa0", b"\x80\nghci> "], "���\n").await;
}

#[tokio::test]
async fn test_read_invalid_utf8_unexpected_continuation() {
// An unexpected continuation byte.
utf8_boundary([b"\xa0\x80\nghci> "], "��\n").await;
utf8_boundary([b"\xa0", b"\x80\nghci> "], "��\n").await;
}

#[tokio::test]
async fn test_read_invalid_utf8_missing_continuation() {
// Missing continuation byte.
// Weirdly, these only come out as one replacement character, not the three we might
// naïvely expect.
utf8_boundary([b"\xf0", b"\x9f\x90\nghci> "], "�\n").await;
utf8_boundary([b"\xf0\x9f", b"\x90\nghci> "], "�\n").await;
}

#[tokio::test]
async fn test_read_invalid_utf8_invalid_byte() {
// Invalid byte (no defined meaning in UTF-8).
utf8_boundary([b"\xc0\nghci> "], "�\n").await;
utf8_boundary([b"\xc1\nghci> "], "�\n").await;
utf8_boundary([b"\xf5\nghci> "], "�\n").await;
utf8_boundary([b"\xff\nghci> "], "�\n").await;
}
}