AtCoderのためのGo言語用mod付き計算ライブラリ|実装と活用方法
Go言語にはC++のような演算子オーバーロード機能がないので、mod計算を行うコードは複雑になりがちです。本記事では、AtCoderのコンテストで頻繁に出題される「〇〇で割ったあまりを計算しなさい」という問題に利用可能ならmod付きの演算ライブラリの実装方法を解説します。
Go言語でmod計算をする
Go言語には、オペレータ(演算子)のオーバーロード機能がありません。これは言語設計上のポリシーの一環であり、いまのところ変更される予定はなさそうです。
一方で、C++ではオペレータオーバーロードが可能で、たとえばAtCoder提供のACL(AtCoder Library)では、+
や-
といった通常の演算子を使って簡単にmod計算を行うことができます。しかし、Goではこのような簡便な方法は使えないため、mod計算を行うためのコードが煩雑になりがちです。
例えばGo言語は以下のように式中にmod計算を書くか、mod計算を合わせて行う関数を呼び出すしかありません。
直接式に書く例
const mod = int(1e9 + 7)
x = ((x * y)%mod + z)%mod
関数を呼び出す例
mod = newModint(1e9+7)
x = mod.add(mod.mul(x, y), z)
どちらも面倒ですが、関数形式の方が幾分楽な気がします。ということで、上記のような形で呼び出すことのできるライブラリを作成しました。
使い方
初期化
初期化は、以下のように書きます。
引数は、modをとる値です。AtCoderだと998244353か1e9+7のどちらかになると思います。一応、逆元の部分は拡張オイラーを使っているので、素数以外でも条件付きで使えるとは思います。
mod = newModint(998244353)
四則演算等
基本演算としては、以下を用意しています。
Goの場合、別ファイルのライブラリの関数を呼び出す場合は、関数名は大文字というルールがありますが、このライブラリはmainパッケージに貼り付ける前提で作っています(提出時を想定)
関数名 | 処理 |
add(x, y) | $x+y$ |
sub(x, y) | $x-y$ |
mul(x, y) | $x*y$ |
div(x, y) | $x/y$ |
pow(x, y) | $x^y$ |
このうち、div
とpow
は、少し面倒なので関数化しておくと便利です。
mod.mod
自分でmodを書く場合は、以下のようにすることもできます。
mod = newModint(998244353)
x = (x + y)%mod.mod
その他の演算
その他の演算として、nCr
を用意しています。組み合わせ(コンビネーション)の計算です。結構使うことがあるので、入れておきました。
これのせいでコードが長くなっています。
ライブラリコード
私が使っているmod計算用のライブラリです。いまいちな実装と思われる部分もあるかもしれませんが、大抵の場合はこれで何とかなっています。
使う場合は、下のコードをコピーしてコードの適当な場所に貼り付けてください。それで使えます。
Go言語で競技プログラミング、速度は十分なのですが、情報量不足と標準ライブラリ不足で何かと不利ですが、使っている方、一緒に頑張りましょう。
//----------------------------------------
// modint
//----------------------------------------
type modint struct {
mod int
factMemo []int
ifactMemo []int
}
func newModint(m int) *modint {
var ret modint
ret.mod = m
ret.factMemo = []int{1, 1}
ret.ifactMemo = []int{1, 1}
return &ret
}
func (m *modint) add(a, b int) int {
ret := (a + b) % m.mod
if ret < 0 {
ret += m.mod
}
return ret
}
func (m *modint) sub(a, b int) int {
ret := (a - b) % m.mod
if ret < 0 {
ret += m.mod
}
return ret
}
func (m *modint) mul(a, b int) int {
a %= m.mod
b %= m.mod
ret := a * b % m.mod
if ret < 0 {
ret += m.mod
}
return ret
}
func (m *modint) div(a, b int) int {
a %= m.mod
ret := a * m.modinv(b)
ret %= m.mod
return ret
}
func (m *modint) pow(p, n int) int {
ret := 1
x := p % m.mod
for n != 0 {
if n%2 == 1 {
ret *= x
ret %= m.mod
}
n /= 2
x = x * x % m.mod
}
return ret
}
// 拡張オイラーの互除法で逆元を求める
func (mm *modint) modinv(a int) int {
m := mm.mod
b, u, v := m, 1, 0
for b != 0 {
t := a / b
a -= t * b
a, b = b, a
u -= t * v
u, v = v, u
}
u %= m
if u < 0 {
u += m
}
return u
}
//-----------------------------------------------
// 行列累乗
// A[][]のp乗を求める
//-----------------------------------------------
func (m *modint) powModMatrix(A [][]int, p int) [][]int {
N := len(A)
ret := make([][]int, N)
for i := 0; i < N; i++ {
ret[i] = make([]int, N)
ret[i][i] = 1
}
for p > 0 {
if p&1 == 1 {
ret = m.mulMod(ret, A)
}
A = m.mulMod(A, A)
p >>= 1
}
return ret
}
func (m *modint) mulMod(A, B [][]int) [][]int {
H := len(A)
W := len(B[0])
K := len(A[0])
C := make([][]int, W)
for i := 0; i < W; i++ {
C[i] = make([]int, W)
}
for i := 0; i < H; i++ {
for j := 0; j < W; j++ {
for k := 0; k < K; k++ {
C[i][j] += A[i][k] * B[k][j]
C[i][j] %= m.mod
}
}
}
return C
}
//---------------------------------------------------
// nCk 計算関連: TLEすることがあるかも
// ※pow(x, p-2)を何度も取るので
// 厳しそうな場合は、ここを削除して高速なのを使う
//---------------------------------------------------
func (m *modint) mfact(n int) int {
if len(m.factMemo) > n {
return m.factMemo[n]
}
if len(m.factMemo) == 0 {
m.factMemo = append(m.factMemo, 1)
}
for len(m.factMemo) <= n {
size := len(m.factMemo)
m.factMemo = append(m.factMemo, m.factMemo[size-1]*size%m.mod)
}
return m.factMemo[n]
}
func (m *modint) mifact(n int) int {
if len(m.ifactMemo) > n {
return m.ifactMemo[n]
}
if len(m.ifactMemo) == 0 {
m.factMemo = append(m.ifactMemo, 1)
}
for len(m.ifactMemo) <= n {
size := len(m.ifactMemo)
m.ifactMemo = append(m.ifactMemo, m.ifactMemo[size-1]*m.pow(size, m.mod-2)%m.mod)
}
return m.ifactMemo[n]
}
func (m *modint) nCr(n, r int) int {
if n == r {
return 1
}
if n < r || r < 0 {
return 0
}
ret := 1
ret *= m.mfact(n)
ret %= m.mod
ret *= m.mifact(r)
ret %= m.mod
ret *= m.mifact(n - r)
ret %= m.mod
return (ret)
}
まとめ
Go言語用のmod付き計算のライブラリを実装してみました。このままカット&ペーストして利用できるようにしていますので、活用ください。