P6478 [NOI Online 提高组]游戏

一棵$2n$个结点的树,进行$n$个回合,每个回合在树上选择一个白色和黑色的结点,如果某个在某个子树内则不平局,问恰好有$k$个非平局的情况。

假设有$x$个平均其他乱选的非平局为$g[x]$,则有:

设$dp[u][x]$为在$u$这颗子树选择$x$对非平局的情况。$g(i)=dp[1][i]\times 2^{2n-2i}$

对于子树的转移,如果不选当前结点即$dp[u][i]=\sum dp[v_1][a]\times dp[v_2][b]\times dp[v_2][c]\times dp[v_4]d$

其实就是好多卷积,但是不需要用$FFT$,单纯$n^2$即可,可以证明复杂度为$O(n^2)$

如果选择的这颗子树$dp[u][i+1]=dp[u][i]\times(cnt[a[u]\oplus1]-i)$

  • 注意开个临时数组保存卷积
  • 注意不要被覆盖
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 5e3 + 100;
const int mod = 998244353;

int inc(int x, int y) { return (x + y >= mod) ? (x + y - mod) : (x + y); }
int del(int x, int y) { return (x - y < 0) ? (x - y + mod) : (x - y); }
int qpow(int a, int x, int mo)
{
int res = 1;
while (x)
{
if (x & 1)
res = 1ll * res * a % mo;
a = 1ll * a * a % mo;
x >>= 1;
}
return res;
}

int fac[N], facinv[N];
void prepare()
{
fac[0] = 1;
facinv[0] = 1;
for (int i = 1; i < N; i++)
fac[i] = 1ll * fac[i - 1] * i % mod;
facinv[N - 1] = qpow(fac[N - 1], mod - 2, mod);
for (int i = N - 2; i >= 1; i--)
facinv[i] = 1ll * facinv[i + 1] * (i + 1) % mod;
}
int C(int n, int i)
{
if (i == 0)
return 1;
if (n <= 0)
return 0;
if (i > n)
return 0;

return 1ll * fac[n] * facinv[i] % mod * facinv[n - i] % mod;
}
int a[N];
int dp[N][N], cnt[N][2], siz[N];
int n;
char s[N];
vector<int> g[N];
int tmp[N];
void dfs(int x, int fa)
{
dp[x][0] = 1;

siz[x] = 1;
for (int to : g[x])
{
if (to == fa)
continue;
dfs(to, x);
for (int i = 0; i <= siz[x] + siz[to]; i++)
tmp[i] = 0;
for (int i = 0; i <= siz[x]; i++)
for (int j = 0; j <= siz[to]; j++)
{
tmp[i + j] = inc(tmp[i + j], 1ll * dp[x][i] * dp[to][j] % mod);
}

siz[x] += siz[to];
for (int i = 0; i <= siz[x]; i++)
dp[x][i] = tmp[i];

cnt[x][0] += cnt[to][0];
cnt[x][1] += cnt[to][1];
}

for (int i = siz[x]; i >= 0; i--)
if (cnt[x][a[x] ^ 1] > i)
dp[x][i + 1] = inc(dp[x][i + 1], 1ll * dp[x][i] * (cnt[x][a[x] ^ 1] - i) % mod);
cnt[x][a[x]]++;
}
int ans[N];
int main()
{
prepare();
scanf("%d", &n);
scanf("%s", s + 1);
for (int i = 1; i <= n; i++)
a[i] = s[i] - '0';

for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);

for (int i = 0; i <= n / 2; i++)
{
for (int k = i; k <= n / 2; k++)
ans[i] = inc(ans[i], 1ll * dp[1][k] * fac[n / 2 - k] % mod * C(k, i) % mod * ((k - i) & 1 ? mod - 1 : 1) % mod);

printf("%d\n", ans[i]);
}
}