:heavy_check_mark: Kruskal (src/graph/kruskal.hpp)

Depends on

Verified with

Code

#ifndef GRAPH_KRUSKAL_HPP
#define GRAPH_KRUSKAL_HPP

#include "src/datastructure/disjoint_set_union.hpp"

#include <algorithm>
#include <vector>

template<class T>
class Kruskal {
private:
	struct Edge {
		int from, to;
		T cost;

		Edge(int from, int to, T cost) : from(from), to(to), cost(cost) {}

		bool operator<(Edge const &a) { return cost < a.cost; }
	};

	int n;
	std::vector<Edge> g;

public:
	Kruskal(int n) : n(n) {}

	void add_edge(int from, int to, T cost) { g.emplace_back(from, to, cost); }

	T mst_cost() {
		T cost = 0;
		std::sort(g.begin(), g.end());
		DisjointSetUnion dsu(n);
		cost = 0;
		for (Edge const &e : g) {
			if (!dsu.same(e.from, e.to)) {
				cost += e.cost;
				dsu.unite(e.from, e.to);
			}
		}
		return cost;
	}
};

#endif // GRAPH_KRUSKAL_HPP
#line 1 "src/graph/kruskal.hpp"



#line 1 "src/datastructure/disjoint_set_union.hpp"



#include <numeric>
#include <utility>
#include <vector>

class DisjointSetUnion {
private:
	std::vector<int> rank, size, p;
	int num{};

public:
	DisjointSetUnion(int n) : rank(n), size(n, 1), p(n), num(n) {
		std::iota(p.begin(), p.end(), 0);
	}

	bool same(int x, int y) { return root(x) == root(y); }

	void unite(int x, int y) {
		x = root(x);
		y = root(y);
		if (x == y) return;
		--num;
		if (rank[x] < rank[y]) std::swap(x, y);
		if (rank[x] == rank[y]) ++rank[x];
		p[y] = x;
		size[x] += size[y];
	}

	int root(int x) { return p[x] == x ? x : p[x] = root(p[x]); }

	int get_size(int x) { return size[root(x)]; }

	[[nodiscard]] int forest_size() const { return num; }
};


#line 5 "src/graph/kruskal.hpp"

#include <algorithm>
#line 8 "src/graph/kruskal.hpp"

template<class T>
class Kruskal {
private:
	struct Edge {
		int from, to;
		T cost;

		Edge(int from, int to, T cost) : from(from), to(to), cost(cost) {}

		bool operator<(Edge const &a) { return cost < a.cost; }
	};

	int n;
	std::vector<Edge> g;

public:
	Kruskal(int n) : n(n) {}

	void add_edge(int from, int to, T cost) { g.emplace_back(from, to, cost); }

	T mst_cost() {
		T cost = 0;
		std::sort(g.begin(), g.end());
		DisjointSetUnion dsu(n);
		cost = 0;
		for (Edge const &e : g) {
			if (!dsu.same(e.from, e.to)) {
				cost += e.cost;
				dsu.unite(e.from, e.to);
			}
		}
		return cost;
	}
};


Back to top page