Skip to content

Reduce pin related unsafe code #2613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions tokio-test/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,11 @@ const SLEEP: usize = 2;

impl<T> Spawn<T> {
/// Consumes `self` returning the inner value
pub fn into_inner(mut self) -> T
pub fn into_inner(self) -> T
where
T: Unpin,
{
drop(self.task);

// Pin::into_inner is unstable, so we work around it
//
// Safety: `T` is bound by `Unpin`.
unsafe {
let ptr = Pin::get_mut(self.future.as_mut()) as *mut T;
let future = Box::from_raw(ptr);
mem::forget(self.future);
*future
}
*Pin::into_inner(self.future)
}

/// Returns `true` if the inner future has received a wake notification
Expand Down
7 changes: 5 additions & 2 deletions tokio/src/runtime/blocking/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@ impl<T> BlockingTask<T> {
}
}

// The closure `F` is never pinned
impl<T> Unpin for BlockingTask<T> {}

impl<T, R> Future for BlockingTask<T>
where
T: FnOnce() -> R,
{
type Output = R;

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<R> {
let me = unsafe { self.get_unchecked_mut() };
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<R> {
let me = &mut *self;
let func = me
.func
.take()
Expand Down
14 changes: 4 additions & 10 deletions tokio/src/runtime/enter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,22 +142,19 @@ cfg_block_on! {
impl Enter {
/// Blocks the thread on the specified future, returning the value with
/// which that future completes.
pub(crate) fn block_on<F>(&mut self, mut f: F) -> Result<F::Output, crate::park::ParkError>
pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, crate::park::ParkError>
where
F: std::future::Future,
{
use crate::park::{CachedParkThread, Park};
use std::pin::Pin;
use std::task::Context;
use std::task::Poll::Ready;

let mut park = CachedParkThread::new();
let waker = park.get_unpark()?.into_waker();
let mut cx = Context::from_waker(&waker);

// `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can
// no longer be accessed, making the pinning safe.
let mut f = unsafe { Pin::new_unchecked(&mut f) };
pin!(f);

loop {
if let Ready(v) = crate::coop::budget(|| f.as_mut().poll(&mut cx)) {
Expand All @@ -179,12 +176,11 @@ cfg_blocking_impl! {
///
/// If the future completes before `timeout`, the result is returned. If
/// `timeout` elapses, then `Err` is returned.
pub(crate) fn block_on_timeout<F>(&mut self, mut f: F, timeout: Duration) -> Result<F::Output, ParkError>
pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ParkError>
where
F: std::future::Future,
{
use crate::park::{CachedParkThread, Park};
use std::pin::Pin;
use std::task::Context;
use std::task::Poll::Ready;
use std::time::Instant;
Expand All @@ -193,9 +189,7 @@ cfg_blocking_impl! {
let waker = park.get_unpark()?.into_waker();
let mut cx = Context::from_waker(&waker);

// `block_on` takes ownership of `f`. Once it is pinned here, the original `f` binding can
// no longer be accessed, making the pinning safe.
let mut f = unsafe { Pin::new_unchecked(&mut f) };
pin!(f);
let when = Instant::now() + timeout;

loop {
Expand Down
39 changes: 19 additions & 20 deletions tokio/src/time/timeout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

use crate::time::{delay_until, Delay, Duration, Instant};

use pin_project_lite::pin_project;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
Expand Down Expand Up @@ -99,12 +100,16 @@ where
}
}

/// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at).
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Timeout<T> {
value: T,
delay: Delay,
pin_project! {
/// Future returned by [`timeout`](timeout) and [`timeout_at`](timeout_at).
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
pub struct Timeout<T> {
#[pin]
value: T,
#[pin]
delay: Delay,
}
}

/// Error returned by `Timeout`.
Expand Down Expand Up @@ -146,24 +151,18 @@ where
{
type Output = Result<T::Output, Elapsed>;

fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
// First, try polling the future
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let me = self.project();

// Safety: we never move `self.value`
unsafe {
let p = self.as_mut().map_unchecked_mut(|me| &mut me.value);
if let Poll::Ready(v) = p.poll(cx) {
return Poll::Ready(Ok(v));
}
// First, try polling the future
if let Poll::Ready(v) = me.value.poll(cx) {
return Poll::Ready(Ok(v));
}

// Now check the timer
// Safety: X_X!
unsafe {
match self.map_unchecked_mut(|me| &mut me.delay).poll(cx) {
Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))),
Poll::Pending => Poll::Pending,
}
match me.delay.poll(cx) {
Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))),
Poll::Pending => Poll::Pending,
}
}
}
Expand Down
16 changes: 9 additions & 7 deletions tokio/tests/rt_basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn spawned_task_does_not_progress_without_block_on() {

#[test]
fn no_extra_poll() {
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Expand All @@ -37,9 +38,12 @@ fn no_extra_poll() {
use std::task::{Context, Poll};
use tokio::stream::{Stream, StreamExt};

struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
s: S,
pin_project! {
struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
#[pin]
s: S,
}
}

impl<S> Stream for TrackPolls<S>
Expand All @@ -48,11 +52,9 @@ fn no_extra_poll() {
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// safety: we do not move s
let this = unsafe { self.get_unchecked_mut() };
let this = self.project();
this.npolls.fetch_add(1, SeqCst);
// safety: we are pinned, and so is s
unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx)
this.s.poll_next(cx)
}
}

Expand Down
16 changes: 9 additions & 7 deletions tokio/tests/tcp_accept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ test_accept! {
(ip_port_tuple, ("127.0.0.1".parse::<IpAddr>().unwrap(), 0)),
}

use pin_project_lite::pin_project;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicUsize, Ordering::SeqCst},
Expand All @@ -47,9 +48,12 @@ use std::sync::{
use std::task::{Context, Poll};
use tokio::stream::{Stream, StreamExt};

struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
s: S,
pin_project! {
struct TrackPolls<S> {
npolls: Arc<AtomicUsize>,
#[pin]
s: S,
}
}

impl<S> Stream for TrackPolls<S>
Expand All @@ -58,11 +62,9 @@ where
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
// safety: we do not move s
let this = unsafe { self.get_unchecked_mut() };
let this = self.project();
this.npolls.fetch_add(1, SeqCst);
// safety: we are pinned, and so is s
unsafe { Pin::new_unchecked(&mut this.s) }.poll_next(cx)
this.s.poll_next(cx)
}
}

Expand Down