機械学習
記事内に商品プロモーションを含む場合があります

テンソル演算の可読性を向上させる、Einopsの使い方【Pytorch】

Aru

ディープラーニングにおいて、reshape, permute, squeeze, unsqueezeなどのテンソル操作はコードの可読性を損ねる要因です。複雑なテンソル操作をシンプルに記述し、可読性を向上させるにはEinopsというライブラリが役立ちます。本記事では、Einopsの基本的な使い方をチートシート形式まとめ、さらに、気になるパフォーマンスについて、ベンチマーク結果を紹介します。

Einopsとは

einopsとは

Einopsは、PyTorchやTensorFlowなどのディープラーニングフレームワークでテンソル操作を行うためのライブラリです。「Einstein operations」の略称であり、テンソル操作をより簡潔かつ柔軟に記述できるように設計されています。

Einopsを使用すると、テンソルの形状の変更次元の軸の入れ替え畳み込みやプーリングなどの操作を、直感的で読みやすい形式で記述可能です。たとえば、einopsを使用して畳み込み操作を行うと、一般的な畳み込み演算子よりも簡潔で理解しやすいコードを記述できます。

einopsの特徴は、単純な構文と柔軟な操作性です。また、einopsは他のライブラリやフレームワークとも統合が容易であり、さまざまなプロジェクトで幅広く活用されています。

ディープラーニングの実装において、テンソル操作は不可欠な部分ですが、einopsを使用することでより効率的かつ可読性の高いコードを記述することができます。

パフォーマンス面では色々議論がありますが、可読性の高さは正義だと思います。

基本的なPyTorchのテンソルの操作はこちら
PyTorchテンソル操作・演算の逆引きチートシート 【初級 深層学習講座】
PyTorchテンソル操作・演算の逆引きチートシート 【初級 深層学習講座】

einopsの気になる点

Einopsは可読性を向上させますが、処理が極端に遅くなるのでは使えません。そこで、この記事では、einopsを使用した場合に処理時間がどう変化するのかを簡単にベンチマークしてみました。

一般的には、「einopsを使用した場合のオーバヘッドは非常に小さい」とされていますが、実際に計測してみると処理速度が多少は変化することを確認できました。しかしながら、ネットワークの全体の処理時間から考えると、その影響は「非常に小さい」と見なすこともできそうです。

ベンチマーク結果は記事後半にありますので、自身で結果を確認して判断してみてください。

前準備

インストール

インストールはpipで行うことができます。

pip install einops

ライブラリのインポート

torchとnumpyも合わせてインポートしています。

公式では、from…でのインポートとなっていましたので同じようにしました。

import torch
import numpy as np
from einops import rearrange, reduce, repeat, parse_shape#

Einopsの機能

Einopsには、大きく分けて以下の機能があります。ここでは、これらについて解説します。

  • rearrange(再配置)
    軸を入れ替えたり、一部の次元を1つにまとめます
  • reduce(削除)
    次元を削減します
  • repeat(繰り返し)
    テンソルの内容を繰り返します
  • parse_shape
    一部を抽出します

rearrange(再配置)

軸を入れ替えたり、一部の次元を1つにまとめたりすることができます。

基本

基本的な使い方です。'a b -> b a'の左側が元の次元、右が変更後の次元になります。

次元の名前は、なんでも良いです

a b -> b aをみれは、aとbを入れ替える処理をするこが一目でわかります。この可読性の高さがeinopsの魅力です。また、pytorch, numpy, tensorlow、どれでも同じ記述をすることができます。フレームワークごとに微妙に違う関数名に悩まされずに済むのは魅力です。

x = np.arange(20*5).reshape(20,5)
print(x.shape, x)
#(20, 5) [[ 0  1  2  3  4]
# [ 5  6  7  8  9]
# [10 11 12 13 14]
# [15 16 17 18 19]
# [20 21 22 23 24]
# [25 26 27 28 29]
# [30 31 32 33 34]
# [35 36 37 38 39]
# [40 41 42 43 44]
# [45 46 47 48 49]
# [50 51 52 53 54]
# [55 56 57 58 59]
# [60 61 62 63 64]
# [65 66 67 68 69]
# [70 71 72 73 74]
# [75 76 77 78 79]
# [80 81 82 83 84]
# [85 86 87 88 89]
# [90 91 92 93 94]
# [95 96 97 98 99]]
y = rearrange(x, 'a b -> b a')
print(y.shape, y)
#(5, 20) [[ 0  5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95]
# [ 1  6 11 16 21 26 31 36 41 46 51 56 61 66 71 76 81 86 91 96]
# [ 2  7 12 17 22 27 32 37 42 47 52 57 62 67 72 77 82 87 92 97]
# [ 3  8 13 18 23 28 33 38 43 48 53 58 63 68 73 78 83 88 93 98]
# [ 4  9 14 19 24 29 34 39 44 49 54 59 64 69 74 79 84 89 94 99]]

(b, h, w, c) → (b, c, h, w)に入れ替える

PyTorchでよく行うパターンです。バッチサイズ、画像の幅・高さ、チャネル(RGB)の並びを、バッチサイズ、チャネル、画像の高さ・幅に並び替えます。

x = torch.randn(1, 320, 240, 3)
print(x.shape)
# torch.Size([1, 320, 240, 3])
y = rearrange(x, 'b h w c -> b c h w')
print(y.shape)
# torch.Size([1, 3, 320, 240])

(h, w, c)→(w, h, c)に入れ替える

こちらは、画像の縦と横を入れ替えるパターンです

x = torch.randn(320, 240, 3)
print(x.shape)
# torch.Size([320, 240, 3])
y = rearrange(x, 'h w c ->  w h c')
print(y.shape)
# torch.Size([240, 320, 3])

小さめの例で見ると、並びの変更に合わせて要素の並びも入れ替わっていることがわかります。

x = torch.tensor([[0,1],[2,3]])
print(x)
#tensor([[0, 1],
#        [2, 3]])
y = rearrange(x, 'x y -> y x')
print(y)
#tensor([[0, 2],
#        [1, 3]])

次元を減らす(c, h, w)→(c, h*w)

()を使うことで、次元を減らすこともできます。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> c (h w)')
print(y.shape)
# torch.Size([3, 76800])

次元を減らす(c, h, w)→(c*h*w)

1次元に変換するサンプルです。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> (c h w)')
print(y.shape)
# torch.Size([230400])

サイズを変更

rearrangeを使って、サイズを変更することも可能です。下記の例は、画像のサイズを半分にするイメージです。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = rearrange(x, 'c h w -> (c h w)', c=3, h=120, w=160)
print(y.shape)
# torch.Size([3, 120, 160])

reduce(削除)

reduceは、次元を削減する関数です。

次元を削除する

次元を削除する時に、どのような演算を行うか指定できます。指定できるのは、以下の通りです。

パラメータ操作
‘min’最小のものを抽出します
‘max’最大のものを抽出します
‘sum’合計を計算します
‘mean’平均を計算します
‘prod’要素を掛け合わせます

方向

実際の動きは、例を見た方がわかりやすいです。'x y -> x'の場合は、毎に集計されます。

x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = reduce(x, 'x y -> x', 'min')
print(y)
# [1 3]

方向

'x y -> y'の場合は、毎に集計されます。

x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = reduce(x, 'x y -> y', 'min')
print(y)
# [1 2]

(c, h, w) →(c, h/2, w/2)にする

縦と横を半分にして、2×2のうちのmaxに置き換える処理です(maxプーリングと同じ動きです)。

x = torch.randn(3, 240, 320)
print(x.shape)
# torch.Size([3, 240, 320])
y = reduce(x, 'c (h 2) (w 2) -> c h w', 'max')
print(y.shape)
# torch.Size([3, 120, 160])

工夫すれば色々な演算が可能ですが、可読性を挙げるために利用しているのやりすぎて可読性を下げないように注意しましょう。

repeat(繰り返し)

繰り返しです。指定した軸を繰り返すことで追加します。

x = np.array([[1,2],[4,3]])
print(x)
#[[1 2]
# [4 3]]
y = repeat(x, 'x y -> x c y', c = 3)
print(y)
# [[[1 2]
#  [4 3]]
#
# [[1 2]
#  [4 3]]
#
# [[1 2]
#  [4 3]]]

parse_shape 

シェープを取り出します。x.shapeでも良いですが、こちらは名前をつけて辞書型で取り出すことが可能です。

x = torch.randn(1, 3, 240, 320)
print(x.shape)
# torch.Size([1, 3, 240, 320])
y = parse_shape(x, 'b _ h w')
print(y)
# {'b': 1, 'h': 240, 'w': 320}

ベンチマーク結果

rearrangeの実行時間を調べる

頻繁に使う(h, w, c)→(c, h, w)の並べ替えの時間を測定してみました。処理はCPUで行なっています。

import torch
from einops import rearrange, reduce, repeat, parse_shape


import time
x = torch.randn(240,320,3)

loop = 1000000

time_sta = time.time()
for _ in range(loop):
    y = rearrange(x, 'h w c -> c w h')
time_end = time.time()
tim = time_end- time_sta
print(tim)
#2.2942588329315186

time_sta = time.time()
for _ in range(loop):
    y = x.permute(2,0,1)
time_end = time.time()
tim = time_end- time_sta
print(tim)
#0.9618110656738281

結果は、einopsを使うと約2.3秒、torch.premuteを使うと約0.96秒でした。2倍ほど遅くなる感じです。思ったより処理時間が増加しています。

しかしながら、深層学習のモデルの処理全体に占めるpermute処理の比率はそこまで大きくないと思います。ここが2倍重くなっても、全体の処理としてはたいして影響がない可能性もあります。

処理時間の増加と可読性の向上を天秤にかけて、検討する必要がありそうです。

まとめ

可読性向上に便利なeinopsについて説明しました。ディープラーニング関連のコードは、テンソルの演算や入れ替えを頻繁に行うため、知らない人には読みにくいコードになりがちです。einopsをつかって可読性を上げておくと、他の人に引き継ぐとき、あとで読み返すときには確実に有用だと思います。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました