HDU5977

还是一棵树,有$k$种点,求有多少路径正好包含$k$种点。$(n\leq10^5,k\leq10)$

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <iostream>
#include <queue>
#include <vector>
#include <cstring>
#include <algorithm>
#include <cstdio>
using namespace std;
typedef long long ll;
#define inf 0x7fffffff
const int N = 5e5 + 10;
struct node
{
int to, next;
} e[N << 1];
vector<int> s;
int head[N], cnt, n, k, w[N], ss;
ll ans,dp[1100];
int f[N], son[N], vis[N], root, sum;
void addedge(int u, int v)
{
e[++cnt].to = v;
e[cnt].next = head[u];
head[u] = cnt;
}
void getroot(int x, int fa)
{
son[x] = 1;
f[x] = 0;
for (int i = head[x]; i; i = e[i].next)
{
if (e[i].to == fa || vis[e[i].to])
continue;
getroot(e[i].to, x);
son[x] += son[e[i].to];
f[x] = max(f[x], son[e[i].to]);
}
f[x] = max(f[x], sum - son[x]);
if (f[x] < f[root])
root = x;
}
void getdeep(int x, int fa, int st)
{
s.push_back(st);
for (int i = head[x]; i; i = e[i].next)
if (e[i].to != fa && !vis[e[i].to])
{
int p = w[e[i].to];
getdeep(e[i].to, x, st | (1 <<p));
}
}
ll calc(int x, int now)
{
ll res = 0;
s.clear();
getdeep(x, 0, now);
memset(dp,0,sizeof(dp));
for (int i = 0; i < s.size(); i++)
dp[s[i]]++;
for (int i = 0; i < s.size(); i++)
{
//cout << i << " " << dp[i] << endl;
dp[s[i]]--;
res += dp[ss];
for (int j = s[i]; j; j = (j - 1) & s[i])
res += dp[ss ^ j];
dp[s[i]]++;
}

//cout << res << endl;
return res;
}
void slove(int x)
{
ans += calc(x, 1<<w[x]);
//cout << "-------" << endl;
//cout << ans << endl;
vis[x] = 1;
for (int i = head[x]; i; i = e[i].next)
if (!vis[e[i].to])
{
int v=e[i].to;
sum = son[v];
root = 0;
getroot(v, root);
ans -= calc(v, 1 << w[v]|1<<w[x]);
slove(root);
}
}
int main()
{
while (scanf("%d%d", &n, &k) == 2)
{
ans = 0;
root = 0;
cnt = 0;
sum = n;
f[0] = inf;
memset(vis, 0, sizeof(vis));
memset(son, 0, sizeof(son));
memset(head, 0, sizeof(head));
ss = (1 << k) - 1;
for (int i = 1; i <= n; i++)
{
scanf("%d", &w[i]);
w[i]--;
}
for (int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d%d", &u, &v);
addedge(u, v);
addedge(v, u);
}
if (k == 1)
{
printf("%d\n", n * n);
continue;
}
getroot(1, root);
//cout << root << endl;
slove(root);
printf("%lld\n", ans);
}
}