P2664 树上游戏

定义 $s(i,j)$ 为 $i$ 到 $j$ 的颜色数量。

考虑点分治

  • 考虑$p\rightarrow x$

如果在$dfs$往下搜一颗子树的时候在某节点$u$发现某个颜色第一次出现,那么这个颜色随后共享的就是$size[u]$。

最后$\sum$即=$ans[p]$

  • 考虑$x\rightarrow p\rightarrow y$

假设搜到某个结点$x$,$p\rightarrow x$有$k$种颜色($c[p]$不算,其他子树出现的颜色也不算)。那么$ans[x]+=k\times (size[p]-size[x所在子树])$,然后接下来就是加上$p\rightarrow y$的各种值,即就是上题少遍历$x$所在子树所形成的答案。

  • 为了方便处理,每次对一颗子树进行处理的时候,$dfs$先去除影响。
  • 最后$dfs$勿忘清空
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 1e5 + 10;
const int mod = 1e9 + 7;
int siz[N], f[N], vis[N];
int rt, gcnt;
vector<int> g[N];
void GetRoot(int x, int fa)
{
siz[x] = 1;
f[x] = 1;
for (int to : g[x])
{
if (to == fa || vis[to])
continue;
GetRoot(to, x);
siz[x] += siz[to];
f[x] = max(f[x], siz[to]);
}
f[x] = max(f[x], gcnt - siz[x]);
if (f[x] < f[rt])
rt = x;
}
ll sum, S, ans[N], si[N], num;
ll cnt[N], col[N];
int c[N];
void getsi(int x, int fa)
{
si[x] = 1;

for (int to : g[x])
{
if (to == fa || vis[to])
continue;
getsi(to, x);
si[x] += si[to];
}
}

void Change(int x, int fa, int op)
{
cnt[c[x]]++;
if (cnt[c[x]] == 1)
{
sum += si[x] * op;
col[c[x]] += si[x] * op;
}
for (int to : g[x])
{
if (to == fa || vis[to])
continue;
Change(to, x, op);
}
cnt[c[x]]--;
}
void work(int x, int fa)
{
cnt[c[x]]++;
if (cnt[c[x]] == 1)
{
num++;
sum -= col[c[x]];
}
ans[x] += sum + num * S;

for (int to : g[x])
{
if (to == fa || vis[to])
continue;
work(to, x);
}
if (cnt[c[x]] == 1)
{
num--;
sum += col[c[x]];
}
cnt[c[x]]--;
}
void clear(int x, int fa)
{
cnt[c[x]] = col[c[x]] = 0;
for (int to : g[x])
{
if (to == fa || vis[to])
continue;
clear(to, x);
}
}
void calc(int x)
{
getsi(x, 0);
sum = 0;
cnt[c[x]]++;
sum += si[x];
for (int to : g[x])
{
if (vis[to])
continue;
Change(to, x, 1);
}
ans[x] += sum;

ll tmp = sum;
for (int to : g[x])
{
if (vis[to])
continue;

sum = tmp - si[to];

num = 0;
col[c[x]] -= si[to];
Change(to, x, -1);

S = si[x] - si[to];

work(to, x);

col[c[x]] += si[to];
Change(to, x, 1);
}
clear(x, 0);
}

void solve(int x)
{

vis[x] = 1;
calc(x);
for (int to : g[x])
{
if (vis[to])
continue;
gcnt = siz[to];
rt = 0;

GetRoot(to, x);

solve(rt);
}
}
int n;
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &c[i]);
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
f[0] = 1e9;
rt = 0;
gcnt = n;
GetRoot(1, 0);

solve(rt);
for (int i = 1; i <= n; i++)
printf("%lld\n", ans[i]);
}