Algorithms

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub HyunjaeLee/Algorithms

:heavy_check_mark: test/link_cut_tree/dynamic_tree_vertex_set_path_composite.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_tree_vertex_set_path_composite"

#include "link_cut_tree/link_cut_tree.hpp"
#include "monoids/affine_monoid.hpp"
#include "monoids/reversible_monoid.hpp"
#include <atcoder/modint>
#include <bits/stdc++.h>

int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int N, Q;
    std::cin >> N >> Q;
    using mint = atcoder::modint998244353;
    using M = AffineMonoid<mint>;
    using S = ReversibleMonoid<M>;
    link_cut_tree<S, S::op, S::e, S::toggle> lct(N);
    for (auto i = 0; i < N; ++i) {
        int a, b;
        std::cin >> a >> b;
        lct.set(i, M{a, b});
    }
    for (auto i = 0; i < N - 1; ++i) {
        int u, v;
        std::cin >> u >> v;
        lct.link(u, v);
    }
    while (Q--) {
        int t;
        std::cin >> t;
        if (t == 0) {
            int u, v, w, x;
            std::cin >> u >> v >> w >> x;
            lct.cut(u, v);
            lct.link(w, x);
        } else if (t == 1) {
            int p, c, d;
            std::cin >> p >> c >> d;
            lct.set(p, M{c, d});
        } else {
            int u, v, x;
            std::cin >> u >> v >> x;
            auto [a, b] = lct.prod(u, v).val;
            auto ans = a * x + b;
            std::cout << ans.val() << '\n';
        }
    }
}
#line 1 "test/link_cut_tree/dynamic_tree_vertex_set_path_composite.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/dynamic_tree_vertex_set_path_composite"

#line 1 "link_cut_tree/link_cut_tree.hpp"



#include <cassert>
#include <type_traits>
#include <utility>
#include <vector>

template <typename S, auto op, auto e, auto toggle> struct link_cut_tree {
    link_cut_tree(int n)
        : n_(n), left_(n, -1), right_(n, -1), parent_(n, -1), data_(n, e()), sum_(n, e()),
          reversed_(n, false) {}
    int access(int u) {
        assert(0 <= u && u < n_);
        auto result = -1;
        for (auto cur = u; ~cur; cur = parent_[cur]) {
            splay(cur);
            right_[cur] = result;
            update(cur);
            result = cur;
        }
        splay(u);
        return result;
    }
    void make_root(int u) {
        assert(0 <= u && u < n_);
        access(u);
        reverse(u);
        push(u);
    }
    void link(int u, int p) {
        assert(0 <= u && u < n_ && 0 <= p && p < n_);
        make_root(u);
        access(p);
        parent_[u] = p;
        right_[p] = u;
        update(p);
    }
    void cut(int u) {
        assert(0 <= u && u < n_);
        access(u);
        auto p = left_[u];
        left_[u] = -1;
        update(u);
        parent_[p] = -1;
    }
    void cut(int u, int v) {
        assert(0 <= u && u < n_ && 0 <= v && v < n_);
        make_root(u);
        cut(v);
    }
    int lca(int u, int v) {
        assert(0 <= u && u < n_ && 0 <= v && v < n_);
        access(u);
        return access(v);
    }
    void set(int u, S x) {
        assert(0 <= u && u < n_);
        access(u);
        data_[u] = x;
        update(u);
    }
    S get(int u) {
        assert(0 <= u && u < n_);
        access(u);
        return data_[u];
    }
    S prod(int u, int v) {
        assert(0 <= u && u < n_ && 0 <= v && v < n_);
        make_root(u);
        access(v);
        return sum_[v];
    }
    bool connected(int u, int v) {
        assert(0 <= u && u < n_ && 0 <= v && v < n_);
        access(u);
        access(v);
        return u == v || ~parent_[u];
    }

private:
    bool is_root(int u) const {
        auto p = parent_[u];
        return !~p || (left_[p] != u && right_[p] != u);
    }
    void update(int u) {
        if (~u) {
            sum_[u] = data_[u];
            if (auto v = left_[u]; ~v) {
                sum_[u] = op(sum_[v], sum_[u]);
            }
            if (auto v = right_[u]; ~v) {
                sum_[u] = op(sum_[u], sum_[v]);
            }
        }
    }
    void reverse(int u) {
        if (~u) {
            std::swap(left_[u], right_[u]);
            reversed_[u] = !reversed_[u];
            sum_[u] = toggle(sum_[u]);
        }
    }
    void push(int u) {
        if (~u) {
            if (reversed_[u]) {
                reverse(left_[u]);
                reverse(right_[u]);
                reversed_[u] = false;
            }
        }
    }
    void rotate_right(int u) {
        auto p = parent_[u];
        auto g = parent_[p];
        if (left_[p] = right_[u]; ~left_[p]) {
            parent_[right_[u]] = p;
        }
        right_[u] = p;
        parent_[p] = u;
        update(p);
        update(u);
        if (parent_[u] = g; ~parent_[u]) {
            if (left_[g] == p) {
                left_[g] = u;
            }
            if (right_[g] == p) {
                right_[g] = u;
            }
            update(g);
        }
    }
    void rotate_left(int u) {
        auto p = parent_[u];
        auto g = parent_[p];
        if (right_[p] = left_[u]; ~right_[p]) {
            parent_[left_[u]] = p;
        }
        left_[u] = p;
        parent_[p] = u;
        update(p);
        update(u);
        if (parent_[u] = g; ~parent_[u]) {
            if (left_[g] == p) {
                left_[g] = u;
            }
            if (right_[g] == p) {
                right_[g] = u;
            }
            update(g);
        }
    }
    void splay(int u) {
        push(u);
        while (!is_root(u)) {
            auto p = parent_[u];
            if (is_root(p)) {
                push(p);
                push(u);
                if (left_[p] == u) {
                    rotate_right(u);
                } else {
                    rotate_left(u);
                }
            } else {
                auto g = parent_[p];
                push(g);
                push(p);
                push(u);
                if (left_[g] == p) {
                    if (left_[p] == u) {
                        rotate_right(p);
                        rotate_right(u);
                    } else {
                        rotate_left(u);
                        rotate_right(u);
                    }
                } else {
                    if (right_[p] == u) {
                        rotate_left(p);
                        rotate_left(u);
                    } else {
                        rotate_right(u);
                        rotate_left(u);
                    }
                }
            }
        }
    }
    int n_;
    std::vector<int> left_, right_, parent_;
    std::vector<S> data_, sum_;
    std::vector<char> reversed_;
};


#line 1 "monoids/affine_monoid.hpp"



template <typename Z> struct AffineMonoid {
    static AffineMonoid op(AffineMonoid f, AffineMonoid g) {
        auto [a, b] = f;
        auto [c, d] = g;
        return {a * c, b * c + d};
    }
    static AffineMonoid e() { return {1, 0}; }
    Z a, b;
};


#line 1 "monoids/reversible_monoid.hpp"



template <typename M> struct ReversibleMonoid {
    ReversibleMonoid(M x) : val(x), rev(x) {}
    ReversibleMonoid(M x, M y) : val(x), rev(y) {}
    static ReversibleMonoid op(ReversibleMonoid l, ReversibleMonoid r) {
        return {M::op(l.val, r.val), M::op(r.rev, l.rev)};
    }
    static ReversibleMonoid e() { return {M::e(), M::e()}; }
    static ReversibleMonoid toggle(ReversibleMonoid x) { return {x.rev, x.val}; }
    M val, rev;
};


#line 6 "test/link_cut_tree/dynamic_tree_vertex_set_path_composite.test.cpp"
#include <atcoder/modint>
#include <bits/stdc++.h>

int main() {
    std::cin.tie(0)->sync_with_stdio(0);
    int N, Q;
    std::cin >> N >> Q;
    using mint = atcoder::modint998244353;
    using M = AffineMonoid<mint>;
    using S = ReversibleMonoid<M>;
    link_cut_tree<S, S::op, S::e, S::toggle> lct(N);
    for (auto i = 0; i < N; ++i) {
        int a, b;
        std::cin >> a >> b;
        lct.set(i, M{a, b});
    }
    for (auto i = 0; i < N - 1; ++i) {
        int u, v;
        std::cin >> u >> v;
        lct.link(u, v);
    }
    while (Q--) {
        int t;
        std::cin >> t;
        if (t == 0) {
            int u, v, w, x;
            std::cin >> u >> v >> w >> x;
            lct.cut(u, v);
            lct.link(w, x);
        } else if (t == 1) {
            int p, c, d;
            std::cin >> p >> c >> d;
            lct.set(p, M{c, d});
        } else {
            int u, v, x;
            std::cin >> u >> v >> x;
            auto [a, b] = lct.prod(u, v).val;
            auto ans = a * x + b;
            std::cout << ans.val() << '\n';
        }
    }
}
Back to top page