CF1334F -Strange Function

定义序列函数 $f(x)$ :所有满足 $x_i>x_{1\cdots i-1}$
组成的序列。如 $f[3,1,2,7,7,3,6,7,8]=[3,7,8]$

给出三个序列 $a,p,b$ ,删除 $a_i$的代价为 $p_i$($p_i$可能为负)。求使 $f(a)=b$ 的最小代价。无解输出$NO$。

一切都要从简单开始,可以考虑$b[i]->b[i+1]$这么转移。

枚举$i=1..m$。设$dp[pos[b[i]]]$表示若选中需要的删除的数字最小值。

显然枚举所有$dp[pos[b[i-1]]$转移过来。为了表示清楚,我设$r$为某个$b[i]$的位置。$l$为$b[i-1]$的某个位置则
$dp[r]=dp[l]+sum(l,r,v\geq b[i-1] \ or \ p[i]<0)$

观察到这一步因为$b[i]$单调,所有可以每次删去一部分$<b[i-1]$,然后树状数组查询。但是这样会$TLE$

因为对于每个状态$dp[pos[b[i]]]$需要重复枚举。由于$pos[b[i]]$是单调增的,并且满足要求的$pos[b[i-1]]$也随之增多,这里可以单调判断。考虑$dp[pos[b[i]][j]]=dp[pos[b[i]][j-1]]+sum(pos[b[i]][j-1]],pos[b[i]][j]],v\geq b[i-1] \ or \ p[i]<0)$。直接转移即可。

  • 注意转移边界开闭
  • 注意启始的状态边界
代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 1e6 + 10;
const int mod = 1e9 + 7;
const ll infll = 1e18;
ll tr[N], n;
void modify(int x, int v)
{
while (x <= n)
tr[x] += v, x += x & -x;
}
ll query(int x)
{
ll ans = 0;
while (x)
ans += tr[x], x -= x & -x;
return ans;
}
ll cal(int l, int r) { return query(r) - query(l - 1); }
ll dp[N];
vector<int> pos[N];
int p[N], a[N], b[N], m;
int main()
{
bool flag = 1;
scanf("%lld", &n);
for (int i = 1; i <= n; i++)
scanf("%d", &a[i]), pos[a[i]].push_back(i);
for (int i = 1; i <= n; i++)
scanf("%d", &p[i]), modify(i, p[i]);
scanf("%d", &m);
for (int i = 1; i <= m; i++)
scanf("%d", &b[i]);
int cnt = 1;
for (int i = 1; i <= n; i++)
if (cnt <= m && a[i] == b[cnt])
++cnt;

if (cnt <= m)
return puts("NO"), 0;
b[0] = 0;
pos[0].push_back(0);
a[n + 1] = n + 1;
pos[a[n + 1]].push_back(n + 1);
b[m + 1] = n + 1;
for (int i = 1; i <= m + 1; i++)
{
if (i > 1)
for (int j = b[i - 2] + 1; j <= b[i - 1]; j++)
for (int k : pos[j])
if (p[k] > 0)
modify(k, -p[k]);
int now = b[i];
int pre = b[i - 1];

//cout << "---" << b[i] << endl;
int t = 0;
for (int j = 0; j < pos[now].size(); j++)
{
int r = pos[now][j];

dp[pos[now][j]] = infll;
if (j)
dp[pos[now][j]] = dp[pos[now][j - 1]] + cal(pos[now][j - 1], pos[now][j] - 1);
int cnt = 0;
for (; t < pos[pre].size() && pos[pre][t] < pos[now][j]; t++)
{
int l = pos[pre][t];
//cout << r << " " << l << endl;
if (dp[l] != infll)
dp[r] = min(dp[r], dp[l] + cal(l + 1, r - 1));
}
}
}

printf("YES\n%lld\n", dp[n + 1]);
}