Modulo Multiplication
Fri, Aug 3, 2018たまに大きな数同士のかけ算をオーバーフローさせずに計算し、その modulo が取りたい時があります。 多倍長演算を実装すればいいわけですが、大変面倒なので何か簡単な方法はないかなーと調べていたらありました。
C++ で書くと以下のようになります。
#include <cstdio>
#include <cinttypes>
/**
* if b is even then
* a * b = 2 * a * (b / 2)
* otherwise
* a * b = a + a * (b - 1)
*
* (a * b) % mod = (2a * b / 2) % mod
* = (a + a + a + ...(b times)... + a) % mod
*/
std::int64_t modpow(std::int64_t a, std::int64_t b, std::int64_t m)
{
std::int32_t r = 0;
a %= m;
while (b) {
if (b & 1) {
r = (r + a) % m;
}
a = (2 * a) % m;
b >>= 1;
}
return r;
}
int main( void )
{
using namespace std;
printf("%lu\n", modpow(32933923828, 219009238, 1 << 31));
return 0;
}
基本的な原理はコードのコメントにも記載していますが、
- a * b の乗算を a + a + … + a の b 回の加算と考える。
- 漸化式的に前回の結果との差分に対する演算をする。
- 加算の度に modulo を取ることでオーバーフローを避ける。
b >>= 1
を繰り返すことで b の bit 数回だけループが回るようにする。- まともに計算するとループ回数は $O(N)$ だけど $O(logN)$ まで減る。
という感じです。
同じ考え方を使って、$x^{n} \bmod m$ を効率良く計算する方法もあり、以下のようにします。
std::int32_t mmod(std::int32_t x, std::int32_t n, std::int32_t m)
{
if (n == 0) {
return 1;
}
if (n % 2 ==0) {
auto t = mmod(x, n / 2, m);
return (t * t) % m;
}
return (mmod(x, n - 1, m) * x) % m;
}
これらの計算手法はどうもプログラミングコンテスト界隈では必須の知識なようで、日本語では繰り返し二乗法、英語では exponentiation by squaring と呼ぶらしいです。