Home About
Courses Concepts Tools
GitHub

Minimum Spanning Tree Algorithm Implementations

Minimum Spanning Tree Algorithms

Author: Ben Rosenberg

Imports

We begin by importing some relevant libraries. We'll use random to generate some large graphs, and we'll use time to see how well our algorithms perform. We'll use matplotlib.pyplot to plot the performance of the algorithms against each other.

import random
import time
from matplotlib import pyplot as plt

Input data

Next, we define our input data. Recall that in the Minimum Spanning Tree problem we have as an input a graph G = (N, E), where N is the set of nodes, and E is the set of edges, each of which has a cost c(i,j) for all (i,j)\in E

The input data below corresponds to a small example graph, seen below, in which the numbers inside the nodes are arbitrary indices, and the number corresponding to each edge denotes the capacity of that edge:

Note that the graph appears directed but this is not intentional; please pretend that the arrows are not there (MST is supposed to be on an undirected graph.)

nodes = range(8)

edges = {
    (0,1), (0,3), (1,2), (1,3), (1,7), (2,3), (2,5), (3,4), 
    (3,6), (3,7), (4,0), (4,5), (5,1), (5,3), (5,6), (6,4), 
    (7,0)
}

# cost[i,j] is the cost of edge (i,j)
cost = {
    (0,1) : 5, (0,3) : 3, (1,2) : 3, (1,3) : 1, (1,7) : 2, 
    (2,3) : 2, (2,5) : 6, (3,4) : 4, (3,6) : 6, (3,7) : 3,
    (4,0) : 5, (4,5) : 3, (5,1) : 4, (5,3) : 2, (5,6) : 4,
    (6,4) : 4, (7,0) : 1
}

Prim's Algorithm

We'll start with Prim's Algorithm. In Prim's Algorithm, we do the following:

  1. Choose some node i in the graph
  2. Find the smallest edge connected to iand, if adding it to our subgraph doesn't create a cycle, add it to our subgraph
  3. Repeat (2) until it is no longer possible to add such edges (or all nodes have been reached, which should be equivalent)
def prim(nodes, edges, cost):
    current_nodes = {nodes[0]}
    current_edges = set()
    def adj_edges(current_nodes):
        return {edge for edge in edges if len({*edge} & current_nodes) == 1}
    while True:
        candidate_edges = adj_edges(current_nodes) - current_edges
        if len(candidate_edges) == 0:
            break
        best_edge = min(candidate_edges, key=lambda edge:cost[edge])
        current_edges.add(best_edge)
        current_nodes |= {*best_edge}
    total_cost = sum(cost[edge] for edge in current_edges)
    return current_edges, total_cost
print(prim(nodes, edges, cost))
({(1, 3), (7, 0), (4, 5), (5, 6), (2, 3), (1, 7), (5, 3)}, 15)

Kruskal's Algorithm

Now let's look at Kruskal's Algorithm. In Kruskal's Algorithm, we do the following:

  1. Find the smallest edge in the graph and, if adding it to our subgraph doesn't create a cycle, add it to our subgraph
  2. Repeat (1) until it is no longer possible to add such edges (or all nodes have been reached, which should be equivalent)

In the below implementation, the union-find data structure is used to make checking for cycles more efficient, but at its heart the algorithm is the same.

def kruskal(nodes, edges, cost):
    def find(node, parent):
        while not node == parent[node]:
            node = parent[node]
        return node
    def union(edge, parent):
        i,j = edge
        if find(i, parent) != find(j, parent):
            parent[find(i, parent)] = find(j, parent)
        return parent
    def connected(edge, parent):
        i,j = edge
        return find(i, parent) == find(j, parent)
    sorted_edges = sorted(edges, key=lambda edge:cost[edge])
    parent = list(nodes)
    edge = sorted_edges.pop(0)
    current_edges = {edge}
    parent = union(edge, parent)
    while len(current_edges) < len(nodes) - 1:
        edge = sorted_edges.pop(0)
        if not connected(edge, parent):
            current_edges.add(edge)
            parent = union(edge, parent)
    total_cost = sum(cost[edge] for edge in current_edges)
    return current_edges, total_cost
print(kruskal(nodes, edges, cost))
({(6, 4), (1, 3), (7, 0), (4, 5), (2, 3), (1, 7), (5, 3)}, 15)

Stress tests

Now let's see how our algorithms stack up against larger instances. We'll be using the time module here as a benchmark for efficiency. Each instance will have some number of nodes n, and 10\cdot n edges. The edges chosen will be random, and costs will also be randomized in the range 1..30.

Note that this function (generate_instance) will hang for n \leq 10 because the maximum number of edges is n^2and we disallow edges from a node to itself.

def generate_instance(n):
    # make results reproducible
    random.seed(1)
    nodes = range(n)
    edges = set()
    while len(edges) < 10 * n:
        i,j = random.randint(0,n-1), random.randint(0,n-1)
        if i != j:
            edges.add((i,j))
    cost = {edge : random.randint(1,30) for edge in edges}
    return nodes, edges, cost
tests = [
    25, 50, 100, 150, 200, 250, 500, 750, 1000, 1500, 2000
]

prim_times = []
kruskal_times = []
for test in tests:
    instance = generate_instance(test)
    start = time.time()
    prim(*instance)
    end = time.time()
    prim_times.append(end - start)
    start = time.time()
    kruskal(*instance)
    end = time.time()
    kruskal_times.append(end - start)
plt.title('Kruskal\'s vs. Prim\'s')
plt.xlabel('Instance size')
plt.ylabel('Time to solve')
plt.plot(tests, prim_times)
plt.plot(tests, kruskal_times)
plt.legend(['Prim\'s', 'Kruskal\'s'], loc='upper left')