【PyTorch】ニューラルネットワークを構築する方法【NN】

公開日: : 最終更新日:2022/03/17 機械学習 , , , , ,

今回は、PyTorch を使って、ニューラルネットワーク(NN)を構築したときのメモです。

ファッションアイテムの公開データセットである Fashion-MNIST を用いて、ニューラルネットワークを構築し、学習・テストまでを行います。

Fashion-MNIST については、以前の記事にまとめてありますので、もしよければ参考にしてみてください。

【Fashion-MNIST】ファッションアイテムのデータセットを使ってみた【TensorFlow】

動作環境としては、Anaconda の jupyter 上で動かし、CPU の AMD Ryzen 7 を使っています。

Contents

データセットの読み込み

まず、データセットのダウンロードと読み込みを行います。

以下のように、「data」フォルダが作成され、データセットがダウンロードされます。

jupyter 上での出力は以下のようになりました。

ニューラルネットワーク(NN)の入力は、畳込みニューラルネットワークと違って、フラットな形状である必要があるため、

lambda x: x.view(-1)」として、2階テンソルを1階テンソルに変換しています。

lambda」は、「lambda 引数: 返り値」というように定義して使います。

また、「batch_size = 64」として、ミニバッチのサイズを 64 としています。

ミニバッチは、1回の学習・テストに用いるデータの数を指定しています。データはランダムに抽出して使います。

最後の出力は以下のようになり、訓練データ・テストデータのミニバッチのサイズが確認できます。

ニューラルネットワークのモデルを定義・生成

モデルの定義、生成を行い、その構造を出力します。今回は、2層の浅いネットワークとなっています。

モデルの定義は、以下のように行います。

ここではモデルを定義しているだけなので、出力はありません。

次にモデルの生成を行います。

モデルの生成は、以下のように行います。

torch.cuda.is_available()」で GPU が使用可能か確認することができます。「True」が返れば、GPU が使えるということになります。

モデルの構造は、以下のようになりました。

「out_features=10」は、 Fashion-MNIST が10種類あるため、10種類のファッションアイテムを分類するためです。

損失関数と最適化手法の設定

損失関数と最適化手法(オプティマイザー)の設定を行います。

最適化手法には、勾配降下アルゴリズムを用います。

パラメータ更新用の関数を定義

学習時のパラメータを更新する train_step 関数を定義します。

モデル評価用の関数を定義

テストデータを使って評価を行うための test_step 関数を定義します。

早期終了判定を行う関数を定義

早期終了判定を行う EarlyStopping 関数を定義します。

学習

実際に、学習・テストを繰り返し行います。

エポック数が 200 に設定されているため、最大200回繰り返されます。

%%time」とすることで実行時間を計測することができます。

結果、73エポックで終了となりました。

CPU での実行時間は約13分でした。

ちなみに、算出している精度 accuracy は、正解率となります。

精度のグラフ化

エポックごとの損失と精度をグラフ化します。

※横軸の epoc が80を超えていますが、これは一度学習・検証をし直したためです。

参考文献

関連記事

【機械学習】 scikit-learn で精度・再現率・F値を算出する方法【Python】

今回は、2クラス分類で Python の scikit-learn を使った評価指標である、精度(P

記事を読む

【機械学習】パーセプトロン(Perceptron)について。

パーセプトロンは、教師あり学習の中でも、入出力モデルベース(eager learning:働き者の学

記事を読む

【転移学習】学習済みVGG16 による転移学習を行う方法【PyTorch】

今回は、PyTorch を使って、学習済みのモデル VGG16 を用いて転移学習をしてみました。

記事を読む

【探索】縦型・横型・反復深化法の探索手法の比較。

探索とは、チェスや将棋や囲碁などのゲームをコンピュータがプレイするときに、どの手を指すかを決定するの

記事を読む

【機械学習】モンテカルロ法(Monte Carlo method)について。

モンテカルロ法(Monte Carlo method)とは、シュミレーションや数値計算を乱数を用いて

記事を読む

【探索】ダイクストラ法・最良優先探索・Aアルゴリズムの比較。

縦型探索や横型探索では、機械的に順序を付け、最小ステップでゴールを目指します。 つまり、

記事を読む

【PyTorch】GPUのメモリ不足でエラーになったときの対処方法。

PyTorch で深層学習していて、 GPUのメモリ不足でエラーが出てしまったので、対処方法のメモで

記事を読む

【画像認識】 Google画像検索結果を取得する方法 【google image download】

今回は、深層学習(DeepLearning)で画像認識をするための画像データの収集を、Google画

記事を読む

【PyTorch】畳込みニューラルネットワークを構築する方法【CNN】

今回は、PyTorch を使って畳込みニューラルネットワーク(CNN)を構築する方法について紹介しま

記事を読む

【機械学習】 scikit-learn で不正解データを抽出する方法【Python】

Python の scikit-learn ライブラリを使って機械学習でテストデータを識別(2クラス

記事を読む

無料動画編集ソフト AviUtl で mp4 形式の動画を読み込み・出力する方法【Windows】

今回は、無料動画編集ソフト AviUtl で mp4 形式の動画を読み

【Cubase】イヤホンから音がでないときの対処方法。

Cubase でイヤホンから音がでなくなったときの対処方法のメモです。

【Cubase】特定のトラックを無効にする方法。

今回は、Cubaseで特定のトラックのみを無効にする方法について紹介し

【転移学習】学習済みVGG16 による転移学習を行う方法【PyTorch】

今回は、PyTorch を使って、学習済みのモデル VGG16 を用い

【PyTorch】畳込みニューラルネットワークを構築する方法【CNN】

今回は、PyTorch を使って畳込みニューラルネットワーク(CNN)

→もっと見る

PAGE TOP ↑