どうも、カタミチです。
さて、今日も「最短コースでわかる PyTorch&深層学習プログラミング」のひとり読書会、やっていきますかねー。
今日は2-2「テンソル」の続きですね。
2-2. テンソル(続き)
テンソルの各要素の型は浮動小数点数型である「float32」にしておく…ってのがここまでのテンソルの話でしたが、整数値型で格納しておいた方が都合がいいケースもあるようです。そのために使うのは「long関数」ですね。
私の知っている言葉でいうと、型変換(キャスト)ってやつですかね。long関数を噛ませた整数型は「int64」と表記されるようですね。
…まぁ、上の変換はあくまでも、これまで使った変数(r1)を流用したくてこんな書き方になっていますが、整数型で定義したいのなら初めからlong関数で定義しておけば良さそうですね。
次に「view関数」についてです。前回例で使った3階テンソルである変数r3を流用します。
view関数は、テンソルの階数を変換するときに使う関数のようです。上の(3, 2, 2)の3階テンソルを、(3, 4)の2階テンソルに変換する場合…
こう書きます。
重要なルールは「変換元の要素の数と変換先の要素の数が一致する」ということです。元の要素数は、\(3 \times 2 \times 2 = 12\)ってことになりますので、2階テンソルの行の数が\(3\)であるならば、列の数は\(4\)ってことになります。\(3 \times 4 = 12\)ですからね。
さてコードの書き方ですが、本当は「r6 = r3.view(3, 4)」って書く方が厳密なんですが、「\(3\)にいくつをかければ\(12\)になるっけ?」というのを導出するのが苦手な人向けに、「-1って書いとけば勝手に数値を導出しておくよー」という記法があるようです。…まぁ、苦手じゃなくても「-1」を使っておいたほうが計算ミスが無くて無難ですかね。
ただ逆に、書いた数値が間違っていた場合にも、-1の方をいい感じにエラー無く調整してくれるので、本当に厳密に書きたいのなら使うべきではないのかもしれません。可読性も落ちますしね。要は、状況によりけり…って感じですかね。ま、そうは言いつつ-1を積極的に使うことにはなりそうです(汗)。
しかし、階数が減った場合にどういう並びになるのかは、上の結果を見てしっかり目で追っておいたほうが良さそうですね。
このview関数、いろんな階数のテンソルを1階テンソルにする時に良く使うみたいです。
カタマリになってるやつを「びよーーーん」と引き伸ばす感じですかね。これはイメージしやすいですね。
さて、次に書かれていたのは「required_grad属性」と、「device属性」です。なんかこれらは、のちに出てくるみたいなので、今回は顔見せ程度のようです。
次に「item関数」。テンソルから数値を取り出すのに使うようです。0階テンソルと要素の中身がすべて同じテンソルで利用可能なようで、損失関数のデータ記録用などに使うようです。
次に「max関数」。こんな感じで書くと…
要素の中で一番大きな数値を返してくれるようです。ちなみに「r2.max()」は「torch.max(r2)」とも書けます。同じ意味ですね。これは別にこの関数に限った話というわけではないですが、両方の書き方を覚えておいたほうが良さそうですね。
で、この後者の書き方を拡張して、引数をもうひとつ取ったものがコレ。
第2引数は軸方向の指定で、「1」は「行方向」ってことのようです。つまり、行で1番大きいやつを抜き出したことになります。なので[6., 4.]が抜き出されたってことになりますね。ついでに、何番目のインデックスを取ったのか?も分かるようになっています。[2, 0]なのでそれぞれ、2番目と0番目ですね。
で、その行で最大の値のある場所のインデックスだけ引っこ抜く書き方が…
コレ↑です。よく使うらしいので、覚えておきますかねー。簡単に言うと「お山の大将はあそこにいますよ」ってことですね(分かりにくい例え)。
最後に、テンソル型をNumPy配列に変換するやり方ですね。
「.data.numpy()」ですね。とりあえず了解しました。ちなみに、以前の節であった「参照渡しだから気をつけろ」ってのは意識しておく必要がありそうです。
ということで
ひとまず、テンソルの基本的な使い方は分かったかなー、という感じですかね、たぶん(自信なし)。まぁ、まだ準備段階なので出てきたものを覚える…ということを繰り返していきましょう。
次の節からは微分の話ですね。
ではまた。