:heavy_check_mark: Interval Set (src/datastructure/interval_set.hpp)

Verified with

Code

#ifndef DATASTRUCTURE_INTERVAL_SET
#define DATASTRUCTURE_INTERVAL_SET

#include <cassert>
#include <cstdint>
#include <iterator>
#include <limits>
#include <numeric>
#include <set>
#include <utility>

template<typename T>
class IntervalSet {
private:
	static constexpr T MIN = std::numeric_limits<T>::min();
	static constexpr T MAX = std::numeric_limits<T>::max();

	std::set<std::pair<T, T>> intervals;
	T adjacent_offset;

public:
	using iterator = typename std::set<std::pair<T, T>>::iterator;

	IntervalSet(bool enable_merge_adjacent = true) :
		intervals{
			{MIN, MIN},
            {MAX, MAX}
    },
		adjacent_offset(enable_merge_adjacent ? 1 : 0) {}

	iterator begin() const { return std::next(intervals.begin()); }

	iterator end() const { return std::prev(intervals.end()); }

	[[nodiscard]] iterator find(T x) const {
		auto it = std::prev(intervals.upper_bound({x, MAX}));
		return it->first <= x && x <= it->second ? it : end();
	}

	std::uint64_t insert(T x) { return insert(x, x); }

	std::uint64_t insert(T l, T r) { // [l, r]
		assert(l <= r);

		auto left_it = intervals.upper_bound({l, MAX});
		auto right_it = intervals.upper_bound({r + adjacent_offset, MAX});
		if (left_it != intervals.begin() &&
			l - adjacent_offset <= std::prev(left_it)->second)
			--left_it;

		std::uint64_t count =
			r - l + 1 -
			std::accumulate(left_it, right_it, std::uint64_t{0}, [](auto acc, auto x) {
				return acc + x.second - x.first + 1;
			});

		T ll = left_it->first;
		T rr = std::prev(right_it)->second;
		intervals.erase(left_it, right_it);
		if (ll < l) {
			count += l - ll;
			l = ll;
		}
		if (r < rr) {
			count += rr - r;
			r = rr;
		}

		intervals.emplace(l, r);
		return count;
	}

	std::uint64_t erase(T x) { return erase(x, x); }

	std::uint64_t erase(T l, T r) { // [l, r]
		assert(l <= r);

		auto left_it = intervals.upper_bound({l, MAX});
		auto right_it = intervals.upper_bound({r, MAX});
		if (left_it != intervals.begin() && l <= std::prev(left_it)->second) --left_it;
		if (left_it == right_it) return 0;

		std::uint64_t count =
			std::accumulate(left_it, right_it, std::uint64_t{0}, [](auto acc, auto x) {
				return acc + x.second - x.first + 1;
			});

		T ll = left_it->first;
		T rr = std::prev(right_it)->second;
		intervals.erase(left_it, right_it);
		if (ll < l) {
			count -= l - ll;
			intervals.emplace(ll, l - 1);
		}
		if (r < rr) {
			count -= rr - r;
			intervals.emplace(r + 1, rr);
		}

		return count;
	}

	[[nodiscard]] bool covered(T x) const { return covered(x, x); }

	[[nodiscard]] bool covered(T l, T r) const { // [l, r]
		assert(l <= r);
		auto it = std::prev(intervals.upper_bound({r, MAX}));
		return it->first <= l && r <= it->second;
	}

	[[nodiscard]] T mex(T x = 0) const {
		auto it = find(x);
		return it == end() ? x : it->second + 1;
	}
};

#endif // DATASTRUCTURE_INTERVAL_SET
#line 1 "src/datastructure/interval_set.hpp"



#include <cassert>
#include <cstdint>
#include <iterator>
#include <limits>
#include <numeric>
#include <set>
#include <utility>

template<typename T>
class IntervalSet {
private:
	static constexpr T MIN = std::numeric_limits<T>::min();
	static constexpr T MAX = std::numeric_limits<T>::max();

	std::set<std::pair<T, T>> intervals;
	T adjacent_offset;

public:
	using iterator = typename std::set<std::pair<T, T>>::iterator;

	IntervalSet(bool enable_merge_adjacent = true) :
		intervals{
			{MIN, MIN},
            {MAX, MAX}
    },
		adjacent_offset(enable_merge_adjacent ? 1 : 0) {}

	iterator begin() const { return std::next(intervals.begin()); }

	iterator end() const { return std::prev(intervals.end()); }

	[[nodiscard]] iterator find(T x) const {
		auto it = std::prev(intervals.upper_bound({x, MAX}));
		return it->first <= x && x <= it->second ? it : end();
	}

	std::uint64_t insert(T x) { return insert(x, x); }

	std::uint64_t insert(T l, T r) { // [l, r]
		assert(l <= r);

		auto left_it = intervals.upper_bound({l, MAX});
		auto right_it = intervals.upper_bound({r + adjacent_offset, MAX});
		if (left_it != intervals.begin() &&
			l - adjacent_offset <= std::prev(left_it)->second)
			--left_it;

		std::uint64_t count =
			r - l + 1 -
			std::accumulate(left_it, right_it, std::uint64_t{0}, [](auto acc, auto x) {
				return acc + x.second - x.first + 1;
			});

		T ll = left_it->first;
		T rr = std::prev(right_it)->second;
		intervals.erase(left_it, right_it);
		if (ll < l) {
			count += l - ll;
			l = ll;
		}
		if (r < rr) {
			count += rr - r;
			r = rr;
		}

		intervals.emplace(l, r);
		return count;
	}

	std::uint64_t erase(T x) { return erase(x, x); }

	std::uint64_t erase(T l, T r) { // [l, r]
		assert(l <= r);

		auto left_it = intervals.upper_bound({l, MAX});
		auto right_it = intervals.upper_bound({r, MAX});
		if (left_it != intervals.begin() && l <= std::prev(left_it)->second) --left_it;
		if (left_it == right_it) return 0;

		std::uint64_t count =
			std::accumulate(left_it, right_it, std::uint64_t{0}, [](auto acc, auto x) {
				return acc + x.second - x.first + 1;
			});

		T ll = left_it->first;
		T rr = std::prev(right_it)->second;
		intervals.erase(left_it, right_it);
		if (ll < l) {
			count -= l - ll;
			intervals.emplace(ll, l - 1);
		}
		if (r < rr) {
			count -= rr - r;
			intervals.emplace(r + 1, rr);
		}

		return count;
	}

	[[nodiscard]] bool covered(T x) const { return covered(x, x); }

	[[nodiscard]] bool covered(T l, T r) const { // [l, r]
		assert(l <= r);
		auto it = std::prev(intervals.upper_bound({r, MAX}));
		return it->first <= l && r <= it->second;
	}

	[[nodiscard]] T mex(T x = 0) const {
		auto it = find(x);
		return it == end() ? x : it->second + 1;
	}
};


Back to top page