HDU5977 发表于 2019-11-04 | 阅读数 还是一棵树,有$k$种点,求有多少路径正好包含$k$种点。$(n\leq10^5,k\leq10)$ 代码 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124#include <iostream>#include <queue>#include <vector>#include <cstring>#include <algorithm>#include <cstdio>using namespace std;typedef long long ll;#define inf 0x7fffffffconst int N = 5e5 + 10;struct node{ int to, next;} e[N << 1];vector<int> s;int head[N], cnt, n, k, w[N], ss;ll ans,dp[1100];int f[N], son[N], vis[N], root, sum;void addedge(int u, int v){ e[++cnt].to = v; e[cnt].next = head[u]; head[u] = cnt;}void getroot(int x, int fa){ son[x] = 1; f[x] = 0; for (int i = head[x]; i; i = e[i].next) { if (e[i].to == fa || vis[e[i].to]) continue; getroot(e[i].to, x); son[x] += son[e[i].to]; f[x] = max(f[x], son[e[i].to]); } f[x] = max(f[x], sum - son[x]); if (f[x] < f[root]) root = x;}void getdeep(int x, int fa, int st){ s.push_back(st); for (int i = head[x]; i; i = e[i].next) if (e[i].to != fa && !vis[e[i].to]) { int p = w[e[i].to]; getdeep(e[i].to, x, st | (1 <<p)); }}ll calc(int x, int now){ ll res = 0; s.clear(); getdeep(x, 0, now); memset(dp,0,sizeof(dp)); for (int i = 0; i < s.size(); i++) dp[s[i]]++; for (int i = 0; i < s.size(); i++) { //cout << i << " " << dp[i] << endl; dp[s[i]]--; res += dp[ss]; for (int j = s[i]; j; j = (j - 1) & s[i]) res += dp[ss ^ j]; dp[s[i]]++; } //cout << res << endl; return res;}void slove(int x){ ans += calc(x, 1<<w[x]); //cout << "-------" << endl; //cout << ans << endl; vis[x] = 1; for (int i = head[x]; i; i = e[i].next) if (!vis[e[i].to]) { int v=e[i].to; sum = son[v]; root = 0; getroot(v, root); ans -= calc(v, 1 << w[v]|1<<w[x]); slove(root); }}int main(){ while (scanf("%d%d", &n, &k) == 2) { ans = 0; root = 0; cnt = 0; sum = n; f[0] = inf; memset(vis, 0, sizeof(vis)); memset(son, 0, sizeof(son)); memset(head, 0, sizeof(head)); ss = (1 << k) - 1; for (int i = 1; i <= n; i++) { scanf("%d", &w[i]); w[i]--; } for (int i = 1; i < n; i++) { int u, v, w; scanf("%d%d", &u, &v); addedge(u, v); addedge(v, u); } if (k == 1) { printf("%d\n", n * n); continue; } getroot(1, root); //cout << root << endl; slove(root); printf("%lld\n", ans); }}