【機械学習】 scikit-learn で不正解データを抽出する方法【Python】

公開日: : 最終更新日:2020/12/25 機械学習 , , , ,

Python の scikit-learn ライブラリを使って機械学習でテストデータを識別(2クラス分類)をしたときに、正解・不正解データを抽出する方法について紹介します。

scikit-learn で正解率を出したいのであれば、混合行列(confusion matrix)を出力すれば良いのですが、これだと正解・不正解データの数しか出力されません。

これだと、不正解データ(正しく識別できなかったデータ)にはどのような傾向があるかなどの分析ができないため、正解・不正解データを抽出したいなぁと思って調べてみました。

Contents

classifier.predict について

scikit-learn の「classifier.predict」関数にテストデータを与えると、識別結果(学習した識別器が判断したラベルデータ)がリストで返ってきます。

今回は「classifier.fit」関数で学習データを用いて既に学習済みの場合を想定しています。

2クラス分類でラベル「same」「diff」の場合の結果が以下のようになります。

これが、与えた全てのテストデータの識別結果となります。

テストデータのうち、正解ラベルが same のものを抽出して識別した結果を抽出したい場合は、以下のようにします。

labels_test_data は、テストデータのラベル(正解ラベル「same」or「diff」が並んだもの)となります。

一方、テストデータのうち、正解ラベルが diff のものを抽出して識別した結果を抽出したい場合は、以下のようにします。

numpy.where を使う

whrere は Numpy ライブラリの関数で、条件を満たす要素の位置(インデックス)を返してくれます。

先ほど抽出した、正解ラベルが same のものと、diff のものを利用します。

例えば「正解ラベルが same のうち、識別結果が diff のもの」を不正解データと扱うことができます。

コードで書くと以下のようになります。

リストの要素の位置(インデックス)が返ってくることが分かります。

不正解データの抽出

以上を踏まえて、不正解データ(識別器が誤って識別してしまったデータ)の抽出を行います。

不正解データのパターンとしては、予測した結果のうち、

  • 正解ラベルが same のうち、識別結果が diff と誤って判断してしまった(False Negative; FN)。
  • 正解ラベルが diff のうち、識別結果が same と誤って判断してしまった(False Positive; FP)。

の2パターンが考えられます。

where で抽出した リストの要素の位置 を iloc で与えて行番号から抽出を行います。

iloc は、pandas で行番号・列番号を指定して抽出を行う関数です。

●False Negative; FN

df_2_class_test は元の(学習に使用した特徴量なども含めた)データです。

●False Positive; FP

正解データの抽出

正解データのパターンとしては、予測した結果のうち、

  • 正解ラベルが same のうち、識別結果が same と正しく判断した(True Positive; TP)。
  • 正解ラベルが diff のうち、識別結果が diff と正しく判断した(True Negative; TN)。

の2パターンが考えられます。

●True Positive; TP

●True Negative; TN

抽出結果

結果が以下のように出力され、それぞれに該当する行と列だけ抽出されていることが確認できます。

関連記事

【探索】縦型・横型・反復深化法の探索手法の比較。

探索とは、チェスや将棋や囲碁などのゲームをコンピュータがプレイするときに、どの手を指すかを決定するの

記事を読む

【PyTorch】GPUのメモリ不足でエラーになったときの対処方法。

PyTorch で深層学習していて、 GPUのメモリ不足でエラーが出てしまったので、対処方法のメモで

記事を読む

【機械学習・手法比較】決定木とナイーブベイズを比較してみた。

同じデータを使って、教師有り機械学習手法の 決定木(Decision Tree)とナイーブベイズ(N

記事を読む

【深層学習】 TensorFlow と Keras をインストールする【Python】

今回は、Google Colaboratory 上で、深層学習(DeepLearning)フレームワ

記事を読む

【Fashion-MNIST】ファッションアイテムのデータセットを使ってみた【TensorFlow】

今回は、機械学習用に公開されているデータセットの1つである「Fashion-MNIST」について紹介

記事を読む

【PyTorch】ニューラルネットワークを構築する方法【NN】

今回は、PyTorch を使って、ニューラルネットワーク(NN)を構築したときのメモです。 フ

記事を読む

【Chainer】手書き数字認識をしてみた【Deep Learning】

Chainerを用いて、ニューラルネットワークを構築し、手書き数字認識を行ったときのメモです。

記事を読む

【転移学習】学習済みVGG16 による転移学習を行う方法【PyTorch】

今回は、PyTorch を使って、学習済みのモデル VGG16 を用いて転移学習をしてみました。

記事を読む

【機械学習】 scikit-learn で精度・再現率・F値を算出する方法【Python】

今回は、2クラス分類で Python の scikit-learn を使った評価指標である、精度(P

記事を読む

【Weka】フリーの機械学習ソフトをインストールする方法。

Weka は、GUIで使えるフリーの機械学習ソフトです。 https://ja.wikiped

記事を読む

無料動画編集ソフト AviUtl で mp4 形式の動画を読み込み・出力する方法【Windows】

今回は、無料動画編集ソフト AviUtl で mp4 形式の動画を読み

【Cubase】イヤホンから音がでないときの対処方法。

Cubase でイヤホンから音がでなくなったときの対処方法のメモです。

【Cubase】特定のトラックを無効にする方法。

今回は、Cubaseで特定のトラックのみを無効にする方法について紹介し

【転移学習】学習済みVGG16 による転移学習を行う方法【PyTorch】

今回は、PyTorch を使って、学習済みのモデル VGG16 を用い

【PyTorch】畳込みニューラルネットワークを構築する方法【CNN】

今回は、PyTorch を使って畳込みニューラルネットワーク(CNN)

→もっと見る

PAGE TOP ↑