CF715C

给出一个树$(n\leq 10^5)$,每条边上写了一个数字,给出一个$p$,求有多少条路径按顺序读出的数字可以被$p$整除。保证$p$与10互质。

这题可以用$dsu$和点分治做。

点分治

显然我们只需要知道如何快速计算一个子树里合理方案树。因为这道题可以路径顺逆读不一样$(14\% 7=p)$但$(41\% !=0)$。定义两个数组$ls[N]$从根节点到叶子节点组成的数字,$rs[N]$从叶子节点到根节点组成的数字。

这部分以$key=ls[i]*inv(10^{len(ls[i])})$$map$进行匹配即可

求逆元部分需要注意!!

因为$gcd(a,p)=1$

然后就是套上点分治的模板就好了

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <utility>
#include <map>
#define inf 0x7fffffff
using namespace std;
typedef long long ll;
const int N = 2e5 + 10;
struct edge
{
int v, w, nxt;
} e[N << 2];

int n, head[N], vis[N], invpow10[N], pow10[N];
int cnt, root, sum, p;
int son[N], f[N];

ll ans;

map<int, ll> mr;
map<int, ll> ml;
vector<pair<int, int> > ls;
vector<pair<int, int> > rs;

int exgcd(int a, int b, int &x, int &y)
{
if (!b)
{
x = 1;
y = 0;
return a;
}
int k = exgcd(b, a % b, x, y);
int xx = x;
x = y;
y = (xx - a / b * y);
return k;
}

void addedge(int u, int v, int w)
{
e[++cnt].v = v;
e[cnt].w = w;
e[cnt].nxt = head[u];
head[u] = cnt;
}

void getroot(int x, int fa)
{
f[x] = 0;
son[x] = 1;
for (int i = head[x]; i; i = e[i].nxt)
{
int v = e[i].v;
if (v == fa || vis[e[i].v])
continue;
getroot(v, x);
son[x] += son[v];
f[x] = max(son[v], f[x]);
}
f[x] = max(f[x], sum - son[x]);
if (f[x] < f[root])
root = x;
}

void getnum(int x, int ld, int rd, int len,int fa)
{
ls.push_back(make_pair(ld%p, len));
rs.push_back(make_pair(rd%p, len));
//cout<<x<<endl;
for (int i = head[x]; i; i = e[i].nxt)
{
int v = e[i].v;
int w = e[i].w;
if (v == fa || vis[e[i].v])
continue;
getnum(v, (1ll*ld * 10%p + w)%p , (rd + 1ll * pow10[len] * w %p)%p , len + 1,x);
}
}

ll calc(int x, int d)
{
ll res = 0;
ls.clear();
rs.clear();
ml.clear();
mr.clear();
if (!d)
getnum(x, 0, 0, 0,0);
else
getnum(x, d, d, 1,0);

for (int i = 0; i < ls.size(); i++)
{
//if(d)
//cout<<ls[i].first%p<<" "<<ls[i].second<<endl;
int key = 1ll * (-ls[i].first + p)%p * invpow10[ls[i].second] % p;
//cout<<key<<" "<<endl;
ml[key]++;
}

for (int i = 0; i < rs.size(); i++)
res += ml[rs[i].first];
if(!d) res--;
//cout<<"---"<<x<<" "<<res<<endl;
return res;
}

void solve(int x)
{
ans += calc(x, 0);
//cout<<ans<<endl;
vis[x] = 1;
for (int i = head[x]; i; i = e[i].nxt)
{
int v = e[i].v;
if (vis[v])
continue;
ans -= calc(v, e[i].w);
sum = son[v];
root = 0;
getroot(v, x);
solve(root);
}
}

int main()
{
scanf("%d%d", &n, &p);
for (int i = 1; i < n; i++)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
u++,v++;
addedge(u, v, w);
addedge(v, u, w);
}
pow10[0]=1;
for (int i = 1; i <= n; i++)
pow10[i] = 1ll * pow10[i-1] * 10%p ;
for (int i = 0; i <= n; i++)
{
int x, y;
//cout<<pow10[i]<<endl;
exgcd(pow10[i], p, x, y);
//cout<<x<<endl;
invpow10[i] = (x + p) % p;
//cout<<invpow10[i]<<endl;
}

f[0] = inf;
sum = n;
getroot(1, 0);
//cout<<root<<endl;
solve(root);
printf("%lld\n", ans);
}