初級ディープラーニング
記事内に商品プロモーションを含む場合があります

PyTorch|テンソルの操作と基本的な演算を理解する【初級 深層学習講座】

tadanori

この記事では、PyTorchを利用する上で基本となる、テンソル(tensor)の操作と基本的な演算について紹介します。この記事は、テンソル操作のチートシートとしても利用できるように作成しました。

あわせて読みたい
テンソル演算の可読性を向上させる、Einopsの使い方【Pytorch】
テンソル演算の可読性を向上させる、Einopsの使い方【Pytorch】

PyTorchのテンソルとは

PyTorchのテンソルは、多次元の数値データを格納するためのデータ構造で、さまざまな演算をサポートしています。

数学的には、「テンソル」は多次元の配列であり、スカラー、ベクトル、行列などの一般化された概念です。PyTorchのテンソルは、これをデータ構造として実現したものになります。

Numpyの配列とよく似たものですが、GPUでの演算自動微分など、深層学習向けに特化された機能を持つ点で異なります。

テンソルの操作は、PyTorchを使う場合に必要になります。

Numpyとよく似ているのですが、若干異なる点がすこし厄介です。

以下のコードは、それぞれ単独で動くようにしていますが、torchのインポートが必要です。

import torch

テンソルの作成

テンソルの作成方法はいくつかあります。

torch.Tensor(): 空のテンソル

空のテンソルを作るには、torch.Tensor()を使います。

empty_tensor = torch.Tensor()
print(empty_tensor)
# tensor([])

なお、引数にリストやnumpy配列で配列を渡すことで、テンソルを作ることもできます。

tensor = torch.Tensor([[1,2], [3,4]])
print(tensor)
# tensor([[1., 2.],
#        [3., 4.]])

torch.zeros(), torch.ones(): 全ての要素が0または1のテンソル

すべての値が0または1のテンソルを作る場合は、torch.zeros()またはtorch.ones()を用います。引数は、テンソルの形状です。

例えば(3,2)を引数として渡すと、3行2列の行列が生成できます。

zeros_tensor = torch.zeros(3, 2)
print(zeros_tensor, zeros_tensor.shape)
# tensor([[0., 0.],
#         [0., 0.],
#         [0., 0.]]) torch.Size([3, 2])

ones_tensor = torch.ones(2,3,4)
print(ones_tensor, ones_tensor.shape)
# tensor([[[1., 1., 1., 1.],
#          [1., 1., 1., 1.],
#          [1., 1., 1., 1.]],
#         [[1., 1., 1., 1.],
#          [1., 1., 1., 1.],
#          [1., 1., 1., 1.]]]) torch.Size([2, 3, 4])

torch.randn(): ランダムな値を持つテンソルを作成

torch.randn()は、標準正規分布に従う乱数の値を持つテンソルを作成します。引数はテンソルの形状です。

実際に、作成したテンソルの平均と偏差を調べると、以下の例のように平均が0.0、偏差が1.0に近いことがわかります。

random_tensor = torch.randn(2, 2)
print(random_tensor, random_tensor.shape)
# tensor([[-2.2998,  1.4645],
#         [ 0.2852, -1.6519]]) torch.Size([2, 2])

random_tensor2 = torch.randn(10000)
print(random_tensor2.mean(), random_tensor2.std())
# tensor(0.0054) tensor(1.0017)

なお、ランダムな値のテンソルを作成する方法としては、torch.randint(整数のランダムな値)やtorch.rand(一様なランダムな値)のテンソルを生成する方法もあります。

torch.arange(): 指定された範囲の値でテンソルを作成

torch.arangeを使うことで指定された範囲の範囲の1次元テンソルを作成することが可能です。引数は、arange(start, end, step)となり、引数が1つの場合はstart=0, step=1となります。引数が2つの場合は、step=1と解釈されます。

range_tensor = torch.arange(0, 10)
print(range_tensor, range_tensor.shape)
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) torch.Size([10])

torch.linspace(): 等間隔の値でテンソルを作成

linspace()は、指定した分割数で等間隔のテンソルを作成します。

以下の例では、start=0, end=10で10個の値を生成します。注意点としては、0,….,10という数列をつくるので、10を9で割ったステップで値が生成される点です(植木算と言われます)。

linspace_tensor = torch.linspace(0, 10, 10)
print(linspace_tensor, linspace_tensor.shape)
# tensor([ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,
#         8.8889, 10.0000]) torch.Size([10])

テンソルの変形

.view(), .reshape(): テンソルの形状を変更

view, reshapeを使うことで、テンソルの形状を変更することが可能です。

以下の例は、1次元の1×10のテンソルを5×2に計上変更しています。

viewreshapeはほぼ同じ動作をしますが、形状を変更するテンソルの各要素のメモリ上の配置により挙動が異なります。viewは、元のテンソルがメモリ上で連続していない場合はエラーが発生する可能性があります。

転置操作などを行うと、テンソルが非連続になります。

tensor = torch.arange(0, 10)
reshape_tensor = tensor.reshape(5, 2)
print(reshape_tensor, reshape_tensor.shape)
# tensor([[0, 1],
#         [2, 3],
#         [4, 5],
#         [6, 7],
#         [8, 9]]) torch.Size([5, 2])

view_tensor = range_tensor.view(5, 2)
print(view_tensor, view_tensor.shape)
# tensor([[0, 1],
#         [2, 3],
#         [4, 5],
#         [6, 7],
#         [8, 9]]) torch.Size([5, 2])

.unsqueeze(), .squeeze(): 次元の追加や削除

テンソルの次元を追加または削除します。

unsqueezeをすると指定した次元が追加され、squeezeをすると次元が削除されます。

下記の例は、次元を追加して削除しています(.shapeの結果を確認してください)

tensor = torch.arange(0, 10)
unsqueeze_tensor = tensor.unsqueeze(1)
squeeze_tensor = unsqueeze_tensor.squeeze()
print(unsqueeze_tensor.shape, squeeze_tensor.shape)
# torch.Size([10, 1]) torch.Size([10])

実は、PyTorchでは結構利用します。例えば、ディープラーニングのモデルの入力は、

バッチ数×チャネル数×高さ×幅

ですが、画像は、

高さ×幅×チャネル数

となっていることが多いです。

これを入力のフォーマットに合わせるために、以下のような変換をします。

input = img.premute(2,1,0).unsqueeze(0)

torch.transpose(), .permute(): テンソルの次元を入れ替え

transpose()premute()は、テンソルの次元の入れ替えをします。

transpose()は、引数で指定した2つの次元を入れ替えます。

permute()は、引数の並びに合わせて次元を入れ替えます。例えば、permute(2,1,0)と書くと、0,1,2→2,1,0と次元が入れ替えられます。

tensor = torch.arange(0, 10).reshape(2,5)
print(tensor, tensor.shape)
# tensor([[0, 1, 2, 3, 4],
#         [5, 6, 7, 8, 9]]) torch.Size([2, 5])

transpose_tensor = tensor.transpose(0, 1)
print(transpose_tensor, transpose_tensor.shape)
# tensor([[0, 5],
#         [1, 6],
#         [2, 7],
#         [3, 8],
#         [4, 9]]) torch.Size([5, 2])

permute_tensor = tensor.permute(1, 0)
print(permute_tensor, permute_tensor.shape)
# tensor([[0, 5],
#         [1, 6],
#         [2, 7],
#         [3, 8],
#         [4, 9]]) torch.Size([5, 2])

テンソルの演算

四則演算: 加算 (+), 減算 (-), 乗算 (*), 除算 (/)

テンソルを四則演算すると要素毎に加算や減算が行われます。

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
add_result = tensor1 + tensor2
print(add_result)
# tensor([[ 6,  8],
#         [10, 12]])

sub_result = tensor1 - tensor2
print(sub_result)
# tensor([[-4, -4],
#         [-4, -4]])

mul_result = tensor1 * tensor2
print(mul_result)
# tensor([[ 5, 12],
#         [21, 32]])

div_result = tensor1 / tensor2
print(div_result)
# tensor([[0.2000, 0.3333],
#         [0.4286, 0.5000]])

ブロードキャスト

ブロードキャストは、次元数の異なる2つのテンソルでも、条件を満たせば四則演算を可能とするものです。

例えば、テンソルと整数の場合は、テンソル全てに対して整数との四則演算が行われます。

また、片方が2×2、片方が2の場合は、以下の例のように四則演算が行われます。

tensor1 = torch.tensor([[1, 2], [3, 4]])
add_result = tensor1 + 100
print(add_result)
# tensor([[101, 102],
#         [103, 104]])

tensor2 = torch.tensor([10, 100])
mul_result = tensor1*tensor2
print(mul_result)
# tensor([[ 10, 200],
#         [ 30, 400]]

行列演算: torch.matmul(), torch.mm(), @演算子

行列演算を行う場合は、matmul, mm, @を使います。どれも結果は同じです。

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])

matmul_result = torch.matmul(tensor1, tensor2)
print(matmul_result)
# tensor([[19, 22],
#         [43, 50]])

matmul_result = tensor1.mm(tensor2)
print(matmul_result)
# tensor([[19, 22],
#         [43, 50]])

matmul_result = tensor1 @ tensor2
print(matmul_result)
# tensor([[19, 22],
#         [43, 50]])

要素ごとの演算: .add(), .sub(), .mul(), .div()

.add(), .sub(), .mul(), .div()を使って四則演算を行うことも可能です。この場合は要素ごとの演算が行われます。

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6], [7, 8]])
add_elemwise = tensor1.add(tensor2)
print(add_elemwise)
# tensor([[ 6,  8],
#         [10, 12]])

sub_elemwise = tensor1.sub(tensor2)
print(sub_elemwise)
# tensor([[-4, -4],
#         [-4, -4]])

mul_elemwise = tensor1.mul(tensor2)
print(mul_elemwise)
# tensor([[ 5, 12],
#         [21, 32]])

div_elemwise = tensor1.div(tensor2)
print(div_elemwise)
# tensor([[0.2000, 0.3333],
#         [0.4286, 0.5000]])

比較演算子: torch.eq(), torch.gt(), torch.lt()

eq(), gt(), lt()などの比較演算もサポートされています。結果はTrue, Falseで返ります。

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[1, 2], [7, 0]])
eq_result = torch.eq(tensor1, tensor2)
print(eq_result)
# tensor([[ True,  True],
#         [False, False]])

gt_result = torch.gt(tensor1, tensor2)
print(gt_result)
# tensor([[False, False],
#         [False,  True]])

lt_result = torch.lt(tensor1, tensor2)
print(lt_result)
# tensor([[False, False],
#         [ True, False]])

転置

ベクトルを転置する場合は、.Tを用います。

tensor = torch.Tensor([[1,2], [3,4]])
t_tnesor = tensor.T
print(t_tnesor)
# tensor([[1., 3.],
#        [2., 4.]])

内積(dot)

1次元のテンソル同士の内積を求めるにはtorch.dotを用います

tensor = torch.Tensor([1,2])
dot = torch.dot(tensor, tensor)
print(dot)
# tensor(5.)

インデックスとスライスイング

インデックス指定: 特定の位置の要素を取得

インデックスを指定して要素の値を取得できます。

例では、0行1列目の値(=1)を取得しています。

tensor = torch.arange(9).reshape(3, 3)
print(tensor)
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

indexing = tensor[0, 1]
print(indexing)
# tensor(1)

スライシング: 部分的なテンソルを取得

テンソルの一部の値を取得することができます。

例では、全ての行(:)の、1列目以降(1:)を取得しています。

tensor = torch.arange(9).reshape(3, 3)
print(tensor)
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

slicing = tensor[:, 1:]
print(slicing)
# tensor([[1, 2],
#         [4, 5],
#         [7, 8]])

ファンシーインデックス

テンソルや配列でインデックス指定するファンシーインデックスを使うことも可能です。

下記の例では、0行目と2行目を指定して取得しています。

tensor = torch.arange(9).reshape(3, 3)
indices = torch.tensor([0, 2])
print(tensor)
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

fancy_indexing = tensor[indices]
print(fancy_indexing)
# tensor([[0, 1, 2],
#         [6, 7, 8]])

ブールインデックス

論理演算によりインデックスすることも可能です。

下記の例では、tensor>5の条件に合う部分だけ抜き出しています。

 tensor = torch.arange(9).reshape(3, 3)
print(tensor)
# tensor([[0, 1, 2],
#         [3, 4, 5],
#         [6, 7, 8]])

bool_indexing = tensor[tensor > 5]
print(bool_indexing)
# tensor([6, 7, 8])

その他(結合など)の操作

torch.cat(): テンソルを結合

テンソルを結合するにはtorch.catを利用します。

dimを指定することで結合する次元を指定することが可能です。

どのように結合されるかは例を参照してください。

tensor1 = torch.arange(9).reshape(3, 3)
tensor2 = torch.arange(10, 19).reshape(3, 3)
concat_tensor = torch.cat([tensor1, tensor2], dim=0)
print(concat_tensor)
# tensor([[ 0,  1,  2],
#         [ 3,  4,  5],
#         [ 6,  7,  8],
#         [10, 11, 12],
#         [13, 14, 15],
#         [16, 17, 18]])
concat_tensor = torch.cat([tensor1, tensor2], dim=1)
print(concat_tensor)
# tensor([[ 0,  1,  2, 10, 11, 12],
#         [ 3,  4,  5, 13, 14, 15],
#         [ 6,  7,  8, 16, 17, 18]])

torch.stack(): テンソルを新たな次元で結合

テンソルを新たな次元で結合します。

catとの違いは新しい次元で結合されることです。下記はdim=0,1,2の例です。それぞれ指定した次元が追加されて結合されていることを確認してください。

stack_tensor = torch.stack([tensor1, tensor2])
print(stack_tensor, stack_tensor.shape)
# tensor([[[ 0,  1,  2],
#          [ 3,  4,  5],
#          [ 6,  7,  8]],
# 
#         [[10, 11, 12],
#          [13, 14, 15],
#          [16, 17, 18]]]) torch.Size([2, 3, 3])

stack_tensor = torch.stack([tensor1, tensor2], dim=1)
print(stack_tensor, stack_tensor.shape)
# tensor([[[ 0,  1,  2],
#          [10, 11, 12]],
# 
#         [[ 3,  4,  5],
#          [13, 14, 15]],
# 
#         [[ 6,  7,  8],
#          [16, 17, 18]]]) torch.Size([3, 2, 3])

stack_tensor = torch.stack([tensor1, tensor2], dim=2)
print(stack_tensor, stack_tensor.shape)
# tensor([[[ 0, 10],
#          [ 1, 11],
#          [ 2, 12]],
# 
#         [[ 3, 13],
#          [ 4, 14],
#          [ 5, 15]],
# 
#         [[ 6, 16],
#          [ 7, 17],
#          [ 8, 18]]]) torch.Size([3, 3, 2])

torch.split(): テンソルを指定されたサイズに分割

splitは、テンソルを指定した行数で分割します。

例のようには数がでた場合は、指定した行数に満たないテンソルが作られます。

例では、2行毎、1行枚に分割しています。戻り値はタプルです。

tensor = torch.arange(9).reshape(3, 3)
split_tensors = torch.split(tensor, 2)
print(split_tensors)
# (tensor([[0, 1, 2],
#         [3, 4, 5]]), tensor([[6, 7, 8]]))

split_tensors = torch.split(tensor, 1)
print(split_tensors)
# (tensor([[0, 1, 2]]), tensor([[3, 4, 5]]), tensor([[6, 7, 8]]))

torch.sort(): テンソルの要素をソート

sortはテンソルの要素をソートします。ソートする次元はdimで指定することができます。

以下の例ではdim=1dim=0でソートしています。dim=0では、行方向でソートされていることに注意してください。

また、descending=Trueを指定することで降順にソートできます。

tensor = torch.rand(2,5)
print(tensor)
# tensor([[0.5613, 0.9695, 0.2886, 0.2987, 0.0773],
#         [0.3537, 0.5369, 0.0145, 0.0271, 0.4820]])

sorted_tensor, sorted_indices = torch.sort(tensor, dim=1)
print(sorted_tensor)
# tensor([[0.0773, 0.2886, 0.2987, 0.5613, 0.9695],
#         [0.0145, 0.0271, 0.3537, 0.4820, 0.5369]])
print(sorted_indices)
# tensor([[4, 2, 3, 0, 1],
#         [2, 3, 0, 4, 1]])

sorted_tensor, sorted_indices = torch.sort(tensor, dim=0, descending=True)
print(sorted_tensor)
# tensor([[0.7752, 0.1893, 0.5090, 0.5800, 0.2460],
#         [0.6563, 0.0693, 0.4898, 0.5022, 0.0015]])
print(sorted_indices)
# tensor([[1, 1, 1, 0, 1],
#         [0, 0, 0, 1, 0]])

torch.unique(): テンソルのユニークな要素を取得

uniqueでテンソル中のユニーク(一意)な値を抜き出せます。

引数にreturn_counts=Trueを指定すると、各要素の数もカウントできます。

tensor = torch.randint(0, 10, (10, 10))
print(tensor)
# tensor([[4, 1, 0, 9, 4, 0, 6, 5, 1, 4],
#         [5, 0, 4, 3, 4, 4, 5, 6, 1, 8],
#         [9, 3, 9, 9, 1, 2, 1, 5, 9, 1],
#         [5, 9, 7, 9, 2, 3, 2, 3, 1, 9],
#         [7, 3, 8, 5, 6, 1, 9, 3, 3, 3],
#         [3, 5, 5, 1, 1, 7, 8, 5, 6, 1],
#         [1, 3, 0, 3, 3, 8, 2, 9, 0, 3],
#         [4, 6, 0, 5, 8, 8, 6, 2, 0, 2],
#         [9, 1, 4, 7, 5, 4, 7, 7, 9, 5],
#         [8, 9, 8, 7, 8, 8, 2, 3, 3, 3]])

unique_tensor = torch.unique(tensor)
print(unique_tensor)
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

unique_tensor, counts = torch.unique(tensor, return_counts = True)
print(unique_tensor, counts)
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) tensor([ 7, 13,  7, 16,  9, 12,  6,  7, 10, 13])

まとめ

以上、テンソルの操作と基本的な演算についてまとめました。PyTorchを使うとテンソルへの操作が必要になります。ここに書いた操作は覚えておくと、他人のコードを読む場合にも役にたちます。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました