Algorithms

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub HyunjaeLee/Algorithms

:heavy_check_mark: graph/directed_mst.hpp

Depends on

Verified with

Code

#ifndef DIRECTED_MST_HPP
#define DIRECTED_MST_HPP

#include "disjoint_set/rollback_disjoint_set.hpp"
#include <cassert>
#include <utility>
#include <vector>

template <typename Cost> struct directed_mst {
    explicit directed_mst(int n) : n_(n), heap_(n_, -1) {}

    void add_edge(int from, int to, Cost cost) {
        assert(0 <= from && from < n_ && 0 <= to && to < n_);
        auto id = static_cast<int>(from_.size());
        from_.push_back(from);
        to_.push_back(to);
        cost_.push_back(cost);
        left_.push_back(-1);
        right_.push_back(-1);
        lazy_.push_back(Cost{});
        heap_[to] = merge(heap_[to], id);
    }

    std::pair<Cost, std::vector<int>> run(int root) {
        rollback_disjoint_set dsu(n_);
        Cost result{};
        std::vector<int> seen(n_, -1), path(n_), queue(n_), in(n_, -1);
        seen[root] = root;
        std::vector<std::pair<int, std::vector<int>>> cycles;
        for (auto s = 0; s < n_; ++s) {
            auto u = s, pos = 0, w = -1;
            while (!~seen[u]) {
                if (!~heap_[u]) {
                    return {-1, {}};
                }
                push(heap_[u]);
                auto e = heap_[u];
                result += cost_[e];
                lazy_[heap_[u]] -= cost_[e];
                heap_[u] = pop(heap_[u]);
                queue[pos] = e;
                path[pos++] = u;
                seen[u] = s;
                u = dsu.find(from_[e]);
                if (seen[u] == s) {
                    auto cycle = -1;
                    auto end = pos;
                    do {
                        w = path[--pos];
                        cycle = merge(cycle, heap_[w]);
                    } while (dsu.merge(u, w));
                    u = dsu.find(u);
                    heap_[u] = cycle;
                    seen[u] = -1;
                    cycles.emplace_back(u,
                                        std::vector<int>(queue.begin() + pos,
                                                         queue.begin() + end));
                }
            }
            for (auto i = 0; i < pos; ++i) {
                in[dsu.find(to_[queue[i]])] = queue[i];
            }
        }
        for (auto it = cycles.rbegin(); it != cycles.rend(); ++it) {
            auto &[u, comp] = *it;
            auto count = static_cast<int>(comp.size()) - 1;
            dsu.rollback(count);
            auto in_edge = in[u];
            for (auto e : comp) {
                in[dsu.find(to_[e])] = e;
            }
            in[dsu.find(to_[in_edge])] = in_edge;
        }
        std::vector<int> parent;
        parent.reserve(n_);
        for (auto i : in) {
            parent.push_back(~i ? from_[i] : -1);
        }
        return {result, parent};
    }

private:
    void push(int u) {
        cost_[u] += lazy_[u];
        if (auto l = left_[u]; ~l) {
            lazy_[l] += lazy_[u];
        }
        if (auto r = right_[u]; ~r) {
            lazy_[r] += lazy_[u];
        }
        lazy_[u] = 0;
    }
    int merge(int u, int v) {
        if (!~u || !~v) {
            return ~u ? u : v;
        }
        push(u);
        push(v);
        if (cost_[u] > cost_[v]) {
            std::swap(u, v);
        }
        right_[u] = merge(v, right_[u]);
        std::swap(left_[u], right_[u]);
        return u;
    }
    int pop(int u) {
        push(u);
        return merge(left_[u], right_[u]);
    }
    const int n_;
    std::vector<int> from_, to_, left_, right_, heap_;
    std::vector<Cost> cost_, lazy_;
};

#endif // DIRECTED_MST_HPP
#line 1 "graph/directed_mst.hpp"



#line 1 "disjoint_set/rollback_disjoint_set.hpp"



#include <cassert>
#include <stack>
#include <utility>
#include <vector>

struct rollback_disjoint_set {
    explicit rollback_disjoint_set(int n) : n_(n), parent_or_size_(n, -1) {}
    int find(int u) const {
        return parent_or_size_[u] < 0 ? u : find(parent_or_size_[u]);
    }
    bool merge(int u, int v) {
        assert(0 <= u && u < n_ && 0 <= v && v < n_);
        u = find(u);
        v = find(v);
        if (u == v) {
            return false;
        }
        if (-parent_or_size_[u] < -parent_or_size_[v]) {
            std::swap(u, v);
        }
        history_.emplace(v, parent_or_size_[v]);
        parent_or_size_[u] += parent_or_size_[v];
        parent_or_size_[v] = u;
        return true;
    }
    bool same(int u, int v) const {
        assert(0 <= u && u < n_ && 0 <= v && v < n_);
        return find(u) == find(v);
    }
    int size(int u) const {
        assert(0 <= u && u < n_);
        return -parent_or_size_[find(u)];
    }
    void rollback() {
        assert(!history_.empty());
        auto [v, val] = history_.top();
        auto u = parent_or_size_[v];
        parent_or_size_[v] = val;
        parent_or_size_[u] -= val;
        history_.pop();
    }
    void rollback(int count) {
        for (auto i = 0; i < count; ++i) {
            rollback();
        }
    }

private:
    int n_;
    std::vector<int> parent_or_size_;
    std::stack<std::pair<int, int>> history_;
};


#line 8 "graph/directed_mst.hpp"

template <typename Cost> struct directed_mst {
    explicit directed_mst(int n) : n_(n), heap_(n_, -1) {}

    void add_edge(int from, int to, Cost cost) {
        assert(0 <= from && from < n_ && 0 <= to && to < n_);
        auto id = static_cast<int>(from_.size());
        from_.push_back(from);
        to_.push_back(to);
        cost_.push_back(cost);
        left_.push_back(-1);
        right_.push_back(-1);
        lazy_.push_back(Cost{});
        heap_[to] = merge(heap_[to], id);
    }

    std::pair<Cost, std::vector<int>> run(int root) {
        rollback_disjoint_set dsu(n_);
        Cost result{};
        std::vector<int> seen(n_, -1), path(n_), queue(n_), in(n_, -1);
        seen[root] = root;
        std::vector<std::pair<int, std::vector<int>>> cycles;
        for (auto s = 0; s < n_; ++s) {
            auto u = s, pos = 0, w = -1;
            while (!~seen[u]) {
                if (!~heap_[u]) {
                    return {-1, {}};
                }
                push(heap_[u]);
                auto e = heap_[u];
                result += cost_[e];
                lazy_[heap_[u]] -= cost_[e];
                heap_[u] = pop(heap_[u]);
                queue[pos] = e;
                path[pos++] = u;
                seen[u] = s;
                u = dsu.find(from_[e]);
                if (seen[u] == s) {
                    auto cycle = -1;
                    auto end = pos;
                    do {
                        w = path[--pos];
                        cycle = merge(cycle, heap_[w]);
                    } while (dsu.merge(u, w));
                    u = dsu.find(u);
                    heap_[u] = cycle;
                    seen[u] = -1;
                    cycles.emplace_back(u,
                                        std::vector<int>(queue.begin() + pos,
                                                         queue.begin() + end));
                }
            }
            for (auto i = 0; i < pos; ++i) {
                in[dsu.find(to_[queue[i]])] = queue[i];
            }
        }
        for (auto it = cycles.rbegin(); it != cycles.rend(); ++it) {
            auto &[u, comp] = *it;
            auto count = static_cast<int>(comp.size()) - 1;
            dsu.rollback(count);
            auto in_edge = in[u];
            for (auto e : comp) {
                in[dsu.find(to_[e])] = e;
            }
            in[dsu.find(to_[in_edge])] = in_edge;
        }
        std::vector<int> parent;
        parent.reserve(n_);
        for (auto i : in) {
            parent.push_back(~i ? from_[i] : -1);
        }
        return {result, parent};
    }

private:
    void push(int u) {
        cost_[u] += lazy_[u];
        if (auto l = left_[u]; ~l) {
            lazy_[l] += lazy_[u];
        }
        if (auto r = right_[u]; ~r) {
            lazy_[r] += lazy_[u];
        }
        lazy_[u] = 0;
    }
    int merge(int u, int v) {
        if (!~u || !~v) {
            return ~u ? u : v;
        }
        push(u);
        push(v);
        if (cost_[u] > cost_[v]) {
            std::swap(u, v);
        }
        right_[u] = merge(v, right_[u]);
        std::swap(left_[u], right_[u]);
        return u;
    }
    int pop(int u) {
        push(u);
        return merge(left_[u], right_[u]);
    }
    const int n_;
    std::vector<int> from_, to_, left_, right_, heap_;
    std::vector<Cost> cost_, lazy_;
};
Back to top page