プログラミング
記事内に商品プロモーションを含む場合があります

AtCoderのためのGo言語用mod付き計算ライブラリ|実装と活用方法

Go言語でMOD演算
Aru

Go言語にはC++のような演算子オーバーロード機能がないので、mod計算を行うコードは複雑になりがちです。本記事では、AtCoderのコンテストで頻繁に出題される「〇〇で割ったあまりを計算しなさい」という問題に利用可能ならmod付きの演算ライブラリの実装方法を解説します。

その他のAtCoderに役立つ記事の一覧

Go言語でmod計算をする

Go言語には、オペレータ(演算子)のオーバーロード機能がありません。これは言語設計上のポリシーの一環であり、いまのところ変更される予定はなさそうです。

一方で、C++ではオペレータオーバーロードが可能で、たとえばAtCoder提供のACL(AtCoder Library)では、+-といった通常の演算子を使って簡単にmod計算を行うことができます。しかし、Goではこのような簡便な方法は使えないため、mod計算を行うためのコードが煩雑になりがちです。

例えばGo言語は以下のように式中にmod計算を書くか、mod計算を合わせて行う関数を呼び出すしかありません。

直接式に書く例

const mod = int(1e9 + 7)
x = ((x * y)%mod + z)%mod 

関数を呼び出す例

mod = newModint(1e9+7)
x = mod.add(mod.mul(x, y), z)

どちらも面倒ですが、関数形式の方が幾分楽な気がします。ということで、上記のような形で呼び出すことのできるライブラリを作成しました。

使い方

初期化

初期化は、以下のように書きます。

引数は、modをとる値です。AtCoderだと9982443531e9+7のどちらかになると思います。一応、逆元の部分は拡張オイラーを使っているので、素数以外でも条件付きで使えるとは思います。

mod = newModint(998244353)

四則演算等

基本演算としては、以下を用意しています。

Goの場合、別ファイルのライブラリの関数を呼び出す場合は、関数名は大文字というルールがありますが、このライブラリはmainパッケージに貼り付ける前提で作っています(提出時を想定)

関数名処理
add(x, y)$x+y$
sub(x, y)$x-y$
mul(x, y)$x*y$
div(x, y)$x/y$
pow(x, y)$x^y$

このうち、divpowは、少し面倒なので関数化しておくと便利です。

mod.mod

自分でmodを書く場合は、以下のようにすることもできます。

mod = newModint(998244353)
x = (x + y)%mod.mod

その他の演算

その他の演算として、nCrを用意しています。組み合わせ(コンビネーション)の計算です。結構使うことがあるので、入れておきました。

これのせいでコードが長くなっています。

ライブラリコード

私が使っているmod計算用のライブラリです。いまいちな実装と思われる部分もあるかもしれませんが、大抵の場合はこれで何とかなっています。

使う場合は、下のコードをコピーしてコードの適当な場所に貼り付けてください。それで使えます。

Go言語で競技プログラミング、速度は十分なのですが、情報量不足と標準ライブラリ不足で何かと不利ですが、使っている方、一緒に頑張りましょう。

//----------------------------------------
// modint
//----------------------------------------
type modint struct {
	mod       int
	factMemo  []int
	ifactMemo []int
}

func newModint(m int) *modint {
	var ret modint
	ret.mod = m
	ret.factMemo = []int{1, 1}
	ret.ifactMemo = []int{1, 1}
	return &ret
}

func (m *modint) add(a, b int) int {
	ret := (a + b) % m.mod
	if ret < 0 {
		ret += m.mod
	}
	return ret
}

func (m *modint) sub(a, b int) int {
	ret := (a - b) % m.mod
	if ret < 0 {
		ret += m.mod
	}
	return ret
}

func (m *modint) mul(a, b int) int {
	a %= m.mod
	b %= m.mod
	ret := a * b % m.mod
	if ret < 0 {
		ret += m.mod
	}
	return ret
}

func (m *modint) div(a, b int) int {
	a %= m.mod
	ret := a * m.modinv(b)
	ret %= m.mod
	return ret
}

func (m *modint) pow(p, n int) int {
	ret := 1
	x := p % m.mod
	for n != 0 {
		if n%2 == 1 {
			ret *= x
			ret %= m.mod
		}
		n /= 2
		x = x * x % m.mod
	}
	return ret
}


// 拡張オイラーの互除法で逆元を求める
func (mm *modint) modinv(a int) int {
	m := mm.mod
	b, u, v := m, 1, 0
	for b != 0 {
		t := a / b
		a -= t * b
		a, b = b, a
		u -= t * v
		u, v = v, u
	}
	u %= m
	if u < 0 {
		u += m
	}
	return u
}

//-----------------------------------------------
// 行列累乗
//  A[][]のp乗を求める
//-----------------------------------------------
func (m *modint) powModMatrix(A [][]int, p int) [][]int {
	N := len(A)
	ret := make([][]int, N)
	for i := 0; i < N; i++ {
		ret[i] = make([]int, N)
		ret[i][i] = 1
	}

	for p > 0 {
		if p&1 == 1 {
			ret = m.mulMod(ret, A)
		}
		A = m.mulMod(A, A)
		p >>= 1
	}

	return ret
}

func (m *modint) mulMod(A, B [][]int) [][]int {
	H := len(A)
	W := len(B[0])
	K := len(A[0])
	C := make([][]int, W)
	for i := 0; i < W; i++ {
		C[i] = make([]int, W)
	}

	for i := 0; i < H; i++ {
		for j := 0; j < W; j++ {
			for k := 0; k < K; k++ {
				C[i][j] += A[i][k] * B[k][j]
				C[i][j] %= m.mod
			}
		}
	}

	return C
}

//---------------------------------------------------
// nCk 計算関連: TLEすることがあるかも
//                ※pow(x, p-2)を何度も取るので
// 厳しそうな場合は、ここを削除して高速なのを使う
//---------------------------------------------------
func (m *modint) mfact(n int) int {
	if len(m.factMemo) > n {
		return m.factMemo[n]
	}
	if len(m.factMemo) == 0 {
		m.factMemo = append(m.factMemo, 1)
	}
	for len(m.factMemo) <= n {
		size := len(m.factMemo)
		m.factMemo = append(m.factMemo, m.factMemo[size-1]*size%m.mod)
	}
	return m.factMemo[n]
}

func (m *modint) mifact(n int) int {
	if len(m.ifactMemo) > n {
		return m.ifactMemo[n]
	}
	if len(m.ifactMemo) == 0 {
		m.factMemo = append(m.ifactMemo, 1)
	}
	for len(m.ifactMemo) <= n {
		size := len(m.ifactMemo)
		m.ifactMemo = append(m.ifactMemo, m.ifactMemo[size-1]*m.pow(size, m.mod-2)%m.mod)
	}
	return m.ifactMemo[n]
}

func (m *modint) nCr(n, r int) int {
	if n == r {
		return 1
	}
	if n < r || r < 0 {
		return 0
	}
	ret := 1
	ret *= m.mfact(n)
	ret %= m.mod
	ret *= m.mifact(r)
	ret %= m.mod
	ret *= m.mifact(n - r)
	ret %= m.mod
	return (ret)
}

まとめ

Go言語用のmod付き計算のライブラリを実装してみました。このままカット&ペーストして利用できるようにしていますので、活用ください。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました