This documentation is automatically generated by online-judge-tools/verification-helper
#define PROBLEM "https://judge.yosupo.jp/problem/tree_path_composite_sum"
#include "dp/rerooting.hpp"
#include "graph/csr_graph.hpp"
#include <atcoder/modint>
#include <bits/stdc++.h>
using Z = atcoder::modint998244353;
int main() {
std::cin.tie(0)->sync_with_stdio(0);
int N;
std::cin >> N;
using EdgeWeight = std::pair<int, int>;
CSRGraph<EdgeWeight> g(N);
std::vector<int> a(N);
std::copy_n(std::istream_iterator<int>(std::cin), N, a.begin());
for (auto i = 0; i < N - 1; ++i) {
int u, v, b, c;
std::cin >> u >> v >> b >> c;
g.add_edge(u, v, {b, c});
}
g.build_undirected();
using Subtree = std::pair<Z, int>;
using Child = std::pair<Z, int>;
auto rake = [&](Child l, Child r) -> Child { return {l.first + r.first, l.second + r.second}; };
auto add_edge = [&](Subtree d, EdgeWeight w) -> Child {
return {w.first * d.first + Z::raw(w.second) * d.second, d.second};
};
auto add_vertex = [&](Child d, int i) -> Subtree { return {d.first + a[i], d.second + 1}; };
auto e = []() -> Child { return {0, 0}; };
auto dp = rerooting(g, rake, add_edge, add_vertex, e);
for (auto [sum, cnt] : dp) {
std::cout << sum.val() << ' ';
}
}#line 1 "test/dp/rerooting.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/tree_path_composite_sum"
#line 1 "dp/rerooting.hpp"
#include <ranges>
#include <vector>
/*
vector<vector<pair<int, EdgeWeight>>> g;
struct Subtree {};
struct Child {};
auto rake = [&](Child l, Child r) -> Child {};
auto add_edge = [&](Subtree d, EdgeWeight w) -> Child {};
auto add_vertex = [&](Child d, int i) -> Subtree {};
auto e = []() -> Child {};
*/
auto rerooting(const auto &g, auto rake, auto add_edge, auto add_vertex, auto e) {
auto n = int(g.size());
using Child = decltype(e());
using Subtree = decltype(add_vertex(e(), 0));
std::vector<Subtree> dp(n), dp_parent(n);
std::vector<int> bfs_order, parent(n, -1);
std::vector<Child> pref(n + 1);
bfs_order.reserve(n);
for (auto root = 0; root < n; ++root) {
if (~parent[root]) {
continue;
}
parent[root] = root;
bfs_order.clear();
bfs_order.push_back(root);
auto q = bfs_order.cbegin();
while (q != bfs_order.cend()) {
auto u = *q++;
for (auto [v, w] : g[u]) {
if (v != parent[u]) {
parent[v] = u;
bfs_order.push_back(v);
}
}
}
for (auto u : bfs_order | std::views::reverse) {
Child sum = e();
for (auto [v, w] : g[u]) {
if (v != parent[u]) {
sum = rake(sum, add_edge(dp[v], w));
}
}
dp[u] = add_vertex(sum, u);
}
for (auto u : bfs_order) {
auto i = 0;
pref[0] = e();
for (auto [v, w] : g[u]) {
auto state = (v == parent[u]) ? dp_parent[u] : dp[v];
pref[i + 1] = rake(pref[i], add_edge(state, w));
++i;
}
auto suff = e();
for (auto [v, w] : g[u] | std::views::reverse) {
if (v != parent[u]) {
Child except_child = rake(pref[i - 1], suff);
dp_parent[v] = add_vertex(except_child, u);
}
auto state = (v == parent[u]) ? dp_parent[u] : dp[v];
suff = rake(add_edge(state, w), suff);
--i;
}
dp[u] = add_vertex(suff, u);
}
}
return dp;
}
#line 1 "graph/csr_graph.hpp"
#include <cassert>
#line 6 "graph/csr_graph.hpp"
#include <type_traits>
#include <utility>
#include <variant>
#line 10 "graph/csr_graph.hpp"
template <typename EdgeWeight = std::monostate, typename NodeWeight = std::monostate>
struct CSRGraph {
static constexpr bool HasNodeWeight = !std::is_same_v<NodeWeight, std::monostate>;
CSRGraph(int n) : n_(n), start_(n + 1) {
if constexpr (HasNodeWeight) {
nodes_.resize(n_);
}
}
void set_node(int u, NodeWeight w) {
assert(0 <= u && u < n_);
if constexpr (HasNodeWeight) {
nodes_[u] = w;
}
}
NodeWeight node_weight(int u) const {
assert(0 <= u && u < n_);
if constexpr (HasNodeWeight) {
return nodes_[u];
} else {
return {};
}
}
void add_edge(int u, int v, EdgeWeight w = {}) {
assert(0 <= u && u < n_ && 0 <= v && v < n_);
raw_edges_.push_back({u, v, w});
}
void build_undirected() {
assert(!built_);
edges_.resize(2 * raw_edges_.size());
for (const auto &e : raw_edges_) {
++start_[e.u + 1];
++start_[e.v + 1];
}
for (int i = 0; i < n_; ++i) {
start_[i + 1] += start_[i];
}
auto counter = start_;
for (const auto &e : raw_edges_) {
edges_[counter[e.u]++] = {e.v, e.w};
edges_[counter[e.v]++] = {e.u, e.w};
}
std::vector<RawEdge>().swap(raw_edges_);
built_ = true;
}
void build_directed() {
assert(!built_);
edges_.resize(raw_edges_.size());
for (const auto &e : raw_edges_) {
++start_[e.u + 1];
}
for (int i = 0; i < n_; ++i) {
start_[i + 1] += start_[i];
}
auto counter = start_;
for (const auto &e : raw_edges_) {
edges_[counter[e.u]++] = {e.v, e.w};
}
std::vector<RawEdge>().swap(raw_edges_);
built_ = true;
}
auto operator[](int u) const {
assert(built_);
assert(0 <= u && u < n_);
constexpr auto f = [](Edge e) { return std::pair(e.to, e.w); };
return std::ranges::subrange(edges_.begin() + start_[u], edges_.begin() + start_[u + 1]) |
std::views::transform(f);
}
int size() const { return n_; }
struct Edge {
int to;
[[no_unique_address]] EdgeWeight w;
};
struct RawEdge {
int u, v;
[[no_unique_address]] EdgeWeight w;
};
int n_;
bool built_ = false;
std::vector<Edge> edges_;
std::vector<int> start_;
std::vector<RawEdge> raw_edges_;
std::vector<NodeWeight> nodes_;
};
#line 5 "test/dp/rerooting.test.cpp"
#include <atcoder/modint>
#include <bits/stdc++.h>
using Z = atcoder::modint998244353;
int main() {
std::cin.tie(0)->sync_with_stdio(0);
int N;
std::cin >> N;
using EdgeWeight = std::pair<int, int>;
CSRGraph<EdgeWeight> g(N);
std::vector<int> a(N);
std::copy_n(std::istream_iterator<int>(std::cin), N, a.begin());
for (auto i = 0; i < N - 1; ++i) {
int u, v, b, c;
std::cin >> u >> v >> b >> c;
g.add_edge(u, v, {b, c});
}
g.build_undirected();
using Subtree = std::pair<Z, int>;
using Child = std::pair<Z, int>;
auto rake = [&](Child l, Child r) -> Child { return {l.first + r.first, l.second + r.second}; };
auto add_edge = [&](Subtree d, EdgeWeight w) -> Child {
return {w.first * d.first + Z::raw(w.second) * d.second, d.second};
};
auto add_vertex = [&](Child d, int i) -> Subtree { return {d.first + a[i], d.second + 1}; };
auto e = []() -> Child { return {0, 0}; };
auto dp = rerooting(g, rake, add_edge, add_vertex, e);
for (auto [sum, cnt] : dp) {
std::cout << sum.val() << ' ';
}
}