CF724C - Ray Tracing

某二维空间里$[n,m]$,光线从$(0,0)$出发以$(1,1)$方向射出,遇到边界会反射,遇到顶点结束。$q$个询问,$(x_i,y_i)$到达的最小值,不能达到则$-1$。

结束时时间为$LCM(n,m)$

考虑一维方向以$2n$一个循环,$0,1,2…n,n-1..1$,则出现的位置即$(\%2n=x_i,2n-x_i)$,同理。
即解方程

扩展中国剩余定理即可。


代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <queue>
#include <cmath>
#include <cstring>
#include <vector>
#include <map>
#include <stack>
using namespace std;
typedef long long ll;
#define pii pair<int, int>
#define mk make_pair
const int N = 1e6 + 10;
const int mod = 1e9 + 7;
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if (b == 0)
{
x = 1;
y = 0;
return a;
}
ll r = exgcd(b, a % b, x, y);
ll xx = x;
ll yy = y;
x = yy;
y = xx - a / b * yy;
return r;
}

ll CRT(int *p, int *a, int len)
{
ll res = 1;
for (int i = 0; i < len; i++)
res *= p[i];
// cout << res << endl;
ll ans = 0;
for (int i = 0; i < len; i++)
{
ll L = res / p[i];
ll x, y;
exgcd(L, p[i], x, y);
x = (x % p[i] + p[i]) % p[i];
ans = (ans + L * x % res * a[i] % res) % res;
}
return ans;
}

ll EXCRT(ll *p, ll *a, int len)
{
ll M = p[1], R = a[1], x, y, d;
for (int i = 2; i <= len; i++)
{
d = exgcd(M, p[i], x, y);
if ((R - a[i]) % d)
return 1e18;
ll P=p[i]/d;
ll LC=M / d * p[i];
x =((R - a[i]) / d * x % P+P)%P;
R=((x*M%LC+R)%LC+LC)%LC;
M =LC

}
return (R % M + M) % M;
}
ll a[N], p[N];
ll solve(int u, int v)
{
//cout << u << " " << v << endl;
a[1] = u;
a[2] = v;

return EXCRT(p, a, 2);
}
int n, m, k;
int main()
{
scanf("%d%d%d", &n, &m, &k);
ll Lim = 1ll * n * m / gcd(n, m);

p[1] = 2 * n;
p[2] = 2 * m;

for (int i = 1; i <= k; i++)
{
int u, v;
scanf("%d%d", &u, &v);
ll ans = 1e18;
// cout << "---" << endl;
ans = min(ans, solve(u, v));
ans = min(ans, solve(u, 2 * m - v));
ans = min(ans, solve(2 * n - u, v));
ans = min(ans, solve(2 * n - u, 2 * m - v));
if (ans > Lim)
printf("-1\n");
else
printf("%lld\n", ans);
}
}

</details>