【転移学習】学習済みVGG16 による転移学習を行う方法【PyTorch】

今回は、PyTorch を使って、学習済みのモデル VGG16 を用いて転移学習をしてみました。

VGG16 は、ImageNet という大量の画像データセットで 1000カテゴリの分類を学習したモデルになります。

この VGG16 モデルに対して転移学習を行って、新たに「アリ」と「ハチ」の画像を学習させます。

Contents

データセットの準備

「アリ」と「ハチ」の画像を PyTorch の公式サイトからダウンロードします。

「data」フォルダ以下にダウンロードされました。

画像に対して前処理を行うクラスを定義します。

画像の処理前と処理後のものを比較してみます。

【処理前】

【処理後】

学習用とテスト用の画像パスをそれぞれリストに格納します。

データセット作成用のクラスを定義します。

データローダーを生成します。

ミニバッチのサイズは 32 となります。

学習済みモデルをロード

学習済みモデル VGG16 をロードします。

出力層をデフォルトの 1000クラス分類から 「ハチ」か「アリ」かの2クラス分類に変更します。

モデルの構成は以下のようになります。

重みを更新するレイヤーを設定します。

param.requires_grad = True」とすると、重みが更新され、

param.requires_grad = False」とすると、重みが更新されなくなります。

今回は、出力層以外は学習しない(重みの更新を行わない)ようにします。

更新するパラメータ名は以下のように出力されます。

VGG16 の出力層のインデックスは「6」であることが確認できます。

学習を行う関数を定義

実際に学習を行う関数を定義します。

学習・検証を行う

実際に学習・検証を実行します。

結果は以下のように、学習は約11秒で終了しました。

おそらく更新するパラメータが少ないため、早く学習が終了したのだと考えられます。

参考文献

 

関連記事

【機械学習・手法比較】決定木とナイーブベイズを比較してみた。

同じデータを使って、教師有り機械学習手法の 決定木(Decision Tree)とナイーブベイズ(N

記事を読む

【PyTorch】畳込みニューラルネットワークを構築する方法【CNN】

今回は、PyTorch を使って畳込みニューラルネットワーク(CNN)を構築する方法について紹介しま

記事を読む

【Weka】フリーの機械学習ソフトをインストールする方法。

Weka は、GUIで使えるフリーの機械学習ソフトです。 https://ja.wikiped

記事を読む

【TensorFlow】GPUを認識しない時の対処方法【Python】

TensorFlow で GPU を認識させようとしたときにハマってしまったので、その対処方法のメモ

記事を読む

【深層学習】 TensorFlow と Keras をインストールする【Python】

今回は、Google Colaboratory 上で、深層学習(DeepLearning)フレームワ

記事を読む

【PyTorch】GPUのメモリ不足でエラーになったときの対処方法。

PyTorch で深層学習していて、 GPUのメモリ不足でエラーが出てしまったので、対処方法のメモで

記事を読む

【探索】縦型・横型・反復深化法の探索手法の比較。

探索とは、チェスや将棋や囲碁などのゲームをコンピュータがプレイするときに、どの手を指すかを決定するの

記事を読む

【Weka】欠損データを自動的に補完するフィルタを使ってみた。

機械学習で用いるデータについてです。データは完璧なことに越したことはないが、通常は、ある属性の値が入

記事を読む

【機械学習】 scikit-learn で不正解データを抽出する方法【Python】

Python の scikit-learn ライブラリを使って機械学習でテストデータを識別(2クラス

記事を読む

【Weka】CSVファイルを読み込んで決定木を実行。

フリーの機械学習ソフト Weka を使って、CSVファイルを読み込んで決定木(Decision Tr

記事を読む

無料動画編集ソフト AviUtl で mp4 形式の動画を読み込み・出力する方法【Windows】

今回は、無料動画編集ソフト AviUtl で mp4 形式の動画を読み

【Cubase】イヤホンから音がでないときの対処方法。

Cubase でイヤホンから音がでなくなったときの対処方法のメモです。

【Cubase】特定のトラックを無効にする方法。

今回は、Cubaseで特定のトラックのみを無効にする方法について紹介し

【転移学習】学習済みVGG16 による転移学習を行う方法【PyTorch】

今回は、PyTorch を使って、学習済みのモデル VGG16 を用い

【PyTorch】畳込みニューラルネットワークを構築する方法【CNN】

今回は、PyTorch を使って畳込みニューラルネットワーク(CNN)

→もっと見る

PAGE TOP ↑