原文鏈接:http://blog.sina.com.cn/s/blog_66ad7bba0100hm8n.html
對於搞過競賽算法的人來說,powmod可能不會陌生,它是一個計算a^b mod m的函數,
但abmod你可能不知道,它其實意思更簡單,是計算a*b mod m的函數
powmod的出現看起來很自然,因爲a^b可能非常巨大,但a*b的結果很小,有存在的必要嗎?
比如a,b,m都是int,那麼a*b的結果有可能越int,你會說,那用__int64或者long long保存就行了
但,如果a,b,m都是long long,你又不想使用麻煩的大整數呢?
其實,如果你懂了計算a^b mod m的原理,那解決這個也相當容易,不過還要先說一下:
a^b * a^c mod m == a^(b+c) mod m
這恆等式貌似是初中的內容吧,這個要是明白了,那來再下一步:
舉例子:2^9 mod 5
== 2^8 * 2^1 mod 5
== ( (2^8 mod 5) * (2^1 mod 5) ) mod 5
== ( (2^4 mod 5)^2 * (2^1 mod 5) ) mod 5
== ( (2^2 mod 5)^4 * (2^1 mod 5) ) mod 5
== ( (2^1 mod 5)^8 * (2^1 mod 5) ) mod 5
從這一系列等式中你看出些什麼?
現在,我們來假設一下,我們只懂算乘法,去通過這些等式算結果。
首先2^1 mod 5 == 2,於是有
(2^8 mod 5) * (2 mod 5)
2^8不會算,但我們會把2^8變爲(2^2)^4
於是有4^4 mod 5 * 2
然後有(4^2 mod 5)^2 mod 5 * 2
得到1^2 mod 5 * 2
最後結果爲2,而實際上,2^9 = 512, 512 mod 5 == 2
可能你會問,這樣好像程序不好實現啊?不會的,我們換個角度看
a^b,把b拆成二進制,像剛剛的a^9 mod m,可以拆爲a^(2^3 + 2^0) mod m
於是,可以變形爲(a^(2^3) mod m) * (a^(2^0) mod m) mod m 間接算得
並且,a^(2^3) mod m的結果,可以由 (a^(2^2) mod m) ^ 2 mod m 獲得
於是,我們只要保證平方運算不會越界,就能通過遞推得到a^b mod m的結果
示例代碼如下:
int powmod( int a, int n, int k )
{
int d = 1;
for (a %= k; n > 0; n >>= 1)
{
if(n & 1)
d = (d*a)%k;
a = (a*a)%k;
}
return d;
}
好了,可能有人未必看的懂前面的,現在這裏就講一下簡單一些的abmod
a*b mod m,可能會越界,這是剛剛解釋過的,我們先個個假設這三個數都是byte
現在要算99 * 66 mod 100
我們把式子改寫爲 (99 * (64 + 2)) mod 100
== ( 99 * 64 + 99 * 2 ) mod 100
其中99*64的可以由99*32遞推出,然後又可以從99*16的遞推出。。。。
於是,我們可以知道,只要m不大於char最大值的一半,就可以保證結果是正確的
參考代碼:
INT abmod(INT a, INT b, INT m)
{
a %= m; b %= m;
INT s = 0;
for (INT i=b; i>0; a = (a<<1)%m,i>>=1)
if (i&1) s = (s+a) % m;
return s;
}
講解的最後,再回答一個問題:爲什麼採用二進制分解?
答,爲了提高可適用範圍。如果是三進制,前面的powmod要保證m^3不越界,後面的abmod要保證3*m不越界,
於是m的範圍就小了,精確性低了。
好了,admod的問題解決了,而這時候你發現,前面的powmod是因爲有乘法,所以令m的平方不能越界,
如果,我們把這兩個函數同時用上呢?
同時用上的話,我們可以把powmod擴展成superpowmod!可以在__int64下不借助大整數能正常工作!!
因爲最大的瓶頸是m*2不能越界,並且計算過程我們不需要負數,於是,參數類型我們用unsigned
就能保證__int64下能正常工作!比傳統的powmod可計算範圍大大增加哦!
終極代碼:
typedef unsigned __int64 INT;
INT abmod(INT a, INT b, INT m)
{
a %= m; b %= m;
INT s = 0;
for (INT i=b; i>0; a = (a<<1)%m,i>>=1)
if (i&1) s = (s+a) % m;
return s;
}
INT superPowmod( INT a, INT n, INT k )
{
INT d = 1;
for (a %= k; n > 0; n >>= 1)
{
if(n & 1)
d = abmod(d,a,k); //(d*a)%k;
a = abmod(a,a,k); //(a*a)%k;
}
return d;
}