POJ3415

$S$与$T$长度至少为$k$的公共子串个数。

套路坑定合并两个字符串,中间加上奇怪的字符(保证不会有前缀一边属于$S$,一边属于$T$),生成后缀数组。

对于一个$height[i]>=k$属于$S$串的后缀,统计比他排名小的有多少符合条件的$T$串。字串个数=$(lcp-k+1)$

对于他排名大的有多少符合条件的$T$串,可以通过$height[i]>=k$属于$T$串的后缀,统计比他排名小的有多少符合条件的$S$串。发现同样算法。

实现

利用栈,每次有新的元素加入,判断是否为$T$串:$sum+=hight[i]-k+1$。

当压入栈的$height\leq s[top]$,栈内所有$\geq height$的元素都需要弹出来,$sum-=(height-s.val)*s.num,s.num$表示栈内$s.val$有多少个。当压入栈中位$S$串的时候累加进答案即可。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cmath>
#include <cstring>
#include <vector>
#include <stack>
#include <map>
using namespace std;
#define pii pair<int, int>
#define mk make_pair
typedef long long ll;
const int N = 3e5 + 10;
const int mod = 1e9 + 7;

int sa[N], x[N], y[N], tn[N], he[N];

void get_SA(char *s, int len, int size)
{

int tmp[N]; //辅助数组
for (int i = 1; i <= size; i++)
tn[i] = 0;
for (int i = 1; i <= len; i++)
x[i] = s[i], tn[x[i]]++;
for (int i = 1; i <= size; i++)
tn[i] += tn[i - 1];
for (int i = len; i >= 1; i--)
sa[tn[x[i]]--] = i;
for (int k = 1; k <= len; k <<= 1)
{

int cnt = 0;
for (int i = len - k + 1; i <= len; i++)
y[++cnt] = i;
for (int i = 1; i <= len; i++)
if (sa[i] > k)
y[++cnt] = sa[i] - k;

for (int i = 1; i <= size; i++)
tn[i] = 0;
for (int i = 1; i <= len; i++)
tn[x[i]]++;
for (int i = 1; i <= size; i++)
tn[i] += tn[i - 1];
for (int i = len; i >= 1; i--) //倒叙原因是因为tn[x[y[i]]]是桶里面最大的
sa[tn[x[y[i]]]--] = y[i];
for (int i = 1; i <= len; i++)
tmp[i] = x[i];

cnt = 1;
x[sa[1]] = 1;
for (int i = 2; i <= len; i++)
x[sa[i]] = (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + k] == tmp[sa[i - 1] + k]) ? cnt : ++cnt;
if (cnt == len)
break;
size = cnt;
}

int k = 0;
for (int i = 1; i <= len; i++)
x[sa[i]] = i;
for (int i = 1; i <= len; i++)
{
if (x[i] == 1)
continue;
if (k)
k--;
int j = sa[x[i] - 1];
while (i + k <= len && j + k <= len && s[j + k] == s[i + k])
k++;
he[x[i]] = k;
}
}

char s[N], t[N];
int main()
{
int k;
while (~scanf("%d", &k) && k)
{
scanf("%s", s + 1);
scanf("%s", t + 1);
int len1 = strlen(s + 1);
int len2 = strlen(t + 1);
int len = len1 + len2 + 2;
for (int i = 1; i <= len2; i++)
s[len1 + i + 1] = t[i];
s[len1 + 1] = 1;
s[len] = 0;
get_SA(s, len, 160);
stack<pii> st;
ll sum = 0, ans = 0;

for (int i = 1; i <= len; i++)
{
if (he[i] < k)
{
while (!st.empty())
st.pop();
sum = 0;
}
else
{
int cnt = 0;
if (sa[i - 1] <= len1)
{
sum += he[i] - k + 1;
cnt = 1;
}
while (!st.empty() && st.top().first >= he[i])
{
sum -= (1ll * st.top().second * (st.top().first - he[i]));
cnt += st.top().second;
st.pop();
}
st.push(mk(he[i], cnt));
if (sa[i] > len1 + 1)
ans += sum;
}
}
for (int i = 1; i <= len; i++)
{
if (he[i] < k)
{
while (!st.empty())
st.pop();
sum = 0;
}
else
{
int cnt = 0;
if (sa[i - 1] > len1 + 1)
{
sum += he[i] - k + 1;
cnt = 1;
}
while (!st.empty() && st.top().first >= he[i])
{
sum -= (1ll * st.top().second * (st.top().first - he[i]));
cnt += st.top().second;
st.pop();
}
st.push(mk(he[i], cnt));
if (sa[i] <= len1)
ans += sum;
}
}
printf("%lld\n", ans);
}
}