機械学習
記事内に商品プロモーションを含む場合があります

テンソル演算の可読性を向上させる、Einopsの使い方【Pytorch】

tadanori

ディープラーニングをやっていると、reshape, permute, squeeze, unsqueezeなどによるテンソルの操作でコードが見にくくなります。可読性を向上させるにはEinopsが便利です。この記事は、Einopsの使い方のチートシートと、簡単なベンチマークの結果です。

Einopsとは

einopsとは

Einopsは、PyTorchやTensorFlowなどのディープラーニングフレームワークでテンソル操作を行うためのライブラリです。”Einstein operations”の略称であり、テンソル操作をより簡潔かつ柔軟に記述できるように設計されています。

einopsを使用することで、テンソルの形状の変更、次元の軸の入れ替え、畳み込みやプーリングなどの操作を、直感的で読みやすい形式で記述することができます。

例えば、einopsを使用して畳み込み操作を行うと、一般的な畳み込み演算子よりも簡潔で理解しやすいコードを記述することができます。

einopsの特徴は、単純な構文と柔軟な操作性です。また、einopsは他のライブラリやフレームワークとも統合が容易であり、様々なプロジェクトで幅広く活用されています。

ディープラーニングの実装において、テンソル操作は不可欠な部分ですが、einopsを使用することでより効率的かつ可読性の高いコードを記述することができます。

後で読み直す時にはかなり重宝します

基本的なPyTorchのテンソルの操作はこちら
PyTorch|テンソルの操作と基本的な演算を理解する【初級 深層学習講座】
PyTorch|テンソルの操作と基本的な演算を理解する【初級 深層学習講座】

einopsの気になる点

可読性が高くなっても、処理が重くなるのであれば使えません。今回は、einopsを使った場合に、処理時間がどう変化するか簡単なベンチマークも行ってみました。

「einopsを使用した場合のオーバーヘッドは、一般的には非常に小さい」と言われていますが、実際に計測すると処理時間は結構増加しました。

ただ、ネットワーク全体の処理から考えると、「非常に小さい」と考えることもできる気がします。

とりあえず、ベンチマーク結果を確認して、自身で判断してください。

前準備

インストール

インストールはpipで行うことができます。

pip install einops

ライブラリのインポート

torchとnumpyも合わせてインポートしています。

公式では、from…でのインポートとなっていましたので同じようにしました。

import torch
import numpy as np
from einops import rearrange, reduce, repeat, parse_shape

rearrange(再配置)

軸を入れ替えたり、一部の次元を1つにまとめたりすることができます。

基本

基本的な使い方です。'a b -> b a'の左側が元の次元、右が変更後の次元になります。

次元の名前は、なんでも良いです

a b -> b aをみれは、aとbを入れ替える処理をするこが一目でわかります。この可読性の高さがeinopsの魅力です。また、pytorch, numpy, tensorlow、どれでも同じ記述をすることができます。フレームワークごとに微妙に違う関数名に悩まされずに済むのは魅力です。

x = np.arange(20*5).reshape(20,5)
print(x.shape, x)
#(20, 5) [[ 0  1  2  3  4]
# [ 5  6  7  8  9]
# [10 11 12 13 14]
# [15 16 17 18 19]
# [20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]
# [40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]
# [60 61 62 63 64]
# [65 66 67 68 69]
# [70 71 72 73 74]
# [75 76 77 78 79]
# [80 81 82 83 84]
# [85 86 87 88 89]
# [90 91 92 93 94]
# [95 96 97 98 99]]
y = rearrange(x, 'a b -> b a')
print(y.shape, y)
#(5, 20) [[ 0  5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95]
# [ 1  6 11 16 21 26 31 36 41 46 51 56 61 66 71 76 81 86 91 96]
# [ 2  7 12 17 22 27 32 37 42 47 52 57 62 67 72 77 82 87 92 97]
# [ 3  8 13 18 23 28 33 38 43 48 53 58 63 68 73 78 83 88 93 98]
# [ 4  9 14 19 24 29 34 39 44 49 54 59 64 69 74 79 84 89 94 99]]

(b, h, w, c) → (b, c, h, w)に入れ替える

PyTorchでよく行うパターンです。バッチサイズ、画像の幅・高さ、チャネル(RGB)の並びを、バッチサイズ、チャネル、画像の高さ・幅に並び替えます。

x = torch.randn(1, 320, 240, 3)
print(x.shape)
# torch.Size([1, 320, 240, 3])
y = rearrange(x, 'b h w c -> b c h w')
print(y.shape)
# torch.Size([1, 3, 320, 240])

(h, w, c)→(w, h, c)に入れ替える

こちらは、画像の縦と横を入れ替えるパターンです

x = torch.randn(320, 240, 3)
print(x.shape)
# torch.Size([320, 240, 3])
y = rearrange(x, 'h w c ->  w h c')
print(y.shape)
# torch.Size([240, 320, 3])

小さめの例で見ると、並びの変更に合わせて要素の並びも入れ替わっていることがわかります。

x = torch.tensor([[0,1],[2,3]])
print(x)
#tensor([[0, 1],
#        [2, 3]])
y = rearrange(x, 'x y -> y x')
print(y)
#tensor([[0, 2],
#        [1, 3]])

次元を減らす(c, h, w)→(c, h*w)

()を使うことで、次元を減らすこともできます。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> c (h w)')
print(y.shape)
# torch.Size([3, 76800])

次元を減らす(c, h, w)→(c*h*w)

1次元に変換するサンプルです。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> (c h w)')
print(y.shape)
# torch.Size([230400])

サイズを変更

rearrangeを使って、サイズを変更することも可能です。下記の例は、画像のサイズを半分にするイメージです。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> (c h w)', c=3, h=120, w=160)
print(y.shape)
# torch.Size([3, 120, 160])

reduce(削除)

reduceは、次元を削減する関数です。

次元を削除する

次元を削除する時に、どのような演算を行うか指定できます。指定できるのは、以下の通りです。

パラメータ操作
‘min’最小のものを抽出します
‘max’最大のものを抽出します
‘sum’合計を計算します
‘mean’平均を計算します
‘prod’要素を掛け合わせます

方向

実際の動きは、例を見た方がわかりやすいです。'x y -> x'の場合は、毎に集計されます。

x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = reduce(x, 'x y -> x', 'min')
print(y)
# [1 3]

方向

'x y -> y'の場合は、毎に集計されます。

x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = reduce(x, 'x y -> y', 'min')
print(y)
# [1 2]

(c, h, w) →(c, h/2, w/2)にする

縦と横を半分にして、2×2のうちのmaxに置き換える処理です(maxプーリングと同じ動きです)。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = reduce(x, 'c (h 2) (w 2) -> c h w', 'max')
print(y.shape)
# torch.Size([3, 120, 160])

工夫すれば色々な演算が可能ですが、可読性を挙げるために利用しているのやりすぎて可読性を下げないように注意しましょう。

repeat(繰り返し)

繰り返しです。指定した軸を繰り返すことで追加します。

x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = repeat(x, 'x y -> x c y', c = 3)
print(y)
# [[[1 2]
#  [4 3]]
#
# [[1 2]
#  [4 3]]
#
# [[1 2]
#  [4 3]]]

parse_shape 

シェープを取り出します。x.shapeでも良いですが、こちらは名前をつけて辞書型で取り出すことが可能です。

x = torch.randn(1, 3, 240, 320)
print(x.shape)
# torch.Size([1, 3, 240, 320])
y = parse_shape(x, 'b _ h w')
print(y)
# {'b': 1, 'h': 240, 'w': 320}

簡単なベンチマーク

rearrangeの実行時間を調べる

頻繁に使う(h, w, c)→(c, h, w)の並べ替えの時間を測定してみました。処理はCPUで行なっています。

import torch
from einops import rearrange, reduce, repeat, parse_shape


import time
x = torch.randn(240,320,3)

loop = 1000000

time_sta = time.time()
for _ in range(loop):
    y = rearrange(x, 'h w c -> c w h')
time_end = time.time()
tim = time_end- time_sta
print(tim)
#2.2942588329315186

time_sta = time.time()
for _ in range(loop):
    y = x.permute(2,0,1)
time_end = time.time()
tim = time_end- time_sta
print(tim)
#0.9618110656738281

結果は、einopsを使うと約2.3秒、torch.premuteを使うと約0.96秒でした。2倍ほど遅くなる感じです。少し気になる速度ですが、可読性とどちらを優先するかでしょうか。

ただ、他の処理に比べてそこまで重くないので、全体の処理に対してはそれほど影響がない可能性もあります。

まとめ

可読性向上に便利なeinopsについて説明しました。ディープラーニング関連のコードは、テンソルの演算や入れ替えを頻繁に行うため、知らない人には読みにくいコードになりがちです。einopsをつかって可読性を上げておくと、後で見直すときに有効かもしれません。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました