CF1254D Tree Queries

给定一棵$N$个节点的树,有$Q$次操作

  • $1\space v\space$ 给定一个点$v$和一个权值$d$,等概率地选择一个点$r$,对每一个点$u$,若$v$在$u$到$r$的路径上,则$u$的权值加上$d$ (权值一开始为$0$)
  • $2\space v$ 查询vv的权值期望,对$998244353$取模 $1\leqslant N,Q\leqslant 150000$

首先点自己的期望$+1$。其次每个子树里的点都会$n-siz[x]$

考虑每次维护重儿子和外子树,即每次维护到的都是$fa[top[x]]$,而$fa[top[x]]$决定了所有轻儿子的贡献。

当查询的时候

  • 首先是自己再树链剖分的值。
  • 其次每次跳重链,显然如果自己是轻儿子,贡献就是$fa[top[x]]\times (n-siz[top[x]])$
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 1e6 + 10;

const int mod = 998244353;

int inc(int x, int y) { return (x + y >= mod) ? (x + y - mod) : (x + y); }
int del(int x, int y) { return (x - y < 0) ? (x - y + mod) : (x - y); }

const int SEN = 2e5 + 10;
struct Segment
{
int sum[SEN << 2], lazy[SEN << 2];

void addlazy(int pos, int l, int r, int w)
{
sum[pos] = inc(sum[pos], 1ll * w * (r - l + 1) % mod);
lazy[pos] = inc(lazy[pos], w);
}
void pushup(int pos)
{
sum[pos] = inc(sum[pos << 1], sum[pos << 1 | 1]);
}
void pushdown(int pos, int l, int r)
{

if (lazy[pos] != 0)
{
int mid = (l + r) >> 1;
addlazy(pos << 1, l, mid, lazy[pos]);
addlazy(pos << 1 | 1, mid + 1, r, lazy[pos]);
lazy[pos] = 0;
}
}
int query(int ql, int qr, int pos, int l, int r)
{
if (ql <= l && r <= qr)
{

return sum[pos];
}
pushdown(pos, l, r);
int mid = (l + r) >> 1;
int ans = 0;
if (ql <= mid)
ans = inc(ans, query(ql, qr, pos << 1, l, mid));
if (qr > mid)
ans = inc(ans, query(ql, qr, pos << 1 | 1, mid + 1, r));
return ans;
}
void update(int ql, int qr, int w, int pos, int l, int r)
{
if (ql > qr)
return;
if (ql <= l && r <= qr)
{

addlazy(pos, l, r, w);
return;
}
pushdown(pos, l, r);
int mid = (l + r) >> 1;

if (ql <= mid)
update(ql, qr, w, pos << 1, l, mid);
if (qr > mid)
update(ql, qr, w, pos << 1 | 1, mid + 1, r);
pushup(pos);
}

} t;

namespace treepo
{
int top[N], son[N], dep[N], f[N], siz[N], dfn[N], ed[N];
vector<int> g[N];
int tag[N];

int tot;
void dfs1(int x, int fa)
{
dep[x] = dep[fa] + 1;
siz[x] = 1;
f[x] = fa;
for (int to : g[x])
{

if (to == fa)
continue;

dfs1(to, x);
siz[x] += siz[to];
if (siz[to] > siz[son[x]])
{
son[x] = to;
}
}
}
void dfs2(int x, int fa, int tp)
{
top[x] = tp;
dfn[x] = ++tot;
if (son[x] != 0)
{
dfs2(son[x], x, tp);
}
for (int to : g[x])
{

if (to == fa)
continue;
if (to == fa || to == son[x])
continue;
dfs2(to, x, to);
}
ed[x] = tot;
}

int query(int u)
{

int res = inc(1ll * tot * tag[u] % mod, t.query(dfn[u], dfn[u], 1, 1, tot));

while (u)
{
res = inc(res, 1ll * tag[f[top[u]]] * (tot - siz[top[u]]) % mod);
u = f[top[u]];
}
return res;
}
void update(int u, int d)
{
tag[u] = inc(tag[u], d);
if (son[u])
t.update(dfn[son[u]], ed[son[u]], 1ll * d * (tot - siz[son[u]]) % mod, 1, 1, tot);
t.update(1, dfn[u] - 1, 1ll * d * siz[u] % mod, 1, 1, tot);
t.update(ed[u] + 1, tot, 1ll * d * siz[u] % mod, 1, 1, tot);
}
void init(int n)
{

dfs1(1, 0);
dfs2(1, 0, 1);
}
} // namespace treepo

using namespace treepo;

int read()
{
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
x = (x << 1) + (x << 3) + c - '0', c = getchar();
return x * f;
}
int qpow(int a, int x, int mo)
{
int res = 1;
while (x)
{
if (x & 1)
res = 1ll * res * a % mo;
a = 1ll * a * a % mo;
x >>= 1;
}
return res;
}

int main()
{
int n = read(), q = read();
for (int i = 1; i < n; i++)
{
int u = read(), v = read();

g[u].push_back(v);
g[v].push_back(u);
}

init(n);

int inv = qpow(n, mod - 2, mod);
for (int i = 1; i <= q; i++)
{

int op = read();
if (op == 1)
{
int u = read(), d = read();
update(u, d);
}
else
{
int u = read();
cout << 1ll * query(u) * inv % mod << endl;
}
}
}