Python|たった一行追加して再帰関数を高速化する方法(メモ化)
Pythonには、たった一行を追加するだけでメモ化再帰を実現する機能があります。これを使うと再帰を使った探索問題の計算効率を大幅に向上させることが可能です。この記事では、サンプルプログラムを使ってメモ化再帰の有用性と、Pythonでの記述の方法を解説します。
再帰関数とは
再帰関数とは、自身の中で自分自身を呼び出す関数のことです。数式的には、$f(x) = xf(x-1)$のように右辺にも左辺にも関数$f$が現れるものになります。先の例は再帰を使わずにfor
文で書くこともできますが、再帰を使った方がシンプルにかけることもあります。再帰関数を使った有名なものとしては、階乗計算やフィボナッチ数列の生成などがあります。
階乗の再帰プログラムの例
階乗計算($n!$)は、1からnまでを掛け合わせる計算です。階乗計算は以下の式で定義されます。
$$
\begin{align}
n! &= n \times (n-1)\\
0! &= 1
\end{align}$$
これをPythonで再帰関数として書くと以下のようになります。
def factorial(n):
if n == 0:
return 1
else:
return n * factorial(n - 1)
例えば、factorial(3)
は、3*factrial(2)
→ 3*2*factrial(1)
→ 3*2*1
と計算されます。このように、自身の関数から自分自身を呼び出す構造の関数を再帰関数と呼びます。
Pythonで再帰関数を利用する場合、再起回数の制限に注意する必要があります。デフォルトでは、再起呼び出しの回数が小さいので、これを大きくしておく必要があります。具体的には以下のようなコードを先頭に挿入して、再起回数の制限を緩和します。
import sys
sys.setrecursionlimit(10**6)
メモ化再帰とは
メモ化再帰とは、再起的な計算において、「一度計算した結果を保存しておき、同じ計算を繰り返さないようにする手法」です。特に、同じ計算が難度も行われるケースでは、計算時間を大幅に短縮可能可能です。
再帰関数で同じ計算が行われるかどうかは、引数を見ればわかります。基本的に、引数の値が全て同じ場合は同じ計算が行われると考えることが可能です。
フィボナッチ数列を再帰プログラムで記述する
フィボナッチ数列は以下の式で定義される数列です。
$$
\begin{align}
F(0) &= 0\\
F(1) &= 1\\
F(n) &= F(n-1) + F(n-2) (n>=2)
\end{align}
$$
これを再帰プログラムで記述すると以下のようになります。
def fibonacci(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n - 1) + fibonacci(n - 2)
非常に素直な実装です。fibonacci(5)
を実行すると以下のような計算が行われます。
fibonacci(5) -> fibonacci(4) + fibonacci(3)
fibonacci(4) -> fibonacci(3) + fibonacci(2)
fibonacci(3) -> fibonacci(2) + fibonacci(1)
fibonacci(2) -> fibonacci(1) + fibonacci(0)
上記を見て分かるように、同じ引数で関数がなん度も呼び出されています。つまり同じ計算を難度も行うことになります。
以下のようなコードで上記の関数を実行すると、30付近から実行が遅くなります。
def fibonacci(n):
if n == 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n - 1) + fibonacci(n - 2)
for i in range(100):
print(i, fibonacci(i))
これは、同じ計算を難度も繰り返すため、処理時間がかかっているからです。
メモ化再帰(自分で記述する例)
ここで、メモ化再帰を導入します。
def fibonacci(n):
if n in memo:
return memo[n]
if n == 0:
result = 0
elif n == 1:
result = 1
else:
result = fibonacci(n - 1) + fibonacci(n - 2)
memo[n] = result
return result
memo = {}
for i in range(100):
print(i, fibonacci(i))
上記の例では、辞書memo
に、計算結果を格納し、一度計算したことのある値については、メモした結果を返すようになっています。
このプログラムを実行すると一瞬で処理が完了します。実際に実行するとメモ化の効果を感じれるかと思います。
Pythonで一行でメモ化する方法(@cache)
先ほどのメモ化では、memo
を自身で定義し、メモ化するコードを挿入しましたが、Pythonでは1行追加するだけでもメモ化を実現することが可能です。
以下は、コードです。
from functools import cache
@cache
def fibonacci(n):
if n == 0:
result = 0
elif n == 1:
result = 1
else:
result = fibonacci(n - 1) + fibonacci(n - 2)
return result
for i in range(100):
print(i, fibonacci(i))
一行と書きましたが、import
文があるので正確には2行追加です。@cache
というデコレータを再帰関数の手前に追加するだけでメモ化することが可能です。
似たような機能に@lru_cache
もありますが、こちらはmaxsizeまで記録するものになります。競技プログラミングなどでは@cache
だけで良いかと思います
@cache
は、@lru_cache(maxsize=None)
と同じです(無限にキャッシュする設定)。
サイズを制限したい場合にはlru_cache(maxsize=サイズ)
で指定します。
このように、コードにメモ機能を追加せずにメモ化を実現できるのがPythonの強みです。
AtCoder ABC-275 D問題を解いてみる
メモ化再帰を使った問題です。問題文は以下の通りです
問題文
非負整数 xx に対し定義される関数 f(x)f(x) は以下の条件を満たします。
- $f(0) = 1$
- 任意の正整数 kk に対し $f(k) = f(\lfloor \frac{k}{2} \rfloor) + f(\lfloor \frac{k}{3} \rfloor)$
このとき、 $f(N)$を求めてください。
制約
- $N$ は$0\leq N \leq 10^{18}$を満たす整数
素直に書くと以下のようになります。
import sys
sys.setrecursionlimit(10**6)
n = int(input())
def f(x) :
if x == 0 : return 1
return f(x//2) + f(x//3)
print(f(n))
上記のプログラムを提出するとTLE(時間超過)してしまいます。ここでメモ化を導入します。メモ化のプログラムは以下になります。
違いは、importが増えたのと@cache
が追加されたことです。
from functools import cache
import sys
sys.setrecursionlimit(10**6)
n = int(input())
@cache
def f(x) :
if x == 0 : return 1
return f(x//2) + f(x//3)
print(f(n))
これを提出するとACになります。
メモ化を行う場合は、必要ない引数を増やしすぎないように注意が必要です。引数に対してメモが行われるので引数が多すぎるとメモリ消費が激しくなります。
まとめ
Pythonでデコレータを用いてメモ化を簡単にメモ化を行う方法について解説しました。他の言語よりかなり楽なの手順で導入できますので活用してください。
まず、愚直な再帰で実装して、その後@cache
を追加するという方法が良いかと思います。