Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfir_lang/src/graph/ops/source_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub const SOURCE_FILE: OperatorConstraints = OperatorConstraints {
let file = #root::tokio::fs::File::from_std(file);
let bufread = #root::tokio::io::BufReader::new(file);
let lines = #root::tokio::io::AsyncBufReadExt::lines(bufread);
#root::tokio_stream::wrappers::LinesStream::new(lines)
#root::futures::stream::StreamExt::fuse(#root::tokio_stream::wrappers::LinesStream::new(lines))
};
};
let wc = WriteContextArgs {
Expand Down
8 changes: 5 additions & 3 deletions dfir_lang/src/graph/ops/source_interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ pub const SOURCE_INTERVAL: OperatorConstraints = OperatorConstraints {
let ident_intervalstream = wc.make_ident("intervalstream");
let mut write_prologue = quote_spanned! {op_span=>
let #ident_intervalstream =
#root::tokio_stream::StreamExt::map(
#root::tokio_stream::wrappers::IntervalStream::new(#root::tokio::time::interval(#arguments)),
|_| { }
#root::futures::stream::StreamExt::fuse(
#root::tokio_stream::StreamExt::map(
#root::tokio_stream::wrappers::IntervalStream::new(#root::tokio::time::interval(#arguments)),
|_| { }
)
);
};
let wc = WriteContextArgs {
Expand Down
2 changes: 1 addition & 1 deletion dfir_lang/src/graph/ops/source_stdin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub const SOURCE_STDIN: OperatorConstraints = OperatorConstraints {
use #root::tokio::io::AsyncBufReadExt;
let reader = #root::tokio::io::BufReader::new(#root::tokio::io::stdin());
let stdin_lines = #root::tokio_stream::wrappers::LinesStream::new(reader.lines());
stdin_lines
#root::futures::stream::StreamExt::fuse(stdin_lines)
};
};
let write_iterator = quote_spanned! {op_span=>
Expand Down
4 changes: 2 additions & 2 deletions dfir_lang/src/graph/ops/source_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ pub const SOURCE_STREAM: OperatorConstraints = OperatorConstraints {
// TODO(mingwei): use `::std::pin::pin!(..)`?
let mut #stream_ident = {
#[inline(always)]
fn check_stream<Stream: #root::futures::stream::Stream<Item = Item> + ::std::marker::Unpin, Item>(stream: Stream)
-> impl #root::futures::stream::Stream<Item = Item> + ::std::marker::Unpin
fn check_stream<Stream: #root::futures::stream::Stream<Item = Item> + #root::futures::stream::FusedStream + ::std::marker::Unpin, Item>(stream: Stream)
-> impl #root::futures::stream::Stream<Item = Item> + #root::futures::stream::FusedStream + ::std::marker::Unpin
Comment on lines +58 to +59
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
fn check_stream<Stream: #root::futures::stream::Stream<Item = Item> + #root::futures::stream::FusedStream + ::std::marker::Unpin, Item>(stream: Stream)
-> impl #root::futures::stream::Stream<Item = Item> + #root::futures::stream::FusedStream + ::std::marker::Unpin
fn check_stream<Stream: #root::futures::stream::FusedStream<Item = Item> + ::std::marker::Unpin, Item>(stream: Stream)
-> impl #root::futures::stream::FusedStream<Item = Item> + ::std::marker::Unpin

{
stream
}
Expand Down
3 changes: 2 additions & 1 deletion dfir_rs/examples/deadlock_detector/peer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::net::SocketAddr;

use dfir_rs::dfir_syntax;
use dfir_rs::scheduled::graph::Dfir;
use futures::StreamExt;
use tokio::io::AsyncBufReadExt;
use tokio::net::UdpSocket;
use tokio_stream::wrappers::LinesStream;
Expand All @@ -18,7 +19,7 @@ pub(crate) async fn run_detector(opts: Opts, peer_list: Vec<String>) {

// We provide a command line for users to type waits-for edges (u32,u32).
let reader = tokio::io::BufReader::new(tokio::io::stdin());
let stdin_lines = LinesStream::new(reader.lines());
let stdin_lines = LinesStream::new(reader.lines()).fuse();

let mut hf: Dfir = dfir_syntax! {
// fetch peers from file, convert ip:port to a SocketAddr, and tee
Expand Down
2 changes: 1 addition & 1 deletion dfir_rs/examples/rga/adjacency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::protocol::{Timestamp, Token};

pub(crate) fn rga_adjacency(
input_recv: UnboundedReceiverStream<(Token, Timestamp)>,
input_recv: futures::stream::Fuse<UnboundedReceiverStream<(Token, Timestamp)>>,
rga_send: UnboundedSender<(Token, Timestamp)>,
list_send: UnboundedSender<(Timestamp, Timestamp)>,
) -> Dfir<'static> {
Expand Down
2 changes: 1 addition & 1 deletion dfir_rs/examples/rga/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::protocol::{Timestamp, Token};

pub(crate) fn rga_datalog(
input_recv: UnboundedReceiverStream<(Token, Timestamp)>,
input_recv: futures::stream::Fuse<UnboundedReceiverStream<(Token, Timestamp)>>,
rga_send: UnboundedSender<(Token, Timestamp)>,
list_send: UnboundedSender<(Timestamp, Timestamp)>,
) -> Dfir<'static> {
Expand Down
2 changes: 1 addition & 1 deletion dfir_rs/examples/rga/datalog_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::protocol::{Timestamp, Token};

pub(crate) fn rga_datalog_agg(
input_recv: UnboundedReceiverStream<(Token, Timestamp)>,
input_recv: futures::stream::Fuse<UnboundedReceiverStream<(Token, Timestamp)>>,
rga_send: UnboundedSender<(Token, Timestamp)>,
list_send: UnboundedSender<(Timestamp, Timestamp)>,
) -> Dfir<'static> {
Expand Down
4 changes: 2 additions & 2 deletions dfir_rs/examples/rga/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ fn keystroke(
}

async fn write_to_dot(
rga_recv: &mut UnboundedReceiverStream<(Token, Timestamp)>,
list_recv: &mut UnboundedReceiverStream<(Timestamp, Timestamp)>,
rga_recv: &mut futures::stream::Fuse<UnboundedReceiverStream<(Token, Timestamp)>>,
list_recv: &mut futures::stream::Fuse<UnboundedReceiverStream<(Timestamp, Timestamp)>>,
w: &mut impl std::fmt::Write,
) {
let tree_edges: BTreeSet<_> = collect_ready_async(rga_recv).await;
Expand Down
2 changes: 1 addition & 1 deletion dfir_rs/examples/rga/minimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::{Timestamp, Token};

pub(crate) fn rga_minimal(
input_recv: UnboundedReceiverStream<(Token, Timestamp)>,
input_recv: futures::stream::Fuse<UnboundedReceiverStream<(Token, Timestamp)>>,
rga_send: UnboundedSender<(Token, Timestamp)>,
_list_send: UnboundedSender<(Timestamp, Timestamp)>,
) -> Dfir<'static> {
Expand Down
18 changes: 10 additions & 8 deletions dfir_rs/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use std::net::SocketAddr;
use std::num::NonZeroUsize;
use std::task::{Context, Poll};

use futures::Stream;
use futures::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde::ser::Serialize;
#[cfg(unix)]
Expand All @@ -56,18 +56,19 @@ pub enum PersistenceKeyed<K, V> {
/// Returns a channel as a (1) unbounded sender and (2) unbounded receiver `Stream` for use in DFIR.
pub fn unbounded_channel<T>() -> (
tokio::sync::mpsc::UnboundedSender<T>,
tokio_stream::wrappers::UnboundedReceiverStream<T>,
futures::stream::Fuse<tokio_stream::wrappers::UnboundedReceiverStream<T>>,
) {
let (send, recv) = tokio::sync::mpsc::unbounded_channel();
let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv).fuse();
(send, recv)
}

/// Returns an unsync channel as a (1) sender and (2) receiver `Stream` for use in DFIR.
pub fn unsync_channel<T>(
capacity: Option<NonZeroUsize>,
) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
unsync::mpsc::channel(capacity)
) -> (unsync::mpsc::Sender<T>, futures::stream::Fuse<unsync::mpsc::Receiver<T>>) {
let (sender, receiver) = unsync::mpsc::channel(capacity);
(sender, receiver.fuse())
}

/// Returns an [`Iterator`] of any immediately available items from the [`Stream`].
Expand Down Expand Up @@ -184,7 +185,7 @@ pub async fn bind_tcp_bytes(
addr: SocketAddr,
) -> (
unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
futures::stream::Fuse<unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>>,
SocketAddr,
) {
bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
Expand All @@ -198,7 +199,7 @@ pub async fn bind_tcp_lines(
addr: SocketAddr,
) -> (
unsync::mpsc::Sender<(String, SocketAddr)>,
unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
futures::stream::Fuse<unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>>,
SocketAddr,
) {
bind_tcp(addr, tokio_util::codec::LinesCodec::new())
Expand Down Expand Up @@ -248,7 +249,7 @@ where
pub fn iter_batches_stream<I>(
iter: I,
n: usize,
) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
) -> futures::stream::Fuse<futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>>
where
I: IntoIterator + Unpin,
{
Expand All @@ -264,6 +265,7 @@ where
Poll::Ready(iter.next())
}
})
.fuse()
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion dfir_rs/src/util/simulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl ProcessBuilderContext<'_> {
pub fn new_inbox<T: 'static>(
&mut self,
interface: InterfaceName,
) -> UnboundedReceiverStream<(T, Address)> {
) -> futures::stream::Fuse<UnboundedReceiverStream<(T, Address)>> {
let (sender, receiver) = unbounded_channel::<(T, Address)>();
self.inboxes.insert(
interface,
Expand Down
14 changes: 7 additions & 7 deletions dfir_rs/src/util/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ pub fn tcp_framed<Codec>(
codec: Codec,
) -> (
FramedWrite<OwnedWriteHalf, Codec>,
FramedRead<OwnedReadHalf, Codec>,
futures::stream::Fuse<FramedRead<OwnedReadHalf, Codec>>,
)
where
Codec: Clone + Decoder,
{
let (recv, send) = stream.into_split();
let send = FramedWrite::new(send, codec.clone());
let recv = FramedRead::new(recv, codec);
(send, recv)
(send, StreamExt::fuse(recv))
}

/// Helper creates a TCP `Stream` and `Sink` for `Bytes` strings where each string is
Expand All @@ -42,7 +42,7 @@ pub fn tcp_bytes(
stream: TcpStream,
) -> (
FramedWrite<OwnedWriteHalf, LengthDelimitedCodec>,
FramedRead<OwnedReadHalf, LengthDelimitedCodec>,
futures::stream::Fuse<FramedRead<OwnedReadHalf, LengthDelimitedCodec>>,
) {
tcp_framed(stream, LengthDelimitedCodec::new())
}
Expand All @@ -52,7 +52,7 @@ pub fn tcp_bytestream(
stream: TcpStream,
) -> (
FramedWrite<OwnedWriteHalf, BytesCodec>,
FramedRead<OwnedReadHalf, BytesCodec>,
futures::stream::Fuse<FramedRead<OwnedReadHalf, BytesCodec>>,
) {
tcp_framed(stream, BytesCodec::new())
}
Expand All @@ -62,7 +62,7 @@ pub fn tcp_lines(
stream: TcpStream,
) -> (
FramedWrite<OwnedWriteHalf, LinesCodec>,
FramedRead<OwnedReadHalf, LinesCodec>,
futures::stream::Fuse<FramedRead<OwnedReadHalf, LinesCodec>>,
) {
tcp_framed(stream, LinesCodec::new())
}
Expand All @@ -72,7 +72,7 @@ pub type TcpFramedSink<T> = Sender<(T, SocketAddr)>;
/// A framed TCP `Stream` (receiving).
#[expect(type_alias_bounds, reason = "code readability")]
pub type TcpFramedStream<Codec: Decoder> =
Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>;
futures::stream::Fuse<Receiver<Result<(<Codec as Decoder>::Item, SocketAddr), <Codec as Decoder>::Error>>>;

// TODO(mingwei): this temporary code should be replaced with a properly thought out networking system.
/// Create a listening tcp socket, and then as new connections come in, receive their data and forward it to a queue.
Expand All @@ -98,7 +98,7 @@ where
let mut peers_send = HashMap::new();
// `StreamMap` of `addr -> peers`, to receive messages from. Automatically removes streams
// when they disconnect.
let mut peers_recv = StreamMap::<SocketAddr, FramedRead<OwnedReadHalf, Codec>>::new();
let mut peers_recv = StreamMap::<SocketAddr, futures::stream::Fuse<FramedRead<OwnedReadHalf, Codec>>>::new();

loop {
// Calling methods in a loop, futures must be cancel-safe.
Expand Down
8 changes: 4 additions & 4 deletions dfir_rs/src/util/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use std::net::SocketAddr;

use bytes::Bytes;
use futures::stream::{SplitSink, SplitStream};
use futures::stream::{SplitSink, SplitStream, StreamExt};
use tokio::net::UdpSocket;
use tokio_util::codec::length_delimited::LengthDelimitedCodec;
use tokio_util::codec::{BytesCodec, Decoder, Encoder, LinesCodec};
Expand All @@ -12,7 +12,7 @@ use tokio_util::udp::UdpFramed;
/// A framed UDP `Sink` (sending).
pub type UdpFramedSink<Codec, Item> = SplitSink<UdpFramed<Codec>, (Item, SocketAddr)>;
/// A framed UDP `Stream` (receiving).
pub type UdpFramedStream<Codec> = SplitStream<UdpFramed<Codec>>;
pub type UdpFramedStream<Codec> = futures::stream::Fuse<SplitStream<UdpFramed<Codec>>>;
/// Returns a UDP `Stream`, `Sink`, and address for the given socket, using the given `Codec` to
/// handle delineation between inputs/outputs.
pub fn udp_framed<Codec, Item>(
Expand All @@ -28,8 +28,8 @@ where
{
let framed = UdpFramed::new(socket, codec);
let addr = framed.get_ref().local_addr().unwrap();
let split = futures::stream::StreamExt::split(framed);
(split.0, split.1, addr)
let split = StreamExt::split(framed);
(split.0, split.1.fuse(), addr)
}

/// A UDP length-delimited frame `Sink` (sending).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
error[E0277]: the trait bound `NotFusedStream: FusedStream` is not satisfied
--> tests/compile-fail-stable/surface_source_stream_not_fused.rs:29:23
|
29 | source_stream(stream) -> for_each(std::mem::drop);
| ------------- ^^^^^^ unsatisfied trait bound
| |
| required by a bound introduced by this call
|
help: the trait `FusedStream` is not implemented for `NotFusedStream`
--> tests/compile-fail-stable/surface_source_stream_not_fused.rs:9:1
|
9 | struct NotFusedStream {
| ^^^^^^^^^^^^^^^^^^^^^
= help: the following other types implement trait `FusedStream`:
&mut F
Box<S>
BufferUnordered<St>
Buffered<St>
FlatMapUnordered<St, U, F>
FlattenSink<Fut, Si>
FlattenStream<F>
FuturesOrdered<Fut>
and $N others
note: required by a bound in `check_stream`
--> tests/compile-fail-stable/surface_source_stream_not_fused.rs:29:9
|
29 | source_stream(stream) -> for_each(std::mem::drop);
| ^^^^^^^^^^^^^ required by this bound in `check_stream`

error[E0277]: the trait bound `NotFusedStream: FusedStream` is not satisfied
--> tests/compile-fail-stable/surface_source_stream_not_fused.rs:29:9
|
29 | source_stream(stream) -> for_each(std::mem::drop);
| ^^^^^^^^^^^^^ unsatisfied trait bound
|
help: the trait `FusedStream` is not implemented for `NotFusedStream`
--> tests/compile-fail-stable/surface_source_stream_not_fused.rs:9:1
|
9 | struct NotFusedStream {
| ^^^^^^^^^^^^^^^^^^^^^
= help: the following other types implement trait `FusedStream`:
&mut F
Box<S>
BufferUnordered<St>
Buffered<St>
FlatMapUnordered<St, U, F>
FlattenSink<Fut, Si>
FlattenStream<F>
FuturesOrdered<Fut>
and $N others
note: required by a bound in `check_stream`
--> tests/compile-fail-stable/surface_source_stream_not_fused.rs:29:9
|
29 | source_stream(stream) -> for_each(std::mem::drop);
| ^^^^^^^^^^^^^ required by this bound in `check_stream`
32 changes: 32 additions & 0 deletions dfir_rs/tests/compile-fail/surface_source_stream_not_fused.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use dfir_rs::dfir_syntax;
use futures::stream::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};

// A simple stream that is Unpin but NOT FusedStream
// It will panic if polled after returning None
#[derive(Default)]
struct NotFusedStream {
done: bool,
}

impl Stream for NotFusedStream {
type Item = i32;

fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
panic!("NotFusedStream polled after returning None!");
}
self.done = true;
Poll::Ready(None)
}
}

fn main() {
let stream = NotFusedStream::default();

let mut df = dfir_syntax! {
source_stream(stream) -> for_each(std::mem::drop);
};
df.run_available_sync();
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn check_cartesian_product_multi_tick(
mut df: Dfir,
lhs_send: UnboundedSender<u32>,
rhs_send: UnboundedSender<u32>,
mut out_recv: UnboundedReceiverStream<SetUnionHashSet<(u32, u32)>>,
mut out_recv: futures::stream::Fuse<UnboundedReceiverStream<SetUnionHashSet<(u32, u32)>>>,
) {
df.run_available_sync();
assert_eq!(0, collect_ready::<Vec<_>, _>(&mut out_recv).len());
Expand Down