テンソル演算の可読性を向上させる、Einopsの使い方【Pytorch】
ディープラーニングにおいて、reshape, permute, squeeze, unsqueezeなどのテンソル操作はコードの可読性を損ねる要因です。複雑なテンソル操作をシンプルに記述し、可読性を向上させるにはEinopsというライブラリが役立ちます。本記事では、Einopsの基本的な使い方をチートシート形式まとめ、さらに、気になるパフォーマンスについて、ベンチマーク結果を紹介します。
Einopsとは
einopsとは
Einopsは、PyTorchやTensorFlowなどのディープラーニングフレームワークでテンソル操作を行うためのライブラリです。「Einstein operations」の略称であり、テンソル操作をより簡潔かつ柔軟に記述できるように設計されています。
Einopsを使用すると、テンソルの形状の変更、次元の軸の入れ替え、畳み込みやプーリングなどの操作を、直感的で読みやすい形式で記述可能です。たとえば、einopsを使用して畳み込み操作を行うと、一般的な畳み込み演算子よりも簡潔で理解しやすいコードを記述できます。
einopsの特徴は、単純な構文と柔軟な操作性です。また、einopsは他のライブラリやフレームワークとも統合が容易であり、さまざまなプロジェクトで幅広く活用されています。
ディープラーニングの実装において、テンソル操作は不可欠な部分ですが、einopsを使用することでより効率的かつ可読性の高いコードを記述することができます。
パフォーマンス面では色々議論がありますが、可読性の高さは正義だと思います。
einopsの気になる点
Einopsは可読性を向上させますが、処理が極端に遅くなるのでは使えません。そこで、この記事では、einopsを使用した場合に処理時間がどう変化するのかを簡単にベンチマークしてみました。
一般的には、「einopsを使用した場合のオーバヘッドは非常に小さい」とされていますが、実際に計測してみると処理速度が多少は変化することを確認できました。しかしながら、ネットワークの全体の処理時間から考えると、その影響は「非常に小さい」と見なすこともできそうです。
ベンチマーク結果は記事後半にありますので、自身で結果を確認して判断してみてください。
前準備
インストール
インストールはpipで行うことができます。
pip install einops
ライブラリのインポート
torchとnumpyも合わせてインポートしています。
公式では、from…でのインポートとなっていましたので同じようにしました。
import torch
import numpy as np
from einops import rearrange, reduce, repeat, parse_shape#
Einopsの機能
Einopsには、大きく分けて以下の機能があります。ここでは、これらについて解説します。
- rearrange(再配置)
軸を入れ替えたり、一部の次元を1つにまとめます - reduce(削除)
次元を削減します - repeat(繰り返し)
テンソルの内容を繰り返します - parse_shape
一部を抽出します
rearrange(再配置)
軸を入れ替えたり、一部の次元を1つにまとめたりすることができます。
基本
基本的な使い方です。'a b -> b a'
の左側が元の次元、右が変更後の次元になります。
次元の名前は、なんでも良いです
a b -> b a
をみれは、aとbを入れ替える処理をするこが一目でわかります。この可読性の高さがeinopsの魅力です。また、pytorch, numpy, tensorlow、どれでも同じ記述をすることができます。フレームワークごとに微妙に違う関数名に悩まされずに済むのは魅力です。
x = np.arange(20*5).reshape(20,5)
print(x.shape, x)
#(20, 5) [[ 0 1 2 3 4]
# [ 5 6 7 8 9]
# [10 11 12 13 14]
# [15 16 17 18 19]
# [20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]
# [40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]
# [60 61 62 63 64]
# [65 66 67 68 69]
# [70 71 72 73 74]
# [75 76 77 78 79]
# [80 81 82 83 84]
# [85 86 87 88 89]
# [90 91 92 93 94]
# [95 96 97 98 99]]
y = rearrange(x, 'a b -> b a')
print(y.shape, y)
#(5, 20) [[ 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95]
# [ 1 6 11 16 21 26 31 36 41 46 51 56 61 66 71 76 81 86 91 96]
# [ 2 7 12 17 22 27 32 37 42 47 52 57 62 67 72 77 82 87 92 97]
# [ 3 8 13 18 23 28 33 38 43 48 53 58 63 68 73 78 83 88 93 98]
# [ 4 9 14 19 24 29 34 39 44 49 54 59 64 69 74 79 84 89 94 99]]
(b, h, w, c) → (b, c, h, w)に入れ替える
PyTorchでよく行うパターンです。バッチサイズ、画像の幅・高さ、チャネル(RGB)の並びを、バッチサイズ、チャネル、画像の高さ・幅に並び替えます。
x = torch.randn(1, 320, 240, 3)
print(x.shape)
# torch.Size([1, 320, 240, 3])
y = rearrange(x, 'b h w c -> b c h w')
print(y.shape)
# torch.Size([1, 3, 320, 240])
(h, w, c)→(w, h, c)に入れ替える
こちらは、画像の縦と横を入れ替えるパターンです
x = torch.randn(320, 240, 3)
print(x.shape)
# torch.Size([320, 240, 3])
y = rearrange(x, 'h w c -> w h c')
print(y.shape)
# torch.Size([240, 320, 3])
小さめの例で見ると、並びの変更に合わせて要素の並びも入れ替わっていることがわかります。
x = torch.tensor([[0,1],[2,3]])
print(x)
#tensor([[0, 1],
# [2, 3]])
y = rearrange(x, 'x y -> y x')
print(y)
#tensor([[0, 2],
# [1, 3]])
次元を減らす(c, h, w)→(c, h*w)
()
を使うことで、次元を減らすこともできます。
x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> c (h w)')
print(y.shape)
# torch.Size([3, 76800])
次元を減らす(c, h, w)→(c*h*w)
1次元に変換するサンプルです。
x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> (c h w)')
print(y.shape)
# torch.Size([230400])
サイズを変更
rearrange
を使って、サイズを変更することも可能です。下記の例は、画像のサイズを半分にするイメージです。
x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> (c h w)', c=3, h=120, w=160)
print(y.shape)
# torch.Size([3, 120, 160])
reduce(削除)
reduce
は、次元を削減する関数です。
次元を削除する
次元を削除する時に、どのような演算を行うか指定できます。指定できるのは、以下の通りです。
パラメータ | 操作 |
‘min’ | 最小のものを抽出します |
‘max’ | 最大のものを抽出します |
‘sum’ | 合計を計算します |
‘mean’ | 平均を計算します |
‘prod’ | 要素を掛け合わせます |
行方向
実際の動きは、例を見た方がわかりやすいです。'x y -> x'
の場合は、行毎に集計されます。
x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = reduce(x, 'x y -> x', 'min')
print(y)
# [1 3]
列方向
'x y -> y'
の場合は、列毎に集計されます。
x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = reduce(x, 'x y -> y', 'min')
print(y)
# [1 2]
(c, h, w) →(c, h/2, w/2)にする
縦と横を半分にして、2×2のうちのmaxに置き換える処理です(maxプーリングと同じ動きです)。
x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = reduce(x, 'c (h 2) (w 2) -> c h w', 'max')
print(y.shape)
# torch.Size([3, 120, 160])
工夫すれば色々な演算が可能ですが、可読性を挙げるために利用しているのやりすぎて可読性を下げないように注意しましょう。
repeat(繰り返し)
繰り返しです。指定した軸を繰り返すことで追加します。
x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = repeat(x, 'x y -> x c y', c = 3)
print(y)
# [[[1 2]
# [4 3]]
#
# [[1 2]
# [4 3]]
#
# [[1 2]
# [4 3]]]
parse_shape
シェープを取り出します。x.shapeでも良いですが、こちらは名前をつけて辞書型で取り出すことが可能です。
x = torch.randn(1, 3, 240, 320)
print(x.shape)
# torch.Size([1, 3, 240, 320])
y = parse_shape(x, 'b _ h w')
print(y)
# {'b': 1, 'h': 240, 'w': 320}
ベンチマーク結果
rearrangeの実行時間を調べる
頻繁に使う(h, w, c)→(c, h, w)の並べ替えの時間を測定してみました。処理はCPUで行なっています。
import torch
from einops import rearrange, reduce, repeat, parse_shape
import time
x = torch.randn(240,320,3)
loop = 1000000
time_sta = time.time()
for _ in range(loop):
y = rearrange(x, 'h w c -> c w h')
time_end = time.time()
tim = time_end- time_sta
print(tim)
#2.2942588329315186
time_sta = time.time()
for _ in range(loop):
y = x.permute(2,0,1)
time_end = time.time()
tim = time_end- time_sta
print(tim)
#0.9618110656738281
結果は、einopsを使うと約2.3秒、torch.premuteを使うと約0.96秒でした。2倍ほど遅くなる感じです。思ったより処理時間が増加しています。
しかしながら、深層学習のモデルの処理全体に占めるpermute処理の比率はそこまで大きくないと思います。ここが2倍重くなっても、全体の処理としてはたいして影響がない可能性もあります。
処理時間の増加と可読性の向上を天秤にかけて、検討する必要がありそうです。
まとめ
可読性向上に便利なeinopsについて説明しました。ディープラーニング関連のコードは、テンソルの演算や入れ替えを頻繁に行うため、知らない人には読みにくいコードになりがちです。einopsをつかって可読性を上げておくと、他の人に引き継ぐとき、あとで読み返すときには確実に有用だと思います。