プログラミング
記事内に商品プロモーションを含む場合があります

Pythonで深さ優先探索(DFS),幅優先探索(BFS), 優先キューを使ったダイクストラ法を実装

PythonでDFS, BSF, ダイクストラ法を記述する
tadanori

はじめに

経路探索の代表的な手法である、DFS(深さ優先探索)、BFS(幅優先探索)、ダイクストラ法(優先キュー、Priority Queue)の3つの典型コードを紹介します。

AtCoderなどの競技プログラムに参加する場合、これらの基本的なアルゴリズムはサクッとかけないとなかなかレートが上がりません。

それぞれのアルゴリズムの詳細はWikipediaに任せるとして、ここではPythonでこれらのアルゴリズムを記述する場合のテンプレートを紹介します。

この形を基本形として組めるようになっておけば、応用は簡単です。

Wikipediaへのリンク

その他のAtCoderに役立つ記事の一覧は以下にあります。

あわせて読みたい
AtCoderで役立つ記事一覧(Python, Go言語)
AtCoderで役立つ記事一覧(Python, Go言語)

深さ優先探索(DFS)

頂点間の接続が与えられるパターン

N個の頂点番号と、頂点Aと頂点Bの接続情報が与えられるタイプのパターンの場合、各頂点間の接続情報をリスト構造に格納して探索を行います。

ここでは、A62 – Depth First Searchのコードを作成してみます。

DFSは再帰で実装しますが、深い再帰を利用する場合、Pythonでは再帰用のスタックサイズを大きくしておく必要があります

import sys
sys.setrecursionlimit(10**6)

入力を受け取ります。N,Mと、頂点AとBの接続情報を受け取り、nodeに格納しています。また、入力が1スタートなので0スタートに変更しています。これで、node[0]には、点0が接続している頂点番号のリストが格納されます。

N, M = map(int, input().split())

node = [[] for _ in range(N)]
for _ in range(M) :
    a, b = map(int, input().split())
    a -= 1
    b -= 1
    node[a].append(b)
    node[b].append(a)

DFSの関数を作成します。ここは、ほぼテンプレートです。usedは頂点に訪問したかどうかを示すフラグで、最初Falseに初期化しておきます。dfs()の引数として入力された場合は訪問したとしてusedTrueを代入します。

あとは、点curに接続している各頂点(node[cur])に対して、訪問していなければ再帰的にdfs()を呼び出します。usedフラグをチェックしないと、無限ループになるので注意です。

以下のコードが、dfsを記述する場合の典型コードになります。

def dfs(cur) :
    used[cur] = True
    for e in node[cur]:
        if used[e] == False :
            dfs(e)

この関数の呼び出しと、判定は以下になります。今回の問題では、全ての頂点が接続されているかどうかを答えるので、頂点1(0番)から探索して、usedTrueになった個数が頂点数と同じなら、全体が連結されていると判定します。

used = [False for _ in range(N)]
dfs(0)

if sum(used) == N :
    print("The graph is connected.")
else:
    print("The graph is not connected.")

プログラムの全リストは以下になります。

import sys
sys.setrecursionlimit(10**6)

N, M = map(int, input().split())

node = [[] for _ in range(N)]
for _ in range(M) :
    a, b = map(int, input().split())
    a -= 1
    b -= 1
    node[a].append(b)
    node[b].append(a)


def dfs(cur) :
    used[cur] = True
    for e in node[cur]:
        if used[e] == False :
            dfs(e)

used = [False for _ in range(N)]
dfs(0)

if sum(used) == N :
    print("The graph is connected.")
else:
    print("The graph is not connected.")

迷路が与えられるパターン

H, Wのサイズの迷路が文字列で与えられるパターンです。

ここでは、A – 深さ優先探索 のコードを作成します。

DFSは再帰で実装しますが、深い再帰を利用する場合、Pythonでは再帰用のスタックサイズを大きくしておく必要があります

import sys
sys.setrecursionlimit(10**6)

入力を受け取ります。1行目でHとWを、2行目で迷路を読み込みます。Pythonの場合、このように1行で読み込みを記述できます。

H, W = map(int, input().split())

c = [input() for _ in range(H)]

この問題では、スタートとゴールの位置を探す必要があるので、先に探しておきます。

sx, sy = 0,0
gx, gy = 0,0

for h in range(H):
    for w in range(W):
        if c[h][w] == 's' :
            sx, sy = w, h
        if c[h][w] == 'g' :
            gx, gy = w, h

dfs()の定義です。c[y][x]でアクセス可能な迷路を探索する場合は、以下のコードが基本となります。dx, dyは、移動できる方向(dx, dy)です。dfs()では、4方向に対して探索します。探索時には、(1)エリア外(2)壁(3)既に訪問した部分についてはスキップ処理をします。

上記以外の到達可能な場所の場合、dfs()を再帰的に呼び出します。

x, yの座標と配列の関係をミスらないように。c[x][y]とアクセスするとバグります。

以下のコードが、dfsを記述する場合の典型コードになります。

dx = [-1, 1, 0, 0]
dy = [0 ,0 ,-1, 1]
def dfs(cx, cy) :
    used[cy][cx] = True
    for i in range(4):
        px = cx + dx[i]
        py = cy + dy[i]
        if px < 0 or px >= W or py < 0 or py >= H : continue
        if c[py][px] == '#' : continue
        if used[py][px] : continue
        dfs(px, py)

dfs()を呼び出す部分です。訪問済みのフラグusedを初期化してからdfs(sx, sy)を呼び出します。

used[gy][gx]Trueになっている場合、目的地に到着しているのでYesを返します。

used = [[False for _ in range(W)] for _ in range(H)]
dfs(sx, sy)


if used[gy][gx] :
    print("Yes")
else :
    print("No")

プログラムの全リストは以下になります。

import sys
sys.setrecursionlimit(10**6)

H, W = map(int, input().split())

c = [input() for _ in range(H)]

sx, sy = 0,0
gx, gy = 0,0

for h in range(H):
    for w in range(W):
        if c[h][w] == 's' :
            sx, sy = w, h
        if c[h][w] == 'g' :
            gx, gy = w, h



dx = [-1, 1, 0, 0]
dy = [0 ,0 ,-1, 1]
def dfs(cx, cy) :
    used[cy][cx] = True
    for i in range(4):
        px = cx + dx[i]
        py = cy + dy[i]
        if px < 0 or px >= W or py < 0 or py >= H : continue
        if c[py][px] == '#' : continue
        if used[py][px] : continue
        dfs(px, py)

used = [[False for _ in range(W)] for _ in range(H)]
dfs(sx, sy)


if used[gy][gx] :
    print("Yes")
else :
    print("No")

幅優先探索(BFS)

頂点間の接続が与えられるパターン

N個の頂点番号と、頂点Aと頂点Bの接続情報が与えられるタイプのパターンの場合、各頂点間の接続情報をリスト構造に格納して探索を行います。

ここでは、A63 – Shortest Path 1のコードを作成してみます。

BFSを作る場合、キューが必要になります。Pythonでは、dequeを使いますのでこれをインポートします。

from collections import deque

入力の受け取りです。深さ優先探索(DFS)とコードは同じです。

node = [[] for _ in range(N)]
for _ in range(M):
    a, b = map(int, input().split())
    a -= 1
    b -= 1
    node[a].append(b)
    node[b].append(a)

今回のコードでは、頂点1からの距離を求めるので、distに距離を格納します。とりあえず、全て大きな値で初期化しておきます。

次に、dist[0]=0に、dequeに頂点1(0)を代入します。

あとは、キューが空になるまで、先頭を取り出して、接続されている頂点の探索を繰り返します。if文では、距離がdist[cur]+1より大きい場合に、キューに追加していますが、dist[e]==infとしても大丈夫です。

以下のコードが、bfsを記述する場合の典型コードになります。

inf = 10**9
dist = [inf for _ in range(N)]

dist[0] = 0
q = deque([0])

while len(q) != 0 :
    cur = q[0]
    q.popleft()
    for e in node[cur]:
        if dist[e] > dist[cur] + 1 :
            dist[e] = dist[cur]+1
            q.append(e)

最後に、各頂点への距離を出力します。distinfの頂点は訪問されていないので-1を出力します。

for e in dist:
    if e == inf : print(-1)
    else:
        print(e)

プログラムの全リストは以下になります。

from collections import deque

N, M = map(int, input().split())

node = [[] for _ in range(N)]
for _ in range(M):
    a, b = map(int, input().split())
    a -= 1
    b -= 1
    node[a].append(b)
    node[b].append(a)


inf = 10**9
dist = [inf for _ in range(N)]

dist[0] = 0
q = deque([0])

while len(q) != 0 :
    cur = q[0]
    q.popleft()
    for e in node[cur]:
        if dist[e] > dist[cur] + 1 :
            dist[e] = dist[cur]+1
            q.append(e)


for e in dist:
    if e == inf : print(-1)
    else:
        print(e)

迷路が与えられるパターン

H, Wのサイズの迷路が文字列で与えられるパターンです。

ここでは、C – 幅優先探索のコードを作成します。

BFSを作る場合、キューが必要になります。Pythonでは、dequeを使いますのでこれをインポートします。

from collections import deque

入力の受け取りです。入力sx, sy, gx, gyは1スタートになっていますが、0スタートに変更します。

R, C = map(int, input().split())

sy, sx = map(int, input().split())
gy, gx = map(int, input().split())
sy -= 1
sx -= 1
gy -= 1
gx -= 1

c = [input() for _ in range(R)]

BFSのメイン処理です。距離を保存するリストdistを初期化し、キューにスタート位置を入れてから、キューが空になるまでループを繰り返します。

dx, dyを使って次の位置を計算するのはDFSの例と同じです。領域外の判定を行なっていますが、この問題では壁マスで囲まれているので、枠外の判定は必要ありませんが、テンプレートとしてとりあえず入れています。

以下のコードが、bfsを記述する場合の典型コードになります。

inf = 10**9
dist = [[inf for _ in range(C)] for _ in range(R)]

dist[sy][sx] = 0
q = deque()
q.append((sx, sy))

dx = [-1, 1, 0, 0]
dy = [0, 0, -1, 1]
while len(q) != 0 :
    cx, cy = q[0]
    q.popleft()
    for i in range(4):
        px, py = cx + dx[i], cy + dy[i]
        if px < 0 or px >= C or py < 0 or py >= R : continue
        if c[py][px] == '#' : continue
        if dist[py][px] > dist[cy][cx] + 1:
            dist[py][px] = dist[cy][cx] + 1
            q.append((px, py))

最後に、目的地の距離を出力して終了です。

print(dist[gy][gx])

プログラムの全リストは以下になります。

from collections import deque

R, C = map(int, input().split())

sy, sx = map(int, input().split())
gy, gx = map(int, input().split())
sy -= 1
sx -= 1
gy -= 1
gx -= 1

c = [input() for _ in range(R)]


inf = 10**9
dist = [[inf for _ in range(C)] for _ in range(R)]

dist[sy][sx] = 0
q = deque()
q.append((sx, sy))

dx = [-1, 1, 0, 0]
dy = [0, 0, -1, 1]
while len(q) != 0 :
    cx, cy = q[0]
    q.popleft()
    for i in range(4):
        px, py = cx + dx[i], cy + dy[i]
        if px < 0 or px >= C or py < 0 or py >= R : continue
        if c[py][px] == '#' : continue
        if dist[py][px] > dist[cy][cx] + 1:
            dist[py][px] = dist[cy][cx] + 1
            q.append((px, py))


print(dist[gy][gx])

ダイクストラ法

N個の頂点番号と、頂点Aと頂点Bの接続情報が与えられるタイプのパターンの場合、各頂点間の接続情報をリスト構造に格納して探索を行います。

ここでは、A64 – Shortest Path 2 をサンプルとたコードを作成します。

ダイクストラ法を実装する場合、優先キュー(Priority Queue)を使います。Pythonではheapqパッケージを利用します。

import heapq

入力を受け取ります。基本的には、DFS、BSFと同じです。

N, M = map(int, input().split())

node = [[] for _ in range(N)]
for i in range(M) :
    a, b, c = map(int, input().split())
    a -= 1
    b -= 1
    node[a].append((b, c))
    node[b].append((a, c))

ダイクストラ法のメイン処理です。最初に、距離リストをinfに初期化します。次に、最初の頂点1(0)の距離を0に初期化し、優先キューに(距離、頂点番号)の組み合わせで入力します。

heapqでは、要素を昇順に並べます。今回の場合、距離の短い順で頂点番号が小さい順に並びます。

メインのループですが、実はBFSとほとんど変わりません。違うのはキューが優先キューに変わったこと、距離を考慮することくらいです。なので、ダイクストラ法とBFSのコードはどちらかを覚えたら、反対側もかけるようになります。

以下のコードが、ダイクストラ法を記述する場合の典型コードになります。

inf = int(1e18)
dist = [inf for _ in range(N)]
dist[0] = 0
pq = []
heapq.heappush(pq, (0, 0))


while len(pq) != 0 :
    cur_cost, cur  = pq[0]
    heapq.heappop(pq)
    if cur_cost > dist[cur] : continue
    for e, cost in node[cur]:
        if dist[e] > dist[cur] + cost :
            dist[e] = dist[cur] + cost
            heapq.heappush(pq, (dist[e], e))

最後に結果を出力します。今回は、各頂点への距離なので、distをそのまま出力します。

for e in dist:
    if e == inf : print(-1)
    else : print(e)

プログラムの全リストは以下になります。

import heapq

N, M = map(int, input().split())

node = [[] for _ in range(N)]
for i in range(M) :
    a, b, c = map(int, input().split())
    a -= 1
    b -= 1
    node[a].append((b, c))
    node[b].append((a, c))


inf = int(1e18)
dist = [inf for _ in range(N)]
dist[0] = 0
pq = []
heapq.heappush(pq, (0, 0))


while len(pq) != 0 :
    cur_cost, cur  = pq[0]
    heapq.heappop(pq)
    if cur_cost > dist[cur] : continue
    for e, cost in node[cur]:
        if dist[e] > dist[cur] + cost :
            dist[e] = dist[cur] + cost
            heapq.heappush(pq, (dist[e], e))

for e in dist:
    if e == inf : print(-1)
    else : print(e)

終わりに

以上、DFS、BFS、ダイクストラ法をPythonで記述する場合の典型パターンを紹介しました。この3つは、AtCoderでは頻繁に利用しますので、基本パターンはサクッとかけるようになっていた方が良いです。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました