プログラミング
記事内に商品プロモーションを含む場合があります

Go言語でmod計算をするライブラリ for AtCoder

Go言語でMOD演算
tadanori

Go言語には、C++みたいな演算子のオーバライトがありません。
mod計算を行うのが非常に面倒です。

その他のAtCoderに役立つ記事の一覧は以下にあります。

あわせて読みたい
AtCoderで役立つ記事一覧(Python, Go言語)
AtCoderで役立つ記事一覧(Python, Go言語)

Go言語でmod計算をする

Go言語には、オペレータ(演算子の)オーバロードの機能がありません。言語ポリシーとして無いみたいなので方向性が変化しない限りは変わらないと思います。

C++は、オペレータオーバーロードができますので、AtCoder提供のACLなどでは、普通の+, -などの記号でmod計算ができてしまいます。

これに比べて、Go言語は以下のように頑張って書くか、関数呼び出しで書くかしかありません。

const mod = int(1e9 + 7)
x = ((x * y)%mod + z)%mod 

関数で書く場合は、こんな感じ。

mod = newModint(1e9+7)
x = mod.add(mod.mul(x, y), z)

どちらも面倒です・・・

とはいえ、関数形式にしておけば、pow()とか、面倒なdiv()とかも同じ形で記述することができるので、少しだけマシです。

ということで、上記のような形式で呼び出すライブラリを作成して使っています。

使い方

初期化

初期化は、以下のように書きます。

引数は、modをとる値です。AtCoderだと998244353か1e9+7のどちらかになると思います。一応、逆元の部分は拡張オイラーを使っているので、素数以外でも条件付きで使えるとは思います。

mod = newModint(998244353)

四則演算等

基本演算としては、以下を用意しています。

Goの場合、別ファイルのライブラリの関数を呼び出す場合は、関数名は大文字というルールがありますが、このライブラリはmainパッケージに貼り付ける前提で作っています(提出時を想定)

関数名処理
add(x, y)$x+y$
sub(x, y)$x-y$
mul(x, y)$x*y$
div(x, y)$x/y$
pow(x, y)$x^y$

mod.mod

自分でmodを書く場合は、以下のようにすることもできます。

mod = newModint(998244353)
x = (x + y)%mod.mod

その他の演算

その他の演算として、nCrを用意しています。組み合わせ(コンビネーション)の計算です。結構使うことがあるので、入れておきました。

これのせいでコードが長くなっています。

ライブラリコード

私が使っているmod計算用のライブラリです。いまいちな実装と思われる部分もあるかもしれませんが、大抵の場合はこれで何とかなっています。

使う場合は、下のコードをコピーしてコードの適当な場所に貼り付けてください。それで使えます。

Go言語で競技プログラミング、速度は十分なのですが、情報量不足と標準ライブラリ不足で何かと不利ですが、使っている方、一緒に頑張りましょう。

//----------------------------------------
// modint
//----------------------------------------
type modint struct {
	mod       int
	factMemo  []int
	ifactMemo []int
}

func newModint(m int) *modint {
	var ret modint
	ret.mod = m
	ret.factMemo = []int{1, 1}
	ret.ifactMemo = []int{1, 1}
	return &ret
}

func (m *modint) add(a, b int) int {
	ret := (a + b) % m.mod
	if ret < 0 {
		ret += m.mod
	}
	return ret
}

func (m *modint) sub(a, b int) int {
	ret := (a - b) % m.mod
	if ret < 0 {
		ret += m.mod
	}
	return ret
}

func (m *modint) mul(a, b int) int {
	a %= m.mod
	b %= m.mod
	ret := a * b % m.mod
	if ret < 0 {
		ret += m.mod
	}
	return ret
}

func (m *modint) div(a, b int) int {
	a %= m.mod
	ret := a * m.modinv(b)
	ret %= m.mod
	return ret
}

func (m *modint) pow(p, n int) int {
	ret := 1
	x := p % m.mod
	for n != 0 {
		if n%2 == 1 {
			ret *= x
			ret %= m.mod
		}
		n /= 2
		x = x * x % m.mod
	}
	return ret
}


// 拡張オイラーの互除法で逆元を求める
func (mm *modint) modinv(a int) int {
	m := mm.mod
	b, u, v := m, 1, 0
	for b != 0 {
		t := a / b
		a -= t * b
		a, b = b, a
		u -= t * v
		u, v = v, u
	}
	u %= m
	if u < 0 {
		u += m
	}
	return u
}

//-----------------------------------------------
// 行列累乗
//  A[][]のp乗を求める
//-----------------------------------------------
func (m *modint) powModMatrix(A [][]int, p int) [][]int {
	N := len(A)
	ret := make([][]int, N)
	for i := 0; i < N; i++ {
		ret[i] = make([]int, N)
		ret[i][i] = 1
	}

	for p > 0 {
		if p&1 == 1 {
			ret = m.mulMod(ret, A)
		}
		A = m.mulMod(A, A)
		p >>= 1
	}

	return ret
}

func (m *modint) mulMod(A, B [][]int) [][]int {
	H := len(A)
	W := len(B[0])
	K := len(A[0])
	C := make([][]int, W)
	for i := 0; i < W; i++ {
		C[i] = make([]int, W)
	}

	for i := 0; i < H; i++ {
		for j := 0; j < W; j++ {
			for k := 0; k < K; k++ {
				C[i][j] += A[i][k] * B[k][j]
				C[i][j] %= m.mod
			}
		}
	}

	return C
}

//---------------------------------------------------
// nCk 計算関連: TLEすることがあるかも
//                ※pow(x, p-2)を何度も取るので
// 厳しそうな場合は、ここを削除して高速なのを使う
//---------------------------------------------------
func (m *modint) mfact(n int) int {
	if len(m.factMemo) > n {
		return m.factMemo[n]
	}
	if len(m.factMemo) == 0 {
		m.factMemo = append(m.factMemo, 1)
	}
	for len(m.factMemo) <= n {
		size := len(m.factMemo)
		m.factMemo = append(m.factMemo, m.factMemo[size-1]*size%m.mod)
	}
	return m.factMemo[n]
}

func (m *modint) mifact(n int) int {
	if len(m.ifactMemo) > n {
		return m.ifactMemo[n]
	}
	if len(m.ifactMemo) == 0 {
		m.factMemo = append(m.ifactMemo, 1)
	}
	for len(m.ifactMemo) <= n {
		size := len(m.ifactMemo)
		m.ifactMemo = append(m.ifactMemo, m.ifactMemo[size-1]*m.pow(size, m.mod-2)%m.mod)
	}
	return m.ifactMemo[n]
}

func (m *modint) nCr(n, r int) int {
	if n == r {
		return 1
	}
	if n < r || r < 0 {
		return 0
	}
	ret := 1
	ret *= m.mfact(n)
	ret %= m.mod
	ret *= m.mifact(r)
	ret %= m.mod
	ret *= m.mifact(n - r)
	ret %= m.mod
	return (ret)
}

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました