diff --git a/kore/src/main/java/org/kframework/POSet.java b/kore/src/main/java/org/kframework/POSet.java index aa36428b600..82c6fc16a0f 100644 --- a/kore/src/main/java/org/kframework/POSet.java +++ b/kore/src/main/java/org/kframework/POSet.java @@ -6,13 +6,11 @@ import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; -import java.util.stream.StreamSupport; import org.apache.commons.lang3.tuple.Pair; import org.kframework.utils.Lazy; import org.kframework.utils.errorsystem.KEMException; @@ -108,14 +106,17 @@ public Set elements() { return elementsLazy.get(); } - private List computeSortedElements() { - return StreamSupport.stream(TopologicalSort.tsort(directRelations).spliterator(), false) - .toList(); + private java.util.List computeSortedElements() { + Optional> topological = TopologicalSort.tsort(directRelations); + // We already checked for cycles during construction, so the sort should succeed + assert topological.isPresent(); + return topological.get().toList(); } - private final Lazy> sortedElementsLazy = new Lazy<>(this::computeSortedElements); + private final Lazy> sortedElementsLazy = + new Lazy<>(this::computeSortedElements); - public List sortedElements() { + public java.util.List sortedElements() { return sortedElementsLazy.get(); } diff --git a/kore/src/main/java/org/kframework/TopologicalSort.java b/kore/src/main/java/org/kframework/TopologicalSort.java new file mode 100644 index 00000000000..a0052fa9918 --- /dev/null +++ b/kore/src/main/java/org/kframework/TopologicalSort.java @@ -0,0 +1,50 @@ +// Copyright (c) Runtime Verification, Inc. All Rights Reserved. +package org.kframework; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.lang3.tuple.Pair; + +public class TopologicalSort { + /** Topologically sort based on the provided edges, unless a cycle is present. */ + public static Optional> tsort(Iterable> edges) { + Map> toPred = new HashMap<>(); + for (Pair edge : edges) { + if (!toPred.containsKey(edge.getLeft())) { + toPred.put(edge.getLeft(), new HashSet<>()); + } + if (!toPred.containsKey(edge.getRight())) { + toPred.put(edge.getRight(), new HashSet<>()); + } + toPred.get(edge.getRight()).add(edge.getLeft()); + } + return tsortInternal(toPred, Stream.empty()); + } + + private static Optional> tsortInternal(Map> toPreds, Stream done) { + Map>>> partition = + toPreds.entrySet().stream() + .collect(Collectors.partitioningBy((e) -> e.getValue().isEmpty())); + List>> noPreds = partition.get(true); + Map> hasPreds = + partition.get(false).stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + if (noPreds.isEmpty()) { + if (hasPreds.isEmpty()) { + return Optional.of(done); + } + return Optional.empty(); + } + Set found = noPreds.stream().map(Map.Entry::getKey).collect(Collectors.toSet()); + for (Map.Entry> entry : hasPreds.entrySet()) { + entry.getValue().removeAll(found); + } + return tsortInternal(hasPreds, Stream.concat(done, found.stream())); + } +} diff --git a/kore/src/main/scala/org/kframework/TopologicalSort.scala b/kore/src/main/scala/org/kframework/TopologicalSort.scala deleted file mode 100644 index 906325677a2..00000000000 --- a/kore/src/main/scala/org/kframework/TopologicalSort.scala +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Runtime Verification, Inc. All Rights Reserved. -package org.kframework - -import org.apache.commons.lang3.tuple.Pair -import scala.annotation.tailrec -import scala.jdk.CollectionConverters._ - -/** - * Created by dwightguth on 4/16/15. - */ -object TopologicalSort { - def tsort[A](edges: Traversable[(A, A)]): Iterable[A] = { - @tailrec - def tsort(toPreds: Map[A, Set[A]], done: Iterable[A]): Iterable[A] = { - val (noPreds, hasPreds) = toPreds.partition(_._2.isEmpty) - if (noPreds.isEmpty) { - if (hasPreds.isEmpty) done else sys.error(hasPreds.toString) - } else { - val found = noPreds.map(_._1) - tsort(hasPreds.mapValues(_ -- found), done ++ found) - } - } - - val toPred = edges.foldLeft(Map[A, Set[A]]()) { (acc, e) => - acc + (e._1 -> acc.getOrElse(e._1, Set())) + (e._2 -> (acc.getOrElse(e._2, Set()) + e._1)) - } - tsort(toPred, Seq()) - } - - def tsort[A](edges: java.lang.Iterable[Pair[A, A]]): java.lang.Iterable[A] = - tsort(edges.asScala.toSet.map((p: Pair[A, A]) => Tuple2(p.getLeft, p.getRight))).asJava -} diff --git a/kore/src/main/scala/org/kframework/collections.scala b/kore/src/main/scala/org/kframework/collections.scala index 46f79951750..f2b6de923c6 100644 --- a/kore/src/main/scala/org/kframework/collections.scala +++ b/kore/src/main/scala/org/kframework/collections.scala @@ -19,7 +19,6 @@ object Collections { def immutable[T](s: java.util.Set[T]): Set[T] = s.asScala.toSet def immutable[T](s: java.util.List[T]): Seq[T] = s.asScala def immutable[K, V](s: java.util.Map[K, V]): Map[K, V] = s.asScala - def immutable[T](s: Array[T]): Seq[T] = s def mutable[T](s: scala.List[T]): java.util.List[T] = s.asJava def mutable[T](s: Seq[T]): java.util.List[T] = s.asJava diff --git a/kore/src/main/scala/org/kframework/compile/ConfigurationInfoFromModule.scala b/kore/src/main/scala/org/kframework/compile/ConfigurationInfoFromModule.scala index aef173a0efc..22989eabd19 100644 --- a/kore/src/main/scala/org/kframework/compile/ConfigurationInfoFromModule.scala +++ b/kore/src/main/scala/org/kframework/compile/ConfigurationInfoFromModule.scala @@ -1,7 +1,6 @@ // Copyright (c) Runtime Verification, Inc. All Rights Reserved. package org.kframework.compile -import collection._ import java.util import org.kframework.attributes.Att import org.kframework.builtin.Sorts @@ -14,8 +13,10 @@ import org.kframework.kore._ import org.kframework.kore.KORE.KApply import org.kframework.kore.KORE.KLabel import org.kframework.utils.errorsystem.KEMException +import org.kframework.Collections import org.kframework.POSet import org.kframework.TopologicalSort._ +import scala.collection._ import scala.collection.JavaConverters._ object ConfigurationInfoFromModule @@ -96,7 +97,7 @@ class ConfigurationInfoFromModule(val m: Module) extends ConfigurationInfo { private lazy val topCells = cellSorts.diff(edges.map(_._2)) - private val sortedSorts: Seq[Sort] = tsort(edges).toSeq + private val sortedSorts: Seq[Sort] = Collections.immutable(edgesPoset.sortedElements()) private val sortedEdges: Seq[(Sort, Sort)] = edges.toList.sortWith((l, r) => sortedSorts.indexOf(l._1) < sortedSorts.indexOf(r._1)) val levels: Map[Sort, Int] = sortedEdges.foldLeft(topCells.map((_, 0)).toMap) {