Skip to content

Commit 0e88e56

Browse files
authored
Rollup merge of rust-lang#64805 - nnethercote:ObligForest-still-more, r=nikomatsakis
Still more `ObligationForest` improvements. Following on from rust-lang#64627, more readability improvements, but negligible effects on speed. r? @nikomatsakis
2 parents bd9d843 + a820672 commit 0e88e56

File tree

2 files changed

+115
-126
lines changed

2 files changed

+115
-126
lines changed

src/librustc_data_structures/obligation_forest/mod.rs

+94-119
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,8 @@ pub struct ObligationForest<O: ForestObligation> {
151151
/// comments in `process_obligation` for details.
152152
active_cache: FxHashMap<O::Predicate, usize>,
153153

154-
/// A scratch vector reused in various operations, to avoid allocating new
155-
/// vectors.
156-
scratch: RefCell<Vec<usize>>,
154+
/// A vector reused in compress(), to avoid allocating new vectors.
155+
node_rewrites: RefCell<Vec<usize>>,
157156

158157
obligation_tree_id_generator: ObligationTreeIdGenerator,
159158

@@ -235,10 +234,6 @@ enum NodeState {
235234
/// This obligation was resolved to an error. Error nodes are
236235
/// removed from the vector by the compression step.
237236
Error,
238-
239-
/// This is a temporary state used in DFS loops to detect cycles,
240-
/// it should not exist outside of these DFSes.
241-
OnDfsStack,
242237
}
243238

244239
#[derive(Debug)]
@@ -279,7 +274,7 @@ impl<O: ForestObligation> ObligationForest<O> {
279274
nodes: vec![],
280275
done_cache: Default::default(),
281276
active_cache: Default::default(),
282-
scratch: RefCell::new(vec![]),
277+
node_rewrites: RefCell::new(vec![]),
283278
obligation_tree_id_generator: (0..).map(ObligationTreeId),
284279
error_cache: Default::default(),
285280
}
@@ -305,9 +300,10 @@ impl<O: ForestObligation> ObligationForest<O> {
305300

306301
match self.active_cache.entry(obligation.as_predicate().clone()) {
307302
Entry::Occupied(o) => {
303+
let index = *o.get();
308304
debug!("register_obligation_at({:?}, {:?}) - duplicate of {:?}!",
309-
obligation, parent, o.get());
310-
let node = &mut self.nodes[*o.get()];
305+
obligation, parent, index);
306+
let node = &mut self.nodes[index];
311307
if let Some(parent_index) = parent {
312308
// If the node is already in `active_cache`, it has already
313309
// had its chance to be marked with a parent. So if it's
@@ -342,7 +338,8 @@ impl<O: ForestObligation> ObligationForest<O> {
342338
if already_failed {
343339
Err(())
344340
} else {
345-
v.insert(self.nodes.len());
341+
let new_index = self.nodes.len();
342+
v.insert(new_index);
346343
self.nodes.push(Node::new(parent, obligation, obligation_tree_id));
347344
Ok(())
348345
}
@@ -352,15 +349,16 @@ impl<O: ForestObligation> ObligationForest<O> {
352349

353350
/// Converts all remaining obligations to the given error.
354351
pub fn to_errors<E: Clone>(&mut self, error: E) -> Vec<Error<O, E>> {
355-
let mut errors = vec![];
356-
for (index, node) in self.nodes.iter().enumerate() {
357-
if let NodeState::Pending = node.state.get() {
358-
errors.push(Error {
352+
let errors = self.nodes.iter().enumerate()
353+
.filter(|(_index, node)| node.state.get() == NodeState::Pending)
354+
.map(|(index, _node)| {
355+
Error {
359356
error: error.clone(),
360357
backtrace: self.error_at(index),
361-
});
362-
}
363-
}
358+
}
359+
})
360+
.collect();
361+
364362
let successful_obligations = self.compress(DoCompleted::Yes);
365363
assert!(successful_obligations.unwrap().is_empty());
366364
errors
@@ -370,15 +368,14 @@ impl<O: ForestObligation> ObligationForest<O> {
370368
pub fn map_pending_obligations<P, F>(&self, f: F) -> Vec<P>
371369
where F: Fn(&O) -> P
372370
{
373-
self.nodes
374-
.iter()
375-
.filter(|n| n.state.get() == NodeState::Pending)
376-
.map(|n| f(&n.obligation))
371+
self.nodes.iter()
372+
.filter(|node| node.state.get() == NodeState::Pending)
373+
.map(|node| f(&node.obligation))
377374
.collect()
378375
}
379376

380-
fn insert_into_error_cache(&mut self, node_index: usize) {
381-
let node = &self.nodes[node_index];
377+
fn insert_into_error_cache(&mut self, index: usize) {
378+
let node = &self.nodes[index];
382379
self.error_cache
383380
.entry(node.obligation_tree_id)
384381
.or_default()
@@ -408,10 +405,10 @@ impl<O: ForestObligation> ObligationForest<O> {
408405
// `self.active_cache`. This means that `self.active_cache` can get
409406
// out of sync with `nodes`. It's not very common, but it does
410407
// happen, and code in `compress` has to allow for it.
411-
let result = match node.state.get() {
412-
NodeState::Pending => processor.process_obligation(&mut node.obligation),
413-
_ => continue
414-
};
408+
if node.state.get() != NodeState::Pending {
409+
continue;
410+
}
411+
let result = processor.process_obligation(&mut node.obligation);
415412

416413
debug!("process_obligations: node {} got result {:?}", index, result);
417414

@@ -476,64 +473,53 @@ impl<O: ForestObligation> ObligationForest<O> {
476473
fn process_cycles<P>(&self, processor: &mut P)
477474
where P: ObligationProcessor<Obligation=O>
478475
{
479-
let mut stack = self.scratch.replace(vec![]);
480-
debug_assert!(stack.is_empty());
476+
let mut stack = vec![];
481477

482478
debug!("process_cycles()");
483479

484480
for (index, node) in self.nodes.iter().enumerate() {
485481
// For some benchmarks this state test is extremely
486482
// hot. It's a win to handle the no-op cases immediately to avoid
487483
// the cost of the function call.
488-
match node.state.get() {
489-
// Match arms are in order of frequency. Pending, Success and
490-
// Waiting dominate; the others are rare.
491-
NodeState::Pending => {},
492-
NodeState::Success => self.find_cycles_from_node(&mut stack, processor, index),
493-
NodeState::Waiting | NodeState::Done | NodeState::Error => {},
494-
NodeState::OnDfsStack => self.find_cycles_from_node(&mut stack, processor, index),
484+
if node.state.get() == NodeState::Success {
485+
self.find_cycles_from_node(&mut stack, processor, index);
495486
}
496487
}
497488

498489
debug!("process_cycles: complete");
499490

500491
debug_assert!(stack.is_empty());
501-
self.scratch.replace(stack);
502492
}
503493

504494
fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>, processor: &mut P, index: usize)
505495
where P: ObligationProcessor<Obligation=O>
506496
{
507497
let node = &self.nodes[index];
508-
match node.state.get() {
509-
NodeState::OnDfsStack => {
510-
let rpos = stack.iter().rposition(|&n| n == index).unwrap();
511-
processor.process_backedge(stack[rpos..].iter().map(GetObligation(&self.nodes)),
512-
PhantomData);
513-
}
514-
NodeState::Success => {
515-
node.state.set(NodeState::OnDfsStack);
516-
stack.push(index);
517-
for &index in node.dependents.iter() {
518-
self.find_cycles_from_node(stack, processor, index);
498+
if node.state.get() == NodeState::Success {
499+
match stack.iter().rposition(|&n| n == index) {
500+
None => {
501+
stack.push(index);
502+
for &index in node.dependents.iter() {
503+
self.find_cycles_from_node(stack, processor, index);
504+
}
505+
stack.pop();
506+
node.state.set(NodeState::Done);
507+
}
508+
Some(rpos) => {
509+
// Cycle detected.
510+
processor.process_backedge(
511+
stack[rpos..].iter().map(GetObligation(&self.nodes)),
512+
PhantomData
513+
);
519514
}
520-
stack.pop();
521-
node.state.set(NodeState::Done);
522-
},
523-
NodeState::Waiting | NodeState::Pending => {
524-
// This node is still reachable from some pending node. We
525-
// will get to it when they are all processed.
526-
}
527-
NodeState::Done | NodeState::Error => {
528-
// Already processed that node.
529515
}
530-
};
516+
}
531517
}
532518

533519
/// Returns a vector of obligations for `p` and all of its
534520
/// ancestors, putting them into the error state in the process.
535521
fn error_at(&self, mut index: usize) -> Vec<O> {
536-
let mut error_stack = self.scratch.replace(vec![]);
522+
let mut error_stack: Vec<usize> = vec![];
537523
let mut trace = vec![];
538524

539525
loop {
@@ -554,23 +540,32 @@ impl<O: ForestObligation> ObligationForest<O> {
554540

555541
while let Some(index) = error_stack.pop() {
556542
let node = &self.nodes[index];
557-
match node.state.get() {
558-
NodeState::Error => continue,
559-
_ => node.state.set(NodeState::Error),
543+
if node.state.get() != NodeState::Error {
544+
node.state.set(NodeState::Error);
545+
error_stack.extend(node.dependents.iter());
560546
}
561-
562-
error_stack.extend(node.dependents.iter());
563547
}
564548

565-
self.scratch.replace(error_stack);
566549
trace
567550
}
568551

569552
// This always-inlined function is for the hot call site.
570553
#[inline(always)]
571554
fn inlined_mark_neighbors_as_waiting_from(&self, node: &Node<O>) {
572555
for &index in node.dependents.iter() {
573-
self.mark_as_waiting_from(&self.nodes[index]);
556+
let node = &self.nodes[index];
557+
match node.state.get() {
558+
NodeState::Waiting | NodeState::Error => {}
559+
NodeState::Success => {
560+
node.state.set(NodeState::Waiting);
561+
// This call site is cold.
562+
self.uninlined_mark_neighbors_as_waiting_from(node);
563+
}
564+
NodeState::Pending | NodeState::Done => {
565+
// This call site is cold.
566+
self.uninlined_mark_neighbors_as_waiting_from(node);
567+
}
568+
}
574569
}
575570
}
576571

@@ -596,37 +591,28 @@ impl<O: ForestObligation> ObligationForest<O> {
596591
}
597592
}
598593

599-
fn mark_as_waiting_from(&self, node: &Node<O>) {
600-
match node.state.get() {
601-
NodeState::Waiting | NodeState::Error | NodeState::OnDfsStack => return,
602-
NodeState::Success => node.state.set(NodeState::Waiting),
603-
NodeState::Pending | NodeState::Done => {},
604-
}
605-
606-
// This call site is cold.
607-
self.uninlined_mark_neighbors_as_waiting_from(node);
608-
}
609-
610-
/// Compresses the vector, removing all popped nodes. This adjusts
611-
/// the indices and hence invalidates any outstanding
612-
/// indices. Cannot be used during a transaction.
594+
/// Compresses the vector, removing all popped nodes. This adjusts the
595+
/// indices and hence invalidates any outstanding indices.
613596
///
614597
/// Beforehand, all nodes must be marked as `Done` and no cycles
615598
/// on these nodes may be present. This is done by e.g., `process_cycles`.
616599
#[inline(never)]
617600
fn compress(&mut self, do_completed: DoCompleted) -> Option<Vec<O>> {
618-
let nodes_len = self.nodes.len();
619-
let mut node_rewrites: Vec<_> = self.scratch.replace(vec![]);
620-
node_rewrites.extend(0..nodes_len);
601+
let orig_nodes_len = self.nodes.len();
602+
let mut node_rewrites: Vec<_> = self.node_rewrites.replace(vec![]);
603+
debug_assert!(node_rewrites.is_empty());
604+
node_rewrites.extend(0..orig_nodes_len);
621605
let mut dead_nodes = 0;
606+
let mut removed_done_obligations: Vec<O> = vec![];
622607

623-
// Now move all popped nodes to the end. Try to keep the order.
608+
// Now move all Done/Error nodes to the end, preserving the order of
609+
// the Pending/Waiting nodes.
624610
//
625611
// LOOP INVARIANT:
626612
// self.nodes[0..index - dead_nodes] are the first remaining nodes
627613
// self.nodes[index - dead_nodes..index] are all dead
628614
// self.nodes[index..] are unchanged
629-
for index in 0..self.nodes.len() {
615+
for index in 0..orig_nodes_len {
630616
let node = &self.nodes[index];
631617
match node.state.get() {
632618
NodeState::Pending | NodeState::Waiting => {
@@ -637,7 +623,7 @@ impl<O: ForestObligation> ObligationForest<O> {
637623
}
638624
NodeState::Done => {
639625
// This lookup can fail because the contents of
640-
// `self.active_cache` is not guaranteed to match those of
626+
// `self.active_cache` are not guaranteed to match those of
641627
// `self.nodes`. See the comment in `process_obligation`
642628
// for more details.
643629
if let Some((predicate, _)) =
@@ -647,61 +633,50 @@ impl<O: ForestObligation> ObligationForest<O> {
647633
} else {
648634
self.done_cache.insert(node.obligation.as_predicate().clone());
649635
}
650-
node_rewrites[index] = nodes_len;
636+
if do_completed == DoCompleted::Yes {
637+
// Extract the success stories.
638+
removed_done_obligations.push(node.obligation.clone());
639+
}
640+
node_rewrites[index] = orig_nodes_len;
651641
dead_nodes += 1;
652642
}
653643
NodeState::Error => {
654644
// We *intentionally* remove the node from the cache at this point. Otherwise
655645
// tests must come up with a different type on every type error they
656646
// check against.
657647
self.active_cache.remove(node.obligation.as_predicate());
658-
node_rewrites[index] = nodes_len;
659-
dead_nodes += 1;
660648
self.insert_into_error_cache(index);
649+
node_rewrites[index] = orig_nodes_len;
650+
dead_nodes += 1;
661651
}
662-
NodeState::OnDfsStack | NodeState::Success => unreachable!()
652+
NodeState::Success => unreachable!()
663653
}
664654
}
665655

666-
// No compression needed.
667-
if dead_nodes == 0 {
668-
node_rewrites.truncate(0);
669-
self.scratch.replace(node_rewrites);
670-
return if do_completed == DoCompleted::Yes { Some(vec![]) } else { None };
656+
if dead_nodes > 0 {
657+
// Remove the dead nodes and rewrite indices.
658+
self.nodes.truncate(orig_nodes_len - dead_nodes);
659+
self.apply_rewrites(&node_rewrites);
671660
}
672661

673-
// Pop off all the nodes we killed and extract the success stories.
674-
let successful = if do_completed == DoCompleted::Yes {
675-
Some((0..dead_nodes)
676-
.map(|_| self.nodes.pop().unwrap())
677-
.flat_map(|node| {
678-
match node.state.get() {
679-
NodeState::Error => None,
680-
NodeState::Done => Some(node.obligation),
681-
_ => unreachable!()
682-
}
683-
})
684-
.collect())
685-
} else {
686-
self.nodes.truncate(self.nodes.len() - dead_nodes);
687-
None
688-
};
689-
self.apply_rewrites(&node_rewrites);
690-
691662
node_rewrites.truncate(0);
692-
self.scratch.replace(node_rewrites);
663+
self.node_rewrites.replace(node_rewrites);
693664

694-
successful
665+
if do_completed == DoCompleted::Yes {
666+
Some(removed_done_obligations)
667+
} else {
668+
None
669+
}
695670
}
696671

697672
fn apply_rewrites(&mut self, node_rewrites: &[usize]) {
698-
let nodes_len = node_rewrites.len();
673+
let orig_nodes_len = node_rewrites.len();
699674

700675
for node in &mut self.nodes {
701676
let mut i = 0;
702677
while i < node.dependents.len() {
703678
let new_index = node_rewrites[node.dependents[i]];
704-
if new_index >= nodes_len {
679+
if new_index >= orig_nodes_len {
705680
node.dependents.swap_remove(i);
706681
if i == 0 && node.has_parent {
707682
// We just removed the parent.
@@ -718,7 +693,7 @@ impl<O: ForestObligation> ObligationForest<O> {
718693
// removal of nodes within `compress` can fail. See above.
719694
self.active_cache.retain(|_predicate, index| {
720695
let new_index = node_rewrites[*index];
721-
if new_index >= nodes_len {
696+
if new_index >= orig_nodes_len {
722697
false
723698
} else {
724699
*index = new_index;

0 commit comments

Comments
 (0)