如何使用不可变数据类型实现 DFS

我正在尝试找出一种 Scala 风格的图形遍历方式,最好使用 val 和不可变数据类型。


val graph = Map(0 -> Set(1),
                1 -> Set(2),
                2 -> Set(0, 3, 4),
                3 -> Set(),
                4 -> Set(3))

我希望输出是从给定节点开始的深度优先遍历。例如从 1 开始,应该产生例如 1 2 3 0 4


  def traverse(graph: Map[Int, Set[Int]], start: Int): List[Int] = {
    def childrenNotVisited(parent: Int, visited: List[Int]) =
      graph(parent) filter (x => !visited.contains(x))

    def loop(stack: Set[Int], visited: List[Int]): List[Int] = {
      if (stack isEmpty) visited
      else loop(childrenNotVisited(stack.head, visited) ++ stack.tail, 
        stack.head :: visited)
    loop(Set(start), Nil) reverse

  def traverse(graph: Map[Int, Set[Int]], start: Int): List[Int] = {
    def childrenNotVisited(parent: Int, visited: List[Int]) =
      graph(parent) filter (x => !visited.contains(x))

    def loop(stack: Set[Int], visited: List[Int]): List[Int] = {
      if (stack isEmpty) visited
      else loop(childrenNotVisited(stack.head, visited) ++ stack.tail, 
        stack.head :: visited)
    loop(Set(start), Nil) reverse
def traverse(graph: Map[Int, Set[Int]], node: Int, visited: Set[Int] = Set()): List[Int] = 
    List(node) ++ (graph(node) -- visited flatMap(traverse(graph, _, visited + node)))

traverse(graph, 1)


def traverse(graph: Map[Int, Set[Int]], node: Int, visited: Set[Int] = Set()): List[Int] = 
    List(node) ++ (graph(node) -- visited flatMap(traverse(graph, _, visited + node)))

traverse(graph, 1)

Also please note, that this function is NOT tail recursive.

graph.foldLeft((List[Int](), 1)){
  (s, e) => if (e._2.size == 0) (0 :: s._1, s._2) else (s._2 :: s._1, (s._2 + 1))

更新:这是一个扩展版本。在这里,我将左侧的地图元素折叠起来,从一个空列表和数字 1 的元组开始。对于每个元素,我检查图表的大小并相应地创建一个新元组。结果列表以相反的顺序出现。

val init = (List[Int](), 1)
val (result, _) = graph.foldLeft(init) {
  (s, elem) => 
    val (stack, count) = s
    if (elem._2.size == 0) 
      (0 :: stack, count) 
      (count :: stack, count + 1)

graph.foldLeft((List[Int](), 1)){
  (s, e) => if (e._2.size == 0) (0 :: s._1, s._2) else (s._2 :: s._1, (s._2 + 1))

Updated: This is an expanded version. Here I fold left over the elements of the map starting out with a tuple of an empty list and number 1. For each element I check the size of the graph and create a new tuple accordingly. The resulting list come out in reverse order.

val init = (List[Int](), 1)
val (result, _) = graph.foldLeft(init) {
  (s, elem) => 
    val (stack, count) = s
    if (elem._2.size == 0) 
      (0 :: stack, count) 
      (count :: stack, count + 1)
不知道 6 年后您是否仍在寻找答案,但它就是:)


case class Node(label: Int)
    case class Graph(adj: Map[Node, Set[Node]]) {
      case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
                          isCylic: Boolean = false)

      def dfs: (List[Node], Boolean) = {
        def dfsVisit(currState: DfsState, src: Node): DfsState = {
          val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
            isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))

          val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
          finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)

        val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
        (stateAfterSearch.tsOrder, stateAfterSearch.isCylic)

case class Node(label: Int)
    case class Graph(adj: Map[Node, Set[Node]]) {
      case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
                          isCylic: Boolean = false)

      def dfs: (List[Node], Boolean) = {
        def dfsVisit(currState: DfsState, src: Node): DfsState = {
          val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
            isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))

          val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
          finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)

        val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
        (stateAfterSearch.tsOrder, stateAfterSearch.isCylic)
看来这个问题比我原来想象的更复杂。我写了另一个递归解决方案。它仍然不是尾递归。我也努力让它成为单行,但在这种情况下可读性会受到很大影响,所以我决定这次声明几个 val

def traverse(graph: Map[Int, Set[Int]], node: Int, result: List[Int] = Nil): List[Int] = {
  val newResult = result :+ node
  val currentEdges = graph(node) -- newResult
  val realEdges = if (currentEdges isEmpty) graph.keySet -- newResult else currentEdges

  (newResult /: realEdges) ((r, n) => if (r contains n) r else traverse(graph, n, r))


def traverse(graph: Map[Int, Set[Int]], node: Int, result: List[Int] = Nil): List[Int] = {
  val newResult = result :+ node
  val currentEdges = graph(node) -- newResult
  val realEdges = if (currentEdges isEmpty) graph.keySet -- newResult else currentEdges

  (newResult /: realEdges) ((r, n) => if (r contains n) r else traverse(graph, n, r))

In my previous answer I tried to find all paths from the given node in directed graph. But it was wrong according to the requirements. This answer tries to follow directed edges, but if it can't, then it just takes some unvisited node and continues from there.

我还没有完全理解你的解决方案,但如果我没有记错的话,它的时间复杂度至少是 O(|V|^2),因为下面的行复杂度是 O(|V|):

val newResult = result :+ node



以下代码解决了有向图上的一些与 DFS 相关的图问题。这不是最优雅的代码,但如果我没有记错的话,它是:

  1. 尾递归。
  2. 仅使用不可变集合(及其迭代器)。
  3. 具有最佳时间 O(|V| + |E|) 和空间复杂度 (O(|V|)。


import scala.annotation.tailrec
import scala.util.Try

 * Created with IntelliJ IDEA.
 * User: mishaelr
 * Date: 5/14/14
 * Time: 5:18 PM
object DirectedGraphTraversals {

  type Graph[Vertex] = Map[Vertex, Set[Vertex]]

  def dfs[Vertex](graph: Graph[Vertex], initialVertex: Vertex) =
    dfsRec(DfsNeighbours)(graph, List(DfsNeighbours(graph, initialVertex, Set(), Set())), Set(), Set(), List())

  def topologicalSort[Vertex](graph: Graph[Vertex]) =
    graphDfsRec(TopologicalSortNeighbours)(graph, graph.keySet, Set(), Set(), List())

  def stronglyConnectedComponents[Vertex](graph: Graph[Vertex]) = {
    val exitOrder = graphDfsRec(DfsNeighbours)(graph, graph.keySet, Set(), Set(), List())
    val reversedGraph = reverse(graph)

    exitOrder.foldLeft((Set[Vertex](), List(Set[Vertex]()))){
      case (acc @(visitedAcc, connectedComponentsAcc), vertex) =>
        else {
          val connectedComponent = dfsRec(DfsNeighbours)(reversedGraph, List(DfsNeighbours(reversedGraph, vertex, visitedAcc, visitedAcc)),
            visitedAcc, visitedAcc,List()).toSet
          (visitedAcc ++ connectedComponent, connectedComponent :: connectedComponentsAcc)

  def reverse[Vertex](graph: Graph[Vertex]) = {
    val reverseList = for {
      (vertex, neighbours) <- graph.toList
      neighbour <- neighbours
    } yield (neighbour, vertex)


  private sealed trait NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]): (Vertex, Iterator[Vertex])

  private object DfsNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) =
      (vertex, graph.getOrElse(vertex, Set()).iterator)

  private object TopologicalSortNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) = {
      val neighbours = graph.getOrElse(vertex, Set())
      if(neighbours.exists(neighbour => entered(neighbour) && !exited(neighbour)))
        throw new IllegalArgumentException("The graph is not a DAG, it contains cycles: " + graph)
        (vertex, neighbours.iterator)

  private def dfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], toVisit: List[(Vertex, Iterator[Vertex])],
                                                             entered: Set[Vertex], exited: Set[Vertex],
                                                             exitStack: List[Vertex]): List[Vertex] = {
    toVisit match {
      case List() => exitStack
      case (currentVertex, neighbours) :: tl =>
        val filtered = neighbours.filterNot(entered)
        if(filtered.hasNext) {
          val nextNeighbour = filtered.next()
          dfsRec(neighboursFunc)(graph, neighboursFunc(graph, nextNeighbour, entered, exited) :: toVisit,
            entered + nextNeighbour, exited, exitStack)
        } else
          dfsRec(neighboursFunc)(graph, tl, entered, exited + currentVertex, currentVertex :: exitStack)

  private def graphDfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], notVisited: Set[Vertex],
                                                                  entered: Set[Vertex], exited: Set[Vertex], order: List[Vertex]): List[Vertex] = {
    else {
      val orderSuffix = dfsRec(neighboursFunc)(graph, List(neighboursFunc(graph, notVisited.head, entered, exited)), entered, exited, List())
      graphDfsRec(neighboursFunc)(graph, notVisited -- orderSuffix, entered ++ orderSuffix, exited ++ orderSuffix, orderSuffix ::: order)

object DirectedGraphTraversalsExamples extends App {
  import DirectedGraphTraversals._

  val graph = Map(
    "B" -> Set("D", "C"),
    "A" -> Set("B", "D"),
    "D" -> Set("E"),
    "E" -> Set("C"))

  println("dfs A " +  dfs(graph, "A"))
  println("dfs B " +  dfs(graph, "B"))

  println("topologicalSort " +  topologicalSort(graph))

  println("reverse " + reverse(graph))
  println("stronglyConnectedComponents graph " + stronglyConnectedComponents(graph))

  val graph2 = graph + ("C" -> Set("D"))
  println("stronglyConnectedComponents graph2 " + stronglyConnectedComponents(graph2))
  println("topologicalSort graph2 " + Try(topologicalSort(graph2)))


import scala.annotation.tailrec
import scala.util.Try

 * Created with IntelliJ IDEA.
 * User: mishaelr
 * Date: 5/14/14
 * Time: 5:18 PM
object DirectedGraphTraversals {

  type Graph[Vertex] = Map[Vertex, Set[Vertex]]

  def dfs[Vertex](graph: Graph[Vertex], initialVertex: Vertex) =
    dfsRec(DfsNeighbours)(graph, List(DfsNeighbours(graph, initialVertex, Set(), Set())), Set(), Set(), List())

  def topologicalSort[Vertex](graph: Graph[Vertex]) =
    graphDfsRec(TopologicalSortNeighbours)(graph, graph.keySet, Set(), Set(), List())

  def stronglyConnectedComponents[Vertex](graph: Graph[Vertex]) = {
    val exitOrder = graphDfsRec(DfsNeighbours)(graph, graph.keySet, Set(), Set(), List())
    val reversedGraph = reverse(graph)

    exitOrder.foldLeft((Set[Vertex](), List(Set[Vertex]()))){
      case (acc @(visitedAcc, connectedComponentsAcc), vertex) =>
        else {
          val connectedComponent = dfsRec(DfsNeighbours)(reversedGraph, List(DfsNeighbours(reversedGraph, vertex, visitedAcc, visitedAcc)),
            visitedAcc, visitedAcc,List()).toSet
          (visitedAcc ++ connectedComponent, connectedComponent :: connectedComponentsAcc)

  def reverse[Vertex](graph: Graph[Vertex]) = {
    val reverseList = for {
      (vertex, neighbours) <- graph.toList
      neighbour <- neighbours
    } yield (neighbour, vertex)


  private sealed trait NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]): (Vertex, Iterator[Vertex])

  private object DfsNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) =
      (vertex, graph.getOrElse(vertex, Set()).iterator)

  private object TopologicalSortNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) = {
      val neighbours = graph.getOrElse(vertex, Set())
      if(neighbours.exists(neighbour => entered(neighbour) && !exited(neighbour)))
        throw new IllegalArgumentException("The graph is not a DAG, it contains cycles: " + graph)
        (vertex, neighbours.iterator)

  private def dfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], toVisit: List[(Vertex, Iterator[Vertex])],
                                                             entered: Set[Vertex], exited: Set[Vertex],
                                                             exitStack: List[Vertex]): List[Vertex] = {
    toVisit match {
      case List() => exitStack
      case (currentVertex, neighbours) :: tl =>
        val filtered = neighbours.filterNot(entered)
        if(filtered.hasNext) {
          val nextNeighbour = filtered.next()
          dfsRec(neighboursFunc)(graph, neighboursFunc(graph, nextNeighbour, entered, exited) :: toVisit,
            entered + nextNeighbour, exited, exitStack)
        } else
          dfsRec(neighboursFunc)(graph, tl, entered, exited + currentVertex, currentVertex :: exitStack)

  private def graphDfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], notVisited: Set[Vertex],
                                                                  entered: Set[Vertex], exited: Set[Vertex], order: List[Vertex]): List[Vertex] = {
    else {
      val orderSuffix = dfsRec(neighboursFunc)(graph, List(neighboursFunc(graph, notVisited.head, entered, exited)), entered, exited, List())
      graphDfsRec(neighboursFunc)(graph, notVisited -- orderSuffix, entered ++ orderSuffix, exited ++ orderSuffix, orderSuffix ::: order)

object DirectedGraphTraversalsExamples extends App {
  import DirectedGraphTraversals._

  val graph = Map(
    "B" -> Set("D", "C"),
    "A" -> Set("B", "D"),
    "D" -> Set("E"),
    "E" -> Set("C"))

  println("dfs A " +  dfs(graph, "A"))
  println("dfs B " +  dfs(graph, "B"))

  println("topologicalSort " +  topologicalSort(graph))

  println("reverse " + reverse(graph))
  println("stronglyConnectedComponents graph " + stronglyConnectedComponents(graph))

  val graph2 = graph + ("C" -> Set("D"))
  println("stronglyConnectedComponents graph2 " + stronglyConnectedComponents(graph2))
  println("topologicalSort graph2 " + Try(topologicalSort(graph2)))
Marimuthu Madasamy 的答案确实有效。


val graph = Map(0 -> Set(1),
  1 -> Set(2),
  2 -> Set(0, 3, 4),
  3 -> Set[Int](),
  4 -> Set(3))

def traverse[T](graph: Map[T, Set[T]], start: T): List[T] = {
  def childrenNotVisited(parent: T, visited: List[T]) =
    graph(parent) filter (x => !visited.contains(x))

  def loop(stack: Set[T], visited: List[T]): List[T] = {
    if (stack.isEmpty) visited
    else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
      stack.head :: visited)
  loop(Set(start), Nil).reverse


注意:您必须确保 T 的实例正确实现 equals 和 hashcode。 使用具有原始值的 case 类是一种简单的方法到达那里。

val graph = Map(0 -> Set(1),
  1 -> Set(2),
  2 -> Set(0, 3, 4),
  3 -> Set[Int](),
  4 -> Set(3))

def traverse[T](graph: Map[T, Set[T]], start: T): List[T] = {
  def childrenNotVisited(parent: T, visited: List[T]) =
    graph(parent) filter (x => !visited.contains(x))

  def loop(stack: Set[T], visited: List[T]): List[T] = {
    if (stack.isEmpty) visited
    else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
      stack.head :: visited)
  loop(Set(start), Nil).reverse


Note: You have to make sure the instances of T are correctly implementing equals and hashcode. Using case classes with primitive values is an easy way to get there.

我想修改 Marimuthu Madasamy 的答案,因为代码使用 Set 表示堆栈,这是无序数据结构,使用 List 表示访问,这需要线性时间来调用 contains 方法,整个时间复杂度为 O(E * V),效率不高(E 是边数,V 是顶点数)。我宁愿使用 List 作为堆栈,使用 Set 来访问(将其命名为 discovered),另外使用 List对于按顺序访问的节点的结果值。

def dfs(stack: List[Int], discovered: Set[Int], orderedVisited: List[Int]): List[Int] = {
  def childrenNotVisited(start: Int) =

  if (stack.isEmpty)
  else {
    val nextNodes = childrenNotVisited(stack.head)
    dfs(nextNodes ::: stack.tail, discovered ++ nextNodes, stack.head :: orderedVisited)

val start = 0
val visitOrder = dfs(List(start), Set(start), Nil)

def dfs(stack: List[Int], discovered: Set[Int], orderedVisited: List[Int]): List[Int] = {
  def childrenNotVisited(start: Int) =

  if (stack.isEmpty)
  else {
    val nextNodes = childrenNotVisited(stack.head)
    dfs(nextNodes ::: stack.tail, discovered ++ nextNodes, stack.head :: orderedVisited)

val start = 0
val visitOrder = dfs(List(start), Set(start), Nil)
