线段树合并

线段树合并

前置

动态开点权值线段树。

思想

用权值线段树保持一个子树的状态。暴力往上合并。

复杂度

  • 加点$log n$
  • 合并,若重复点$m$,$mlog n$ ,显然重复点不会那么多,如果 $n$ 与加入的总点数规模基本相同,我们就可以把它理解成每次操作 $O(logn)$

例题

P3605

找到每个父亲的儿子权值大于父亲的个数,离散化合并就可。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 1e5 + 10;
const int mod = 1e9 + 7;

inline int read()
{
char ch = getchar();
int num = 0;
bool flag = false;
while (ch < '0' || ch > '9')
{
if (ch == '-')
flag = true;
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
num = num * 10 + ch - '0';
ch = getchar();
}
return flag ? -num : num;
}
/*----OI--------------*/
struct seg
{
int l, r, sum;
} tr[N * 20];
int scnt;
void build(int &pos, int l, int r, int v)
{
if (!pos)
pos = ++scnt;
tr[pos].sum++;
if (l == r)
return;
int mid = (l + r) >> 1;
if (v <= mid)
build(tr[pos].l, l, mid, v);
else
build(tr[pos].r, mid + 1, r, v);
}
int query(int pos, int l, int r, int w)
{
if (!pos)
return 0;
if (w <= l)
return tr[pos].sum;

int mid = (l + r) >> 1;
if (w <= mid)
return query(tr[pos].l, l, mid, w) + query(tr[pos].r, mid + 1, r, w);
else
return query(tr[pos].r, mid + 1, r, w);
}
int merge(int u, int v)
{
if (!u || !v)
return u + v;
int pos = ++scnt;
tr[pos].sum = tr[u].sum + tr[v].sum;
tr[pos].l = merge(tr[u].l, tr[v].l);
tr[pos].r = merge(tr[u].r, tr[v].r);
return pos;
}
vector<int> g[N];
int num, li[N], a[N];
int rt[N], ans[N];
void dfs(int x)
{
// cout << x << "-----" << endl;
for (int to : g[x])
{
dfs(to);
rt[x] = merge(rt[x], rt[to]);
}
//cout << "----" << endl;
ans[x] = query(rt[x], 1, num, a[x] + 1);
build(rt[x], 1, num, a[x]);
}

int main()
{
int n = read();
for (int i = 1; i <= n; i++)
a[i] = read(), li[++num] = a[i];
sort(li + 1, li + 1 + num);
num = unique(li + 1, li + 1 + num) - li - 1;
for (int i = 1; i <= n; i++)
a[i] = lower_bound(li + 1, li + 1 + num, a[i]) - li;
for (int i = 2; i <= n; i++)
{
int fa = read();
g[fa].push_back(i);
}
dfs(1);
for (int i = 1; i <= n; ++i)
printf("%d\n", ans[i]);
}