CF1327G. Letters and Question Marks

给$t$个总长度不超过$T$的字符串$t$,每个字符串有对应的代价$c_i$

给一个$S$,里面包含$k$个问号$?$,不重复地填$a-l,14$个字符,问最后形成的字符串大价值=$\sum (t)出现的次数*c_i$

$k\leq 14,T\leq 10^3,S\leq 4\times10^5$

如果都是完整的直接$AC$自动机跑步一遍就好了。考虑到不重复,用$2^{14}$当每个字符产生的状态。

$dp[i][j][mask]$搜到$s[i]$,在$AC$自动机第$k$个结点,字符使用状态$mask$。显然复杂度很高。

  • 考虑没有问号的时候,$mask$不改变,枚举所有$j$在$AC$自动机,就可不用转移$mask$。
  • 转移$mask$的时候只转移二进制数量=之前问号数量

这样复杂度就变成$O(T\times S+k\times T\times(C_{k}^0+C_{k}^1+…C_{k}^{k-1}))=O(k2^{k}T+TS)$

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cmath>
#include <cstring>
#include <vector>
#include <map>
using namespace std;
typedef long long ll;
const int N = 2e3 + 10;
const int mod = 1e9 + 7;
const int M = 1 << 14;
const int _M = 4e5 + 10;
const ll inf = 1e18;
const int N_AC = 2e3, M_AC = 14;
struct AC
{
int ch[M_AC], fail;
} tri[N_AC];
ll val[N_AC];
int tcnt, root;
int creat()
{
val[++tcnt] = 0;
tri[tcnt].fail = 0;
memset(tri[tcnt].ch, 0, sizeof(tri[tcnt].ch));
return tcnt;
}

void insert(char *t, ll w)
{
int tmp = root;
int len = strlen(t);
for (int i = 0; i < len; i++)
{
int k = t[i] - 'a';
if (!tri[tmp].ch[k])
tri[tmp].ch[k] = creat();
tmp = tri[tmp].ch[k];
}
val[tmp] += w;
}
void get_AC()
{
queue<int> q;
for (int i = 0; i < M_AC; i++)
if (tri[root].ch[i])
{
int k = tri[root].ch[i];
tri[k].fail = root;
q.push(k);
}
while (!q.empty())
{
int x = q.front();
q.pop();
val[x] += val[tri[x].fail];
for (int i = 0; i < M_AC; i++)
{
int k = tri[x].ch[i];
if (k)
{
tri[k].fail = tri[tri[x].fail].ch[i];
q.push(k);
}
else
{
tri[x].ch[i] = tri[tri[x].fail].ch[i];
}
}
}
}
int n;
char s[_M];

ll dp[N_AC][M];
int pos[N_AC];
ll w[N_AC];
int main()
{
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
char t[N_AC];
ll w;
scanf("%s%lld", t, &w);
insert(t, w);
}
get_AC();

scanf("%s", s + 1);

int len = strlen(s + 1);
for (int i = 0; i <= tcnt; i++)
for (int j = 0; j < M; j++)
dp[i][j] = -inf;
dp[0][0] = 0;
for (int p = 0; p <= tcnt; p++)
{
pos[p] = p;
w[p] = 0;
}
int cnt = 0;
for (int i = 1; i <= len; i++)
{
// cout << "---" << endl;
if (s[i] == '?')
{

for (int p = 0; p <= tcnt; p++)
for (int j = 0; j < M; j++)
if (dp[p][j] != -inf && __builtin_popcount(j) == cnt)
{
//cout << p << " " << j << " " << __builtin_popcount(j) << endl;
for (int t = 0; t < M_AC; t++)
if (!(j & (1 << t)))
{

int ts = j | (1 << t);
int ng = tri[pos[p]].ch[t];

dp[ng][ts] = max(dp[ng][ts], dp[p][j] + w[p] + val[ng]);
}
}
cnt++;
for (int p = 0; p <= tcnt; p++)
{
pos[p] = p;
w[p] = 0;
}
}
else
{
for (int p = 0; p <= tcnt; p++)
{
pos[p] = tri[pos[p]].ch[s[i] - 'a'];
w[p] += val[pos[p]];
}
}
}
ll ans = -inf;
for (int p = 0; p <= tcnt; p++)
for (int j = 0; j < M; j++)
if (dp[p][j] != -inf && __builtin_popcount(j) == cnt)
ans = max(ans, dp[p][j] + w[p]);

printf("%lld\n", ans);
}