CF1313E

给两个长度为$n$的$a,b$字符串。和长度$m$的字符串$s$。

有多少组$[l_a,r_a],[l_b,r_b]$满足

  • 两个区间有交集。
  • $a_{l_a}+a_{l_a+1}…a_{r_a}+b_{l_b}+b_{l_b+1}…b_{r_b}=s$

两个字符串两部分与$s$去匹配,$exkmp$处理$a$与$s$的最长前缀,$b$与$s$的最长后缀(这里倒置$exkmp$)。

我们就可以得到$a_i$最多向后拓展多少,$b_j$最多向前拓展多少。不关心第一个条件。对于一组$a_i,b_j$,可以组成$b_j-(m-a_i+1)$。

考虑第一个条件,需要有交集。对于每个$i$,最差情况$j-i+1=m-1$,此时它们相交临界值。那么只需要查询$x\in[i,min(m,i+m-2)],\sum (b_x-(m-a_i+1))=\sum b_x-num*\sum(m-a_i+1)),num$=满足条件的数,,用树状数组维护下就$ok$。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cmath>
#include <cstring>
#include <vector>
#include <map>
using namespace std;
#define lowbit(x) ((-x) & x)
typedef long long ll;
const int N = 1e6 + 10;
const int mod = 1e9 + 7;
ll tr[2][N];
int m;
ll query(int x, int op)
{

ll res = 0;
while (x)
{
res += tr[op][x];
x -= lowbit(x);
}
return res;
}
void update(int x, int w, int op)
{
if (x == 0)
return;
while (x <= m)
{
tr[op][x] += w;
x += lowbit(x);
}
}
int e1[N], e2[N];
int ne1[N], ne2[N];
void preexkmp(char *t, int mlen, int nxt[])
{

nxt[1] = mlen;
int exlen = 0;
while (exlen + 2 <= mlen && exlen + 1 <= mlen && t[exlen + 2] == t[exlen + 1])
exlen++;
nxt[2] = exlen;
int pl = 2;
for (int i = 3; i <= mlen; i++)
{
int pr = nxt[pl] + pl - 1;
int l2 = i - pl + 1;
int r2 = nxt[pl];
if (i + nxt[l2] - 1 < pr)
{
nxt[i] = nxt[l2];
}
else
{
exlen = max(0, pr - i + 1);
//cout<<i<<" "<<pl<<" "<<exlen<<endl;
while (exlen + i <= mlen && exlen + 1 <= mlen && t[exlen + i] == t[exlen + 1])
exlen++;
nxt[i] = exlen;
pl = i;
}
}
}
void EXKMP(char s[], int nlen, char t[], int mlen, int nxt[], int extend[])
{

preexkmp(t, mlen, nxt);
int exlen = 0;
while (exlen + 1 <= nlen && exlen + 1 <= mlen && s[exlen + 1] == t[exlen + 1])
exlen++;
extend[1] = exlen;
int pl = 1;
for (int i = 2; i <= nlen; i++)
{
int pr = extend[pl] + pl - 1;
int l2 = i - pl + 1;
int r2 = extend[pl];
if (i + nxt[l2] - 1 < pr)
{
extend[i] = nxt[l2];
}
else
{
exlen = max(0, pr - i + 1);
// cout<<i<<" "<<exlen<<endl;
while (exlen + i <= nlen && exlen + 1 <= mlen && s[exlen + i] == t[exlen + 1])
exlen++;
extend[i] = exlen;
pl = i;
}
}
}
int n;
char s[N], t1[N], t2[N];
int main()
{
scanf("%d%d", &n, &m);
scanf("%s", t1 + 1);
scanf("%s", t2 + 1);
scanf("%s", s + 1);
EXKMP(t1, n, s, m, ne1, e1);

reverse(s + 1, s + 1 + m);
reverse(t2 + 1, t2 + 1 + n);
EXKMP(t2, n, s, m, ne2, e2);
reverse(e2 + 1, e2 + 1 + n);
for (int i = 1; i <= n; i++)
e1[i] = (e1[i] == m ? m - 1 : e1[i]), e2[i] = (e2[i] == m ? m - 1 : e2[i]);
ll ans = 0;
for (int i = 1, j = 0; i <= n; i++)
{

while (j + 1 <= n && j + 1 - i + 1 < m)
{
//cout << i << j << endl;
j++;
//cout << e2[j] << endl;
update(e2[j], 1, 0);
update(e2[j], e2[j], 1);
}
//cout << i << endl;
int l = m - e1[i], r = m;
ans += (query(r, 1) - query(l - 1, 1)) - (query(r, 0) - query(l - 1, 0)) * (l - 1);
update(e2[i], -1, 0);
update(e2[i], -e2[i], 1);
}
printf("%lld\n", ans);
}