diff --git a/lib/graph.ex b/lib/graph.ex index b66670f..8e8c959 100644 --- a/lib/graph.ex +++ b/lib/graph.ex @@ -1579,9 +1579,15 @@ defmodule Graph do ...> g = Graph.add_edges(g, [{:a, :b}, {:a, :c}, {:b, :c}, {:c, :d}]) ...> Graph.reachable(g, [:a]) [:d, :c, :b, :a] + + iex> g = Graph.new(type: :undirected) |> Graph.add_vertices([:a, :b, :c, :d]) + ...> g = Graph.add_edges(g, [{:a, :b}, {:c, :d}]) + ...> Graph.reachable(g, [:a]) + [:b, :a] """ @spec reachable(t, [vertex]) :: [[vertex]] - defdelegate reachable(g, vs), to: Graph.Directed + def reachable(%__MODULE__{type: :undirected} = g, vs), do: Graph.Undirected.reachable(g, vs) + def reachable(%__MODULE__{} = g, vs), do: Graph.Directed.reachable(g, vs) @doc """ Returns an unsorted list of vertices from the graph, such that for each vertex in the list (call it `v`), diff --git a/lib/graph/undirected.ex b/lib/graph/undirected.ex new file mode 100644 index 0000000..bc4f3d8 --- /dev/null +++ b/lib/graph/undirected.ex @@ -0,0 +1,22 @@ +defmodule Graph.Undirected do + @moduledoc false + + def reachable(%Graph{type: :undirected} = graph, vertices) when is_list(vertices) do + vertices + |> collect_reachable(graph, MapSet.new()) + |> MapSet.to_list() + end + + defp collect_reachable([], _graph, seen), do: seen + + defp collect_reachable([vertex | rest], graph, seen) do + if MapSet.member?(seen, vertex) do + collect_reachable(rest, graph, seen) + else + new_seen = MapSet.put(seen, vertex) + neighbors = Graph.neighbors(graph, vertex) + neighbors_seen = collect_reachable(neighbors, graph, new_seen) + collect_reachable(rest, graph, neighbors_seen) + end + end +end diff --git a/test/directed_test.exs b/test/directed_test.exs new file mode 100644 index 0000000..512bb4c --- /dev/null +++ b/test/directed_test.exs @@ -0,0 +1,36 @@ +defmodule DirectedTest do + use ExUnit.Case, async: true + + describe "reachable/2" do + test "returns empty list for empty graph" do + assert Graph.new() |> Graph.reachable([]) == [] + end + + test "returns starting vertex if no edges" do + graph = Graph.new() |> Graph.add_vertex(:a) + assert Graph.reachable(graph, [:a]) == [:a] + end + + test "returns reachable vertices" do + graph = + Graph.new() + |> Graph.add_edge(:a, :b) + |> Graph.add_edge(:b, :c) + |> Graph.add_edge(:d, :e) + + assert Graph.reachable(graph, [:a]) |> Enum.sort() == [:a, :b, :c] + assert Graph.reachable(graph, [:d]) |> Enum.sort() == [:d, :e] + assert Graph.reachable(graph, [:a, :d]) |> Enum.sort() == [:a, :b, :c, :d, :e] + end + + test "handles cycles" do + graph = + Graph.new() + |> Graph.add_edge(:a, :b) + |> Graph.add_edge(:b, :c) + |> Graph.add_edge(:c, :a) + + assert Graph.reachable(graph, [:a]) |> Enum.sort() == [:a, :b, :c] + end + end +end diff --git a/test/undirected_test.exs b/test/undirected_test.exs new file mode 100644 index 0000000..a58d8a0 --- /dev/null +++ b/test/undirected_test.exs @@ -0,0 +1,70 @@ +defmodule UndirectedTest do + use ExUnit.Case, async: true + + describe "reachable/2" do + test "includes all vertices in connected component" do + graph = + Graph.new(type: :undirected) + |> Graph.add_edges([ + {:a, :b}, + {:b, :c} + ]) + + assert Enum.sort(Graph.Undirected.reachable(graph, [:a])) == [:a, :b, :c] + end + + test "handles multiple starting vertices" do + graph = + Graph.new(type: :undirected) + |> Graph.add_edges([ + {:a, :b}, + {:c, :d} + ]) + + assert Enum.sort(Graph.Undirected.reachable(graph, [:a, :c])) == [:a, :b, :c, :d] + end + + test "returns only starting vertex if isolated" do + graph = + Graph.new(type: :undirected) + |> Graph.add_vertex(:a) + |> Graph.add_vertex(:b) + + assert Graph.Undirected.reachable(graph, [:a]) == [:a] + end + + test "handles empty starting set" do + graph = + Graph.new(type: :undirected) + |> Graph.add_vertex(:a) + + assert Graph.Undirected.reachable(graph, []) == [] + end + + test "handles multiple components" do + graph = + Graph.new(type: :undirected) + |> Graph.add_edges([ + {:a, :b}, + {:b, :c}, + {:d, :e} + ]) + + result = + Graph.vertices(graph) + |> Enum.map(fn vertex -> + graph + |> Graph.Undirected.reachable([vertex]) + |> Enum.sort() + end) + + assert result == [ + [:a, :b, :c], + [:a, :b, :c], + [:a, :b, :c], + [:d, :e], + [:d, :e] + ] + end + end +end