Minimum Spanning Tree Algorithm Implementations
Minimum Spanning Tree Algorithms
Author: Ben Rosenberg
Imports
We begin by importing some relevant libraries. We'll use random
to generate some large graphs, and we'll use time
to see how well our algorithms perform. We'll use matplotlib.pyplot
to plot the performance of the algorithms against each other.
import random
import time
from matplotlib import pyplot as plt
Input data
Next, we define our input data. Recall that in the Minimum Spanning Tree problem we have as an input a graph G = (N, E), where N is the set of nodes, and E is the set of edges, each of which has a cost c(i,j) for all (i,j)\in E
The input data below corresponds to a small example graph, seen below, in which the numbers inside the nodes are arbitrary indices, and the number corresponding to each edge denotes the capacity of that edge:
Note that the graph appears directed but this is not intentional; please pretend that the arrows are not there (MST is supposed to be on an undirected graph.)
nodes = range(8)
edges = {
(0,1), (0,3), (1,2), (1,3), (1,7), (2,3), (2,5), (3,4),
(3,6), (3,7), (4,0), (4,5), (5,1), (5,3), (5,6), (6,4),
(7,0)
}
# cost[i,j] is the cost of edge (i,j)
cost = {
(0,1) : 5, (0,3) : 3, (1,2) : 3, (1,3) : 1, (1,7) : 2,
(2,3) : 2, (2,5) : 6, (3,4) : 4, (3,6) : 6, (3,7) : 3,
(4,0) : 5, (4,5) : 3, (5,1) : 4, (5,3) : 2, (5,6) : 4,
(6,4) : 4, (7,0) : 1
}
Prim's Algorithm
We'll start with Prim's Algorithm. In Prim's Algorithm, we do the following:
- Choose some node i in the graph
- Find the smallest edge connected to iand, if adding it to our subgraph doesn't create a cycle, add it to our subgraph
- Repeat (2) until it is no longer possible to add such edges (or all nodes have been reached, which should be equivalent)
def prim(nodes, edges, cost):
current_nodes = {nodes[0]}
current_edges = set()
def adj_edges(current_nodes):
return {edge for edge in edges if len({*edge} & current_nodes) == 1}
while True:
candidate_edges = adj_edges(current_nodes) - current_edges
if len(candidate_edges) == 0:
break
best_edge = min(candidate_edges, key=lambda edge:cost[edge])
current_edges.add(best_edge)
current_nodes |= {*best_edge}
total_cost = sum(cost[edge] for edge in current_edges)
return current_edges, total_cost
print(prim(nodes, edges, cost))
({(1, 3), (7, 0), (4, 5), (5, 6), (2, 3), (1, 7), (5, 3)}, 15)
Kruskal's Algorithm
Now let's look at Kruskal's Algorithm. In Kruskal's Algorithm, we do the following:
- Find the smallest edge in the graph and, if adding it to our subgraph doesn't create a cycle, add it to our subgraph
- Repeat (1) until it is no longer possible to add such edges (or all nodes have been reached, which should be equivalent)
In the below implementation, the union-find data structure is used to make checking for cycles more efficient, but at its heart the algorithm is the same.
def kruskal(nodes, edges, cost):
def find(node, parent):
while not node == parent[node]:
node = parent[node]
return node
def union(edge, parent):
i,j = edge
if find(i, parent) != find(j, parent):
parent[find(i, parent)] = find(j, parent)
return parent
def connected(edge, parent):
i,j = edge
return find(i, parent) == find(j, parent)
sorted_edges = sorted(edges, key=lambda edge:cost[edge])
parent = list(nodes)
edge = sorted_edges.pop(0)
current_edges = {edge}
parent = union(edge, parent)
while len(current_edges) < len(nodes) - 1:
edge = sorted_edges.pop(0)
if not connected(edge, parent):
current_edges.add(edge)
parent = union(edge, parent)
total_cost = sum(cost[edge] for edge in current_edges)
return current_edges, total_cost
print(kruskal(nodes, edges, cost))
({(6, 4), (1, 3), (7, 0), (4, 5), (2, 3), (1, 7), (5, 3)}, 15)
Stress tests
Now let's see how our algorithms stack up against larger instances. We'll be using the time
module here as a benchmark for efficiency. Each instance will have some number of nodes n, and 10\cdot n edges. The edges chosen will be random, and costs will also be randomized in the range 1..30
.
Note that this function (generate_instance
) will hang for n \leq 10 because the maximum number of edges is n^2and we disallow edges from a node to itself.
def generate_instance(n):
# make results reproducible
random.seed(1)
nodes = range(n)
edges = set()
while len(edges) < 10 * n:
i,j = random.randint(0,n-1), random.randint(0,n-1)
if i != j:
edges.add((i,j))
cost = {edge : random.randint(1,30) for edge in edges}
return nodes, edges, cost
tests = [
25, 50, 100, 150, 200, 250, 500, 750, 1000, 1500, 2000
]
prim_times = []
kruskal_times = []
for test in tests:
instance = generate_instance(test)
start = time.time()
prim(*instance)
end = time.time()
prim_times.append(end - start)
start = time.time()
kruskal(*instance)
end = time.time()
kruskal_times.append(end - start)
plt.title('Kruskal\'s vs. Prim\'s')
plt.xlabel('Instance size')
plt.ylabel('Time to solve')
plt.plot(tests, prim_times)
plt.plot(tests, kruskal_times)
plt.legend(['Prim\'s', 'Kruskal\'s'], loc='upper left')