CF888G Xor-MST

给定 $n$ 个结点的无向完全图。每个点有一个点权为 $a_i$ 。连接 $i$ 号结点和 $j$ 号结点的边的边权为 $a_i\oplus a_j$

求这个图的 MST 的权值

用$01$字典树找最小的边,显然根据异或的性质,就是不断找$lca$深度最大的的两个点。

$n$点正好有$n-1$个$lca$。

即对于所以可以成为$lca$的点暴力搜索两侧点对的最小值,选择小的那边启发式合并。$O(n\log^2n)$

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 2e5 + 10;
const int mod = 1e9 + 7;
int read()
{
int x = 0, f = 1;
char c = getchar();
while (c < '0' || c > '9')
{
if (c == '-')
f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
x = (x << 1) + (x << 3) + c - '0', c = getchar();
return x * f;
}
int a[N];
ll ans;
struct Tri
{
int scnt, ch[N * 30][2], li[N * 30], ri[N * 30];
void insert(int x, int id)
{

int tmp = 0;
if (!li[tmp])
li[tmp] = id;
ri[tmp] = id;
for (int i = 30; i >= 0; i--)
{
int c = (x & (1 << i)) >> i;
if (!ch[tmp][c])
ch[tmp][c] = ++scnt;
tmp = ch[tmp][c];
if (!li[tmp])
li[tmp] = id;
ri[tmp] = id;
}

//cout << tmp << endl;
}
int query(int x, int tmp, int dep)
{
if (dep < 0)
return 0;
int c = (x & (1 << dep)) >> dep;
if (ch[tmp][c])
return (query(x, ch[tmp][c], dep - 1));
else
return (query(x, ch[tmp][c ^ 1], dep - 1) + (1 << dep));
}
void solve(int tmp, int dep)
{
if (dep < 0)
return;
int l = ch[tmp][0];
int r = ch[tmp][1];

if (l && r)
{

int cnt0 = ri[l] - li[l];
int cnt1 = ri[r] - li[r];
if (cnt0 <= cnt1)
{
int res = 2e9;
for (int i = li[l]; i <= ri[l]; i++)
{

res = min(query(a[i], r, dep - 1) + (1 << dep), res);
}

ans += res;
}
else
{
int res = 2e9;
for (int i = li[r]; i <= ri[r]; i++)
res = min(query(a[i], l, dep - 1) + (1 << dep), res);
ans += res;
}
solve(l, dep - 1);
solve(r, dep - 1);
}
else if (r)
solve(r, dep - 1);
else if (l)
solve(l, dep - 1);
}
} T;
int main()
{
int n = read();
for (int i = 1; i <= n; i++)
a[i] = read();
sort(1 + a, 1 + a + n);
for (int i = 1; i <= n; i++)
T.insert(a[i], i);
T.solve(0, 30);

printf("%lld\n", ans);
}