(19)【発行国】日本国特許庁(JP)
(12)【公報種別】公開特許公報(A)
(11)【公開番号】P2024104494
(43)【公開日】2024-08-05
(54)【発明の名称】モデル修正プログラム、モデル修正方法および情報処理装置
(51)【国際特許分類】
G06N 20/00 20190101AFI20240729BHJP
【FI】
G06N20/00 130
【審査請求】未請求
【請求項の数】6
【出願形態】OL
(21)【出願番号】P 2023008724
(22)【出願日】2023-01-24
(71)【出願人】
【識別番号】000005223
【氏名又は名称】富士通株式会社
(74)【代理人】
【識別番号】110002918
【氏名又は名称】弁理士法人扶桑国際特許事務所
(72)【発明者】
【氏名】中川 尊雄
(72)【発明者】
【氏名】徳本 晋
(57)【要約】
【課題】機械学習モデルの修正後の精度を向上させる。
【解決手段】情報処理装置10は、訓練済みの機械学習モデル14に第1入力データ群13を入力し、第1入力データ群13のデータのそれぞれのクラス分類の推論結果を示すクラス確率分布を生成する。情報処理装置10は、クラス確率分布から推論結果と正解との誤差を算出するとともに、クラス確率分布におけるクラス間のクラス確率の差に基づいて、機械学習モデル14の混乱の度合いを示す混乱度を算出する。情報処理装置10は、誤差と混乱度に基づいて、第1入力データ群13から除去対象とする第1データを特定し、第1データを第1入力データ群13から除去した第2入力データ群16に基づいて、機械学習モデル14を修正する。
【選択図】
図1
【特許請求の範囲】
【請求項1】
訓練済みの機械学習モデルと第1入力データ群を取得し、
前記機械学習モデルに前記第1入力データ群を入力して、前記第1入力データ群に含まれるデータのそれぞれのクラス分類の推論結果を示すクラス確率分布を生成し、
前記クラス確率分布から前記推論結果と正解との誤差を算出するとともに、前記クラス確率分布におけるクラス間のクラス確率の差に基づいて、前記機械学習モデルの混乱の度合いを示す混乱度を算出し、
前記誤差と前記混乱度に基づいて、前記第1入力データ群から除去対象とする第1データを特定し、
前記第1データを前記第1入力データ群から除去した第2入力データ群に基づいて、前記機械学習モデルを修正する、
処理をコンピュータに実行させるモデル修正プログラム。
【請求項2】
前記混乱度は、前記クラス確率分布において、前記クラス確率が最も大きい第1クラスの前記クラス確率と、前記クラス確率が2番目に大きい第2クラスの前記クラス確率との差が大きくなるほど小さくなる、
請求項1記載のモデル修正プログラム。
【請求項3】
前記第1データを特定する処理は、
前記誤差の大きい順に、前記第1入力データ群から第1データ件数分の第2データを抽出し、
前記混乱度の大きい順に、前記第1入力データ群から第2データ件数分の第3データを抽出する、
処理を含み、
前記第2データであり、かつ前記第3データである前記データが、前記第1データである、
請求項1記載のモデル修正プログラム。
【請求項4】
前記第1データを特定する処理は、さらに、
前記第2データであり、かつ前記第3データではなく、前記機械学習モデルに入力した場合に推論が失敗する前記データを、前記第1データとして追加する、処理を含む請求項3記載のモデル修正プログラム。
【請求項5】
訓練済みの機械学習モデルと第1入力データ群を取得し、
前記機械学習モデルに前記第1入力データ群を入力して、前記第1入力データ群に含まれるデータのそれぞれのクラス分類の推論結果を示すクラス確率分布を生成し、
前記クラス確率分布から前記推論結果と正解との誤差を算出するとともに、前記クラス確率分布におけるクラス間のクラス確率の差に基づいて、前記機械学習モデルの混乱の度合いを示す混乱度を算出し、
前記誤差と前記混乱度に基づいて、前記第1入力データ群から除去対象とする第1データを特定し、
前記第1データを前記第1入力データ群から除去した第2入力データ群に基づいて、前記機械学習モデルを修正する、
処理をコンピュータが実行するモデル修正方法。
【請求項6】
訓練済みの機械学習モデルと第1入力データ群を記憶する記憶部と、
前記機械学習モデルと前記第1入力データ群を取得し、前記機械学習モデルに前記第1入力データ群を入力して、前記第1入力データ群に含まれるデータのそれぞれのクラス分類の推論結果を示すクラス確率分布を生成し、前記クラス確率分布から前記推論結果と正解との誤差を算出するとともに、前記クラス確率分布におけるクラス間のクラス確率の差に基づいて、前記機械学習モデルの混乱の度合いを示す混乱度を算出し、前記誤差と前記混乱度に基づいて、前記第1入力データ群から除去対象とする第1データを特定し、前記第1データを前記第1入力データ群から除去した第2入力データ群に基づいて、前記機械学習モデルを修正する処理部と、
を有する情報処理装置。
【発明の詳細な説明】
【技術分野】
【0001】
本発明はモデル修正プログラム、モデル修正方法および情報処理装置に関する。
【背景技術】
【0002】
機械学習モデルは、画像認識、音声認識、機械翻訳などの様々な分野で用いられている。機械学習モデルは、例えば、深層学習によって訓練させるニューラルネットワークである。
【0003】
訓練済みの機械学習モデルは、修正されることがある。例えば、機械学習モデルを用いたシステムの運用時において不具合が生じた場合などに、修正が行われる。機械学習モデルを修正するために、さらなる訓練データが用いられて再訓練されることがある。しかし、CACE(Changing Anything Changes Everything)原理により、再訓練によって、かえって機械学習モデルの性能が劣化する可能性がある。例えば、訓練データを増やして、機械学習モデルの全体の精度を上げたとしても、特定のデータに対しては逆に正しい推論ができなくなる場合がある。
【0004】
そこで、再訓練を行わずに、機械学習モデルの修正を行う手法が提案されている。この手法では、訓練過程で推論に成功したデータ(以下、成功データという)と、訓練過程において推論に失敗したデータ(以下、失敗データという)が用いられる。まず、成功データを機械学習モデルに入力した場合の出力に影響せず、失敗データを機械学習モデルに入力した場合の出力に影響し、訓練過程で値が大きく変化した機械学習モデルのパラメータが特定される。そして、特定されたパラメータを調整することで、正しい振舞いへの影響が少なく誤った振舞いに影響が限定されるように機械学習モデルが修正される。
【0005】
なお、教師なしデータの中から外れ値を検出し、異常か否かを示すラベルを教師なしデータに付与して教師ありデータを生成し、異常を判定するルールを教師ありデータを用いて訓練する外れ値検出装置が提案されている。また、多層畳み込みニューラルネットワークの損失関数を、誤差の統計情報に基づいて動的に調整する方法が提案されている。
【0006】
また、好ましくないイベントに関連する異常値を訓練データから除去し、除去後の訓練データを用いて機械学習モデルを訓練するシステムが提案されている。また、全てのデータレコードを用いて予測モデルを訓練し、予測モデルを用いて各データレコードの異常スコアを算出し、外れ値を判定するための閾値を決定する判定装置が提案されている。
【先行技術文献】
【特許文献】
【0007】
【特許文献1】特開2003-5970号公報
【特許文献2】米国特許出願公開第2017/0300811号明細書
【特許文献3】米国特許出願公開第2018/0081913号明細書
【特許文献4】特開2018-190127号公報
【非特許文献】
【0008】
【非特許文献1】Shogo Tokui, Susumu Tokumoto, Akihito Yoshii, Fuyuki Ishikawa, Takao Nakagawa, Kazuki Munakata and Shinji Kikuchi "NeuRecover: Regression-Controlled Repair of Deep Neural Networks with Training History", Proc. of the 29th IEEE International Conference on Software Analysis, Evolution and Reengineering (SANER 2022), pp. 1111-1121, March 2022
【発明の概要】
【発明が解決しようとする課題】
【0009】
訓練済の機械学習モデルを修正するときに用いられる成功データと失敗データには、外れ値が含まれる可能性がある。外れ値として、例えば、誤ったラベル付けがされたデータ、判別不能なデータ、ノイズが大きすぎるデータなどがある。
【0010】
このような外れ値が含まれる場合、無意味な修正のため、これまで推論に成功したデータの推論に失敗する現象(退行と呼ばれることもある)が発生することがある。または、その外れ値を成功データまたは失敗データとして維持することで、修正が進まなくなることがある。つまり、このような外れ値は機械学習モデルの修正後の精度を悪化させることがある。
【0011】
そこで、1つの側面では、本発明は、機械学習モデルの修正後の精度を向上させることを目的とする。
【課題を解決するための手段】
【0012】
1つの態様では、以下の処理をコンピュータに実行させるモデル修正プログラムが提供される。訓練済みの機械学習モデルと第1入力データ群を取得する。機械学習モデルに第1入力データ群を入力して、第1入力データ群に含まれるデータのそれぞれのクラス分類の推論結果を示すクラス確率分布を生成する。クラス確率分布から推論結果と正解との誤差を算出するとともに、クラス確率分布におけるクラス間のクラス確率の差に基づいて、機械学習モデルの混乱の度合いを示す混乱度を算出する。誤差と混乱度に基づいて、第1入力データ群から除去対象とする第1データを特定し、第1データを第1入力データ群から除去した第2入力データ群に基づいて、機械学習モデルを修正する。
【0013】
また、1つの態様では、コンピュータが実行するモデル修正方法が提供される。また、1つの態様では、記憶部と処理部とを有する情報処理装置が提供される。
【発明の効果】
【0014】
1つの側面では、機械学習モデルの修正後の精度が向上する。
【図面の簡単な説明】
【0015】
【
図1】第1の実施の形態の情報処理装置を説明するための図である。
【
図3】正解と推論結果のクラス確率分布の例を示す図である。
【
図5】成功データが得られる推論結果と損失および混乱度の関係の例を示す図である。
【
図6】失敗データが得られる推論結果と損失および混乱度の関係の例を示す図である。
【
図7】情報処理装置のハードウェア例を示すブロック図である。
【
図8】情報処理装置の機能例を示すブロック図である。
【
図9】混乱度を算出するプログラムコードの一例を示す図である。
【
図12】機械学習モデル修正の手順例を示すフローチャートである。
【
図13】機械学習モデルの修正後の精度の例を示す図である。
【
図14】機械学習モデルの修正後の修正率と退行率の比を示す図である。
【発明を実施するための形態】
【0016】
以下、本実施の形態を図面を参照して説明する。
[第1の実施の形態]
第1の実施の形態を説明する。
【0017】
図1は、第1の実施の形態の情報処理装置を説明するための図である。
第1の実施の形態の情報処理装置10は、訓練済の機械学習モデルの修正を行う。機械学習モデルの訓練は、情報処理装置10が行ってもよいし他の情報処理装置が行ってもよい。情報処理装置10は、クライアント装置でもよいしサーバ装置でもよい。情報処理装置10は、コンピュータと呼ばれてもよい。
【0018】
情報処理装置10は、記憶部11および処理部12を有する。記憶部11は、RAM(Random Access Memory)などの揮発性半導体メモリでもよいし、HDD(Hard Disk Drive)やフラッシュメモリなどの不揮発性ストレージでもよい。処理部12は、例えば、CPU(Central Processing Unit)、GPU(Graphics Processing Unit)、DSP(Digital Signal Processor)などのプロセッサである。ただし、処理部12が、ASIC(Application Specific Integrated Circuit)やFPGA(Field Programmable Gate Array)などの電子回路を含んでもよい。プロセッサは、例えば、RAMなどのメモリ(記憶部11でもよい)に記憶されたプログラムを実行する。プロセッサの集合が、マルチプロセッサまたは単に「プロセッサ」と呼ばれてもよい。
【0019】
記憶部11は、第1入力データ群13、訓練済みの機械学習モデル14を記憶する。
第1入力データ群13は、機械学習モデル14に入力した場合に推論が成功するデータ(以下、成功データという)と、推論に失敗するデータ(以下、失敗データという)を含む。より具体的には、成功データは、機械学習モデル14に入力したときに、正しいクラスに分類されるデータである。失敗データは、機械学習モデル14に入力したときに、正しいクラスに分類されないデータである。これらのデータは、例えば、機械学習モデル14の訓練時および機械学習モデル14を用いたシステムの運用時において、取得される。なお、各データに対する正解のクラスを示す情報についても、記憶部11に記憶されていてよい。
【0020】
修正対象の機械学習モデル14は、第1入力データ群13に対して推論処理を行う機械学習モデルである。機械学習モデル14は、訓練データを用いて訓練済みである。推論処理は、第1入力データ群13に含まれる各データのクラス分類を行うことを含む。クラス分類の一例として、車両の画像データに基づく、当該車両の車種の分類がある。この場合、車両の画像データが第1入力データ群13に含まれるデータの一例であり、車両の車種がクラスの一例である。機械学習モデル14は、ニューラルネットワークであってもよい。ニューラルネットワークは、畳み込み層、プーリング層および全結合層を含んでもよい。
【0021】
処理部12は、第1入力データ群13と機械学習モデル14を取得する。そして、処理部12は、機械学習モデル14に第1入力データ群13を入力して、第1入力データ群13に含まれる各データのクラス分類の推論結果を生成する。推論結果は、クラス確率分布で示される。
【0022】
図1には第1入力データ群13に含まれるあるデータに対するクラス分類の推論結果の例が示されている。クラス確率分布は、各クラス(
図1の例ではクラスA,B,C)のクラス確率で表される。
図1の例ではクラス確率は、0から1の値で表されている。クラスA,Cのクラス確率は0.3、クラスBのクラス確率は0.4である。また、
図1の例では、正解であるクラスBのクラス確率が1で表されている。なお、クラス確率は、0~100%の値で表されていてもよい。
【0023】
また、処理部12は、クラス確率分布から推論結果と正解との誤差を算出する。このような誤差は、損失とも呼ばれる。以下では、誤差の代わりに損失と呼ぶことにする。損失は、各種の損失関数を用いて算出できる。損失関数として、例えば、交差エントロピー、スパースカテゴリカル交差エントロピー、二乗誤差などを用いることができる。
【0024】
さらに、処理部12は、クラス確率分布におけるクラス間のクラス確率の差に基づいて、機械学習モデル14の混乱の度合いを示す混乱度を算出する。混乱度は、例えば、クラス確率が最も大きい第1クラスのクラス確率と、クラス確率が2番目に大きい第2クラスのクラス確率との差に応じて定められ、差が大きくなるほど小さい値とする。混乱度は、例えば、クラス確率が最も大きい第1クラスのクラス確率と、クラス確率が2番目に大きい第2クラスのクラス確率との差の逆数である。
図1のような推論結果が得られている場合、第1クラスはクラスBであり、そのクラス確率は0.4である。第2クラスはクラスA,Cであり、これらのクラス確率は0.3である。したがって、混乱度は、1/(0.4-0.3)=10、と算出できる。
【0025】
処理部12は、算出した損失と混乱度に基づいて、第1入力データ群13から除去対象とする第1データ(以下、除去対象データという)15を特定する。除去対象データ15の特定は、例えば、以下のように行うことができる。
【0026】
処理部12は、損失の大きい順に、第1入力データ群13から第1データ件数分の第2データを抽出する。同様に処理部12は、混乱度が大きい順に、第1入力データ群13から第2データ件数分の第3データを抽出する。そして、処理部12は、損失と混乱度の両方に関して抽出されたデータ、すなわち、第2データであり、かつ第3データであるデータを、除去対象データ15として特定する。このような除去対象データ15は、極端に分類が難しく、成功データであっても機械学習モデル14を極度に混乱させ、修正を困難にする外れ値である可能性が高い。
【0027】
なお、除去対象データ15を特定する方法は、上記の方法に限定されない。例えば、処理部12は、第1入力データ群13に含まれる各データに対して得られた損失と混乱度の調和平均を算出し、調和平均の値が大きい順に所定数のデータを除去対象データ15として特定してもよい。
【0028】
また、処理部12は、除去対象データ15のデータ件数などに応じて、上記の第2データであって、かつ第3データではない失敗データを、除去対象データ15として追加してもよい。このとき、処理部12は、ユーザによる指定に基づいて、第2データであって、かつ第3データではない失敗データを、除去対象データ15として追加するか否かを決定してもよい。
【0029】
図2は、外れ値の一例を説明する図である。
図2では、2次元の特徴空間上にデータがプロットされている。
データ20aなど白丸で表されているデータは、クラス20に分類される成功データである。データ21aなど白の星印で表されているデータは、クラス21に分類される成功データである。データ20c,21b,21cは、失敗データである。分類境界22は、機械学習モデルの訓練によって決まる。機械学習モデルの修正は、この分類境界22を変更することに相当する。
【0030】
図2において、データ20cは失敗データであるが、クラス21の成功データの分布に比較的近い位置にある。このためデータ20cは、誤ったラベル付けがされているなどの理由から外れ値になっている可能性がある。このようなデータ20cに対しても適切に分類ができるように機械学習モデル14を修正すると、退行を引き起こす可能性がある。
【0031】
データ20bは、成功データであるが、クラス20の他の成功データの分布から極端に離れている。このようなデータ20bはノイズが大きすぎるデータなどの外れ値である可能性がある。このようなデータ20bを維持して、機械学習モデル14を修正しても適切に修正が行われない可能性がある。そのため、このような外れ値を除去対象データ15として特定して、除去することが望ましい。
【0032】
処理部12は、除去対象データ15を特定すると、除去対象データ15を、第1入力データ群13から除去した第2入力データ群16に基づいて、機械学習モデル14を修正する。処理部12は、第2入力データ群16を出力してもよい。処理部12は、第2入力データ群16を不揮発性ストレージに保存してもよいし、表示装置に表示してもよいし、他の情報処理装置に送信してもよい。第2入力データ群16は、成功データと失敗データを含む。機械学習モデル14の修正は、例えば、以下のように行われる。
【0033】
処理部12は、成功データを機械学習モデル14に入力した場合の出力に影響せず、失敗データを機械学習モデル14に入力した場合の出力に影響する、機械学習モデル14の重み値を特定する。機械学習モデル14が、畳み込み層や、全結合層を含むニューラルネットワークである場合、重み値は、畳み込み層または全結合層の重み値である。
【0034】
処理部12は、特定した重み値に対して、例えば、粒子群最適化を行うことで、正しい振舞いへの影響が少なく誤った振舞いに影響が限定されるように機械学習モデル14を修正する。粒子群最適化の例については後述する(
図11参照)。
【0035】
処理部12は、修正後の機械学習モデルである機械学習モデル17を出力する。処理部12は、機械学習モデル17を不揮発性ストレージに保存してもよいし、表示装置に表示してもよいし、他の情報処理装置に送信してもよい。
【0036】
以上説明したように、第1の実施の形態の情報処理装置10は、訓練済みの機械学習モデル14と第1入力データ群13を取得する。情報処理装置10は、機械学習モデル14に第1入力データ群13を入力して、第1入力データ群13に含まれる各データのクラス分類の推論結果を示すクラス確率分布を生成する。情報処理装置10は、クラス確率分布から損失を算出するとともに、クラス確率分布におけるクラス間のクラス確率の差に基づいて、機械学習モデル14の混乱度を算出する。情報処理装置10は、損失と混乱度に基づいて、第1入力データ群13から除去対象とする除去対象データ15を特定する。そして、情報処理装置10は、除去対象データ15を第1入力データ群13から除去した第2入力データ群16に基づいて、機械学習モデル14を修正する。
【0037】
これにより、機械学習モデル14の修正後の精度を悪化させる外れ値を除去した第2入力データ群16を用いて、機械学習モデル14の修正が行われるため、修正後の精度が向上する。
【0038】
(比較例1:VAE(Variational Auto Encoder)を用いた外れ値の除去)
比較例1の外れ値の除去方法は、機械学習モデルの一例であるVAEの訓練時に、外れ値を除去する方法である。VAEは、エンコーダとデコーダとを有し、元画像をエンコードして潜在空間上に畳み込み、畳み込まれた情報をデコードして再現画像を得る。エンコーダとデコーダにはニューラルネットワークが用いられる。
【0039】
外れ値を特定するために、VAEに入力される元画像と、VAEから出力される再現画像とが画素レベルで比較される。両画像の差が大きいほど、元画像が外れ値の可能性が高いと判定される。訓練データからこのような外れ値を除去することで、VAEの精度や汎化性能が向上する。
【0040】
しかし、VAEを用いた外れ値の検出方法を、訓練済みの機械学習モデルの修正に用いる場合、適切に外れ値を特定できない。
例えば、訓練データと、運用時の入力データとの間にデータドリフトがある場合、VAEを用いた外れ値の検出方法の適用は難しい。具体的には、訓練データとして車両の側面画像が多く用いられ、運用時の入力データとして車両の正面画像が多く用いられた場合、正面画像は、ほとんど外れ値として特定されてしまう可能性がある。VAEは、訓練時にあまり用いられなかった正面画像を適切に再現できないためである。
【0041】
また、VAEを用いた外れ値の検出方法では、分類に寄与しづらい特徴についての外れ値も検出してしまう場合がある。例えば、バスとセダンを分類する問題において、車両に施された模様の有無などは車両の形状とは関係なく、分類に寄与しない可能性が高い。しかし、VAEを用いた外れ値の検出方法では、模様の有無によって外れ値を特定してしまう可能性がある。
【0042】
第1の実施の形態の情報処理装置10では、損失と混乱度に基づいて、除去対象データ15を特定するため、分類に寄与しない特徴によって外れ値が特定されることが抑制される。
【0043】
(比較例2:損失に応じた外れ値の除去)
比較例2の外れ値の除去方法は、損失が大きすぎるデータを外れ値として特定し、除去する方法である。
【0044】
図3(A)~
図3(C)は、正解と推論結果のクラス確率分布の例を示す図である。
図3(A)の正解のクラス確率分布では、クラスAのクラス確率が1.0、クラスB,Cのクラス確率は0である。
図3(A)の推論結果のクラス確率分布では、クラスAのクラス確率が0.25、クラスBのクラス確率は0.65、クラスCのクラス確率は0.1である。
【0045】
図3(B)の正解のクラス確率分布では、クラスBのクラス確率が1.0、クラスA,Cのクラス確率は0である。
図3(B)の推論結果のクラス確率分布では、クラスA,Cのクラス確率が0.1、クラスBのクラス確率は0.8である。
【0046】
図3(C)の正解のクラス確率分布では、クラスBのクラス確率が1.0、クラスA,Cのクラス確率は0である。
図3(C)の推論結果のクラス確率分布では、クラスA,Cのクラス確率が0.3、クラスBのクラス確率は0.4である。
【0047】
図3(A)や
図3(C)では、損失が大きくなる。損失が大きいデータは分類が困難なデータであり、誤ったラベル付けがされているなどの外れ値である可能性がある。
しかし、損失から外れ値を特定する場合、本質的に判別が難しいクラスが存在すると、そのクラスのデータばかり外れ値として特定されて、除去されてしまう可能性がある。
【0048】
図4は、判別が難しいクラスの一例を示す図である。
図4では、2次元の特徴空間上において、3つのクラス25,26,27の何れかに分類されたデータがプロットされている。
【0049】
データ25aなど白丸で表されているデータは、クラス25に分類される成功データである。データ26a,26bなど白の星印で表されているデータは、クラス26に分類される成功データである。データ27a,27bなど白の三角で表されているデータは、クラス27に分類される成功データである。分類境界28,29は、機械学習モデルの訓練によって決まる。
【0050】
図4の例では、クラス26に分類されるデータとクラス27に分類されるデータは特徴空間上での位置が近く、判別が難しい。このため、特に分類境界29付近のデータ(例えば、データ26b,27b)は、損失が大きくなりやすい。したがって、分類境界29付近のデータは、外れ値として見なされ除去されてしまう可能性がある。
【0051】
分類境界29付近には、所属クラスが判別しづらく退行の可能性があるため維持したい成功データや、機械学習モデルの修正により適切に判別される見込みがある失敗データが含まれる場合がある。このような有益なデータを除去することは望ましくない。
【0052】
第1の実施の形態の情報処理装置10では、損失だけではなく、損失と混乱度に基づいて、除去対象データ15を特定するため、上記のような有益なデータが外れ値として除去されることが抑制される。
【0053】
(比較例3:最大クラス確率に基づく外れ値の除去)
比較例3の外れ値の除去方法は、クラス確率の最大値が小さいデータを外れ値として特定し、除去する方法である。このようなデータは、機械学習モデルがデータの所属クラスについて確信していないと見なすことができる。
【0054】
しかし、比較例3の外れ値の除去方法でも、比較例2の外れ値の除去方法と同様に、本質的に判別が難しいクラスが存在すると、そのクラスのデータばかり外れ値として除去されてしまう可能性がある。失敗データに関しては、クラス確率の最大値が小さいほど、機械学習モデルの修正により、適切に判別される見込みがある。このような失敗データまで、外れ値として除去することは適切ではない。
【0055】
第1の実施の形態の情報処理装置10では、損失と混乱度に基づいて、除去対象データ15を特定するため、機械学習モデルの修正により適切に判別される見込みがある失敗データが外れ値として除去されることが抑制される。
【0056】
(推論結果と損失および混乱度の関係の例)
図5は、成功データが得られる推論結果と損失および混乱度の関係の例を示す図である。
図6は、失敗データが得られる推論結果と損失および混乱度の関係の例を示す図である。
【0057】
図5、
図6のそれぞれには、損失と混乱度の両方が大きくなる推論結果、損失と混乱度の両方が小さくなる推論結果、損失が大きく混乱度が小さくなる推論結果、損失が小さく混乱度が大きくなる推論結果が示されている。損失は、交差エントロピーの値で表されている。混乱度は、クラス確率が最も大きい第1クラスのクラス確率と、クラス確率が2番目に大きい第2クラスのクラス確率との差の逆数である。
【0058】
なお、損失と混乱度の値は、小数点第三位までで表されている(小数点第四位は四捨五入)。損失と混乱度は非独立の関係にあるため、各推論結果の間の損失と混乱度の大きさは、相対的なものである。
【0059】
図5において、損失と混乱度の両方が大きくなる推論結果が得られる成功データは、クラスBのクラス確率が0.35で最大である。しかし、最大のクラス確率は、2番目に大きいクラス確率(0.25)との差が比較的小さいだけでなく、他のクラス確率(0.2)との差も比較的小さい。つまり、機械学習モデルは2つのクラスA,Bの間だけではなく、全クラスA~D間で分類を迷っていたと考えられる。このような成功データは、損失が大きいだけでなく、機械学習モデルを極度に混乱させ、修正を困難にする外れ値である可能性が高い。
【0060】
図5において、損失は小さいが混乱度が大きくなる推論結果が得られる成功データは、クラスBのクラス確率が0.55で最大である。最大のクラス確率は、2番目に大きいクラス確率(0.45)との差が比較的小さいものの、他のクラスのクラス確率との差が大きい。このような成功データは、クラスA,Bの間で分類が難しいにも関わらず成功したデータであるため、維持することが望ましい。
【0061】
図5において、損失は大きいが混乱度が小さくなる推論結果が得られる成功データは、最大のクラス確率が0.46であり、比較的小さいものの、他のクラス確率(0.18)との差は大きい。このような成功データも、機械学習モデルはある程度確信をもって分類できたデータであると考えられるため、維持することが望ましい。
【0062】
図5において、損失と混乱度の両方が小さくなる推論結果が得られる成功データは、クラスBのクラス確率が0.6で最大である。最大のクラス確率は、他のクラス確率との差が比較的大きい。このような成功データは、容易に分類できたデータといえる。
【0063】
図6において、損失と混乱度の両方が大きくなる推論結果が得られる失敗データは、クラスCのクラス確率が0.35で最大である。さらに、最大のクラス確率は、2番目に大きいクラスDのクラス確率(0.25)との差が比較的小さい。機械学習モデルは、正解ではないクラスC,Dの間で分類を迷っていたと考えられる。このような失敗データは、機械学習モデルを見当違いな分類結果の間で混乱させており、修正を困難にする外れ値である可能性が高い。
【0064】
図6において、損失は小さいが混乱度が大きくなる推論結果が得られる失敗データは、クラスCのクラス確率が0.55で最大である。最大のクラス確率は、2番目に大きいクラスBのクラス確率(0.45)との差が比較的小さい。クラスBは正解のクラスである。また、最大のクラス確率は、他のクラスA,Dのクラス確率(0)との差が大きい。このような失敗データは、わずかの差で分類に失敗したデータであると考えられるため、維持することが望ましい。
【0065】
図6において、損失は大きいが混乱度が小さくなる推論結果が得られる失敗データは、最大のクラス確率が0.7であり大きい。また、最大のクラス確率と他のクラス確率(0.1)との差は大きい。このような失敗データは、機械学習モデルが確信をもって分類したにもかかわらず分類に失敗したデータであり、修正を困難にする外れ値である可能性がある。
【0066】
図6において、損失と混乱度の両方が小さくなる推論結果が得られる失敗データは、クラスCのクラス確率が0.45で最大である。最大のクラス確率は、他のクラス確率との差が比較的大きい。このような失敗データは、機械学習モデルがある程度確信をもって分類したにもかかわらず分類に失敗したデータであるが、損失は相対的に大きくない。このため、機械学習モデルを混乱させる有害な外れ値であるとは言い切れない。なお、混乱度が小さければ、損失は大きくなる傾向にあるため、このような失敗データが出現することは稀である。
【0067】
図1に示した第1の実施の形態の情報処理装置10は、損失と混乱度に基づいて、除去対象データ15を特定するため、上記のような維持することが望ましいデータを除去対象とせず、有害な外れ値となる可能性の高いデータを適切に特定できる。
【0068】
なお、損失は大きいが混乱度が小さくなる推論結果が得られる失敗データは、
図5のように2番目にクラス確率が大きくなるクラスが正解となる場合もある。処理部12は、このような失敗データについては、例えば、ユーザによる指定に基づいて、除去対象データ15として追加するか否かを決定してもよい。
【0069】
[第2の実施の形態]
次に、第2の実施の形態を説明する。
第2の実施の形態の情報処理装置100は、訓練済の機械学習モデルの修正を行う。機械学習モデルの訓練は、情報処理装置100が行ってもよいし他の情報処理装置が行ってもよい。情報処理装置10は、クライアント装置でもよいしサーバ装置でもよい。情報処理装置100は、コンピュータと呼ばれてもよい。
【0070】
図7は、情報処理装置のハードウェア例を示すブロック図である。
情報処理装置100は、バスに接続されたCPU101、RAM102、HDD103、GPU104、入力インタフェース105、媒体リーダ106および通信インタフェース107を有する。CPU101は、第1の実施の形態の処理部12に対応する。RAM102またはHDD103は、第1の実施の形態の記憶部11に対応する。
【0071】
CPU101は、プログラムの命令を実行するプロセッサである。CPU101は、HDD103に記憶されたプログラムおよびデータをRAM102にロードし、プログラムを実行する。情報処理装置100は、複数のプロセッサを有してもよい。
【0072】
RAM102は、CPU101で実行されるプログラムおよびCPU101で演算に使用されるデータを一時的に記憶する揮発性半導体メモリである。情報処理装置100は、RAM以外の種類の揮発性メモリを有してもよい。
【0073】
HDD103は、オペレーティングシステム(OS:Operating System)やミドルウェアやアプリケーションソフトウェアなどのソフトウェアのプログラムと、データとを記憶する不揮発性ストレージである。情報処理装置100は、フラッシュメモリやSSD(Solid State Drive)などの他の種類の不揮発性ストレージを有してもよい。
【0074】
GPU104は、CPU101と連携して画像処理を行い、情報処理装置100に接続された表示装置111に画像を出力する。表示装置111は、例えば、CRT(Cathode Ray Tube)ディスプレイ、液晶ディスプレイ、有機EL(Electro Luminescence)ディスプレイまたはプロジェクタである。情報処理装置100に、プリンタなどの他の種類の出力デバイスが接続されてもよい。
【0075】
また、GPU104は、GPGPU(General Purpose Computing on Graphics Processing Unit)として使用されてもよい。GPU104は、CPU101からの指示に応じてプログラムを実行し得る。情報処理装置100は、RAM102以外の揮発性半導体メモリをGPUメモリとして有してもよい。
【0076】
入力インタフェース105は、情報処理装置100に接続された入力デバイス112から入力信号を受け付ける。入力デバイス112は、例えば、マウス、タッチパネルまたはキーボードである。情報処理装置100に複数の入力デバイスが接続されてもよい。
【0077】
媒体リーダ106は、記録媒体113に記録されたプログラムおよびデータを読み取る読み取り装置である。記録媒体113は、例えば、磁気ディスク、光ディスクまたは半導体メモリである。磁気ディスクには、フレキシブルディスク(FD:Flexible Disk)およびHDDが含まれる。光ディスクには、CD(Compact Disc)およびDVD(Digital Versatile Disc)が含まれる。媒体リーダ106は、記録媒体113から読み取られたプログラムおよびデータを、RAM102やHDD103などの他の記録媒体にコピーする。読み取られたプログラムは、CPU101によって実行されることがある。
【0078】
記録媒体113は、可搬型記録媒体であってもよい。記録媒体113は、プログラムおよびデータの配布に用いられることがある。また、記録媒体113およびHDD103が、コンピュータ読み取り可能な記録媒体と呼ばれてもよい。
【0079】
通信インタフェース107は、ネットワーク114を介して他の情報処理装置と通信する。通信インタフェース107は、スイッチやルータなどの有線通信装置に接続される有線通信インタフェースでもよいし、基地局やアクセスポイントなどの無線通信装置に接続される無線通信インタフェースでもよい。
【0080】
次に、情報処理装置100の機能および処理手順について説明する。
図8は、情報処理装置の機能例を示すブロック図である。
情報処理装置100は、第1入力データ群記憶部121、モデル記憶部122、高損失データ記憶部123、高混乱データ記憶部124、第2入力データ群記憶部125、修正済モデル記憶部126を有する。これらの記憶部は、例えば、RAM102またはHDD103を用いて実装される。
【0081】
また、情報処理装置100は、高損失データ抽出部127、高混乱データ抽出部128、外れ値除去部129、モデル修正部130を有する。これらの処理部は、例えば、CPU101およびプログラムを用いて実装される。なお、第2入力データ群記憶部125、モデル修正部130および修正済モデル記憶部126は、他の情報処理装置に分離されていてもよい。
【0082】
第1入力データ群記憶部121は、第1入力データ群を記憶する。第1入力データ群は、訓練済みの機械学習モデルに入力した場合に推論が成功する成功データと、推論に失敗する失敗データを含む。第1入力データ群は、ユーザにより情報処理装置100に保存されてもよく、他の情報処理装置から情報処理装置100に転送されてもよい。
【0083】
モデル記憶部122は、訓練済み機械学習モデルを記憶する。機械学習モデルは、ニューラルネットワークであってもよい。ニューラルネットワークは、畳み込み層、プーリング層および全結合層を含んでもよい。機械学習モデルは、ユーザにより情報処理装置100に保存されてもよく、他の情報処理装置から情報処理装置100に転送されてもよい。
【0084】
高損失データ記憶部123は、第1入力データ群から高損失データ抽出部127により抽出された高損失データを特定する情報(例えば、識別番号)を記憶する。高損失データは、第1入力データ群に含まれる成功データと失敗データのうち、損失の大きさが上位のN(Nは2以上の整数)件に含まれるデータである。
【0085】
高混乱データ記憶部124は、第1入力データ群から高混乱データ抽出部128により抽出された高混乱データを特定する情報(例えば、識別番号)を記憶する。高混乱データは、第1入力データ群に含まれる成功データと失敗データのうち、混乱度の大きさが上位のN件に含まれるデータである。
【0086】
第2入力データ群記憶部125は、第2入力データ群を記憶する。第2入力データ群は、外れ値除去部129が、第1入力データ群から除去対象データ(外れ値)を除去したデータ群である。
【0087】
修正済モデル記憶部126は、モデル修正部130が修正した機械学習モデルを記憶する。
高損失データ抽出部127は、第1入力データ群記憶部121に記憶された第1入力データ群を、モデル記憶部122に記憶された機械学習モデルに入力する。そして、高損失データ抽出部127は、第1入力データ群に含まれる各データのクラス分類の推論結果を示すクラス確率分布を生成する。また、高損失データ抽出部127は、クラス確率分布から推論結果と正解との誤差である損失を、各データについて算出する。損失は、各種の損失関数を用いて算出できる。損失関数として、例えば、交差エントロピー、スパースカテゴリカル交差エントロピー、二乗誤差などを用いることができる。さらに、高損失データ抽出部127は、第1入力データ群に含まれる成功データと失敗データを、損失の大きさ順に配列し、損失の大きさが上位のN件を、高損失データとして抽出する。高損失データ抽出部127は、抽出した高損失データを特定する情報を、高損失データ記憶部123に保存する。
【0088】
高混乱データ抽出部128は、クラス確率分布におけるクラス間のクラス確率の差に基づいて、機械学習モデルの混乱の度合いを示す混乱度を算出する。混乱度の算出は、例えば、以下のようなプログラムコードを実行することで行われる。
【0089】
図9は、混乱度を算出するプログラムコードの一例を示す図である。
1行目において、“confidence”は、最大のクラス確率を格納する変数である。np.max(class_prob)は、クラス確率分布の配列の中から最大値を取得する関数である。
【0090】
2行目において、“second_most_idx”は、2番目に大きいクラス確率を格納する変数である。“np.argpartition(class_prob,-2)[-2]”は、クラス確率を大きい順に配列していったときの2番目に大きい値を取得する関数である。
【0091】
3行目において、“confusion”は、混乱度を格納する変数である。
つまり、高混乱データ抽出部128は、最大のクラス確率と、2番目に大きいクラス確率との差の逆数を算出することで混乱度を算出できる。
【0092】
さらに、高混乱データ抽出部128は、第1入力データ群に含まれる成功データと失敗データを、混乱度の大きさ順に配列し、混乱度の大きさが上位のN件を、高混乱データとして抽出する。高混乱データ抽出部128は、抽出した高混乱データを特定する情報を、高混乱データ記憶部124に保存する。
【0093】
なお、
図8の例では、高損失データ抽出部127と高混乱データ抽出部128の両方が、訓練済みの機械学習モデルを用いて推論を行っているが、このような形態に限定されるわけではない。高損失データ抽出部127と高混乱データ抽出部128の一方が、訓練済みの機械学習モデルを用いて推論を行い、高損失データ抽出部127と高混乱データ抽出部128の他方が、その推論結果を用いてもよい。または、高損失データ抽出部127と高混乱データ抽出部128とは別の推論部が訓練済みの機械学習モデルを用いて推論を行い、高損失データ抽出部127と高混乱データ抽出部128は、その推論結果を用いて、高損失データと高混乱データを抽出するようにしてもよい。
【0094】
外れ値除去部129は、高損失データ記憶部123に記憶されている高損失データを特定する情報と、高混乱データ記憶部124に記憶されている高混乱データを特定する情報から、除去対象データを特定する。外れ値除去部129は、高損失データであり、かつ高混乱データであるデータを、除去対象データとして特定する。なお、外れ値除去部129は、失敗データについては、ユーザによる指定に基づいて、高損失データであって、高混乱データでないものを、除去対象データとして追加するか否かを決定してもよい。外れ値除去部129は、除去対象データのデータ件数などに応じて、高損失データであって、高混乱データでない失敗データを、除去対象データとして追加するか否かを決定してもよい。
【0095】
また、外れ値除去部129は、第1入力データ群から除去対象データを除去した第2入力データ群を生成し、第2入力データ群記憶部125に保存する。外れ値除去部129は、第2入力データ群を、表示装置111に表示してもよいし、他の情報処理装置に送信してもよい。
【0096】
モデル修正部130は、第2入力データ群を用いて、モデル記憶部122に記憶された訓練済みの機械学習モデルを修正する。機械学習モデルの修正は、例えば、非特許文献1で開示されている技術を用いて、以下のように行うことができる。
【0097】
モデル修正部130は、第2入力データ群のうち失敗データを機械学習モデルに入力し、その際の順伝搬・逆伝搬の値から、失敗データに対する影響の大きい重み値を特定する。さらに、モデル修正部130は、第2入力データ群のうち成功データを機械学習モデルに入力し、その際の順伝搬・逆伝搬の値から、成功データに対する影響の大きい重み値を特定する。なお、モデル修正部130は、成功データのデータ件数が多い場合には、計算コストを減らすために、サンプリングを行い、データ件数を減らすようにしてもよい。
【0098】
モデル修正部130は、これら2つの特定結果に基づいて、失敗データのみに影響する重み値を抽出する。
図10は、重み値の抽出例を説明する図である。
【0099】
図10には、第2入力データ群記憶部125に記憶されている第2入力データ群が入力され、推論結果151を出力する機械学習モデル150の例が示されている。機械学習モデル150は、ニューロン150a1,150a2,…,150a7を含むニューラルネットワークで表されている。各ニューロン間のエッジには、機械学習モデル150のパラメータの一例である重み値が設定されている。モデル修正部130は、これらの重み値の中から、成功データを機械学習モデル150に入力したときの推論の成否に影響せず、失敗データを機械学習モデル150に入力したときの推論の成否に影響する重み値を抽出する。
【0100】
図10には、抽出されたn個の重み値がw
1,w
2,…,w
nと表されている。例えば、w
1は、ニューロン150a3,150a7間のエッジの重み値であり、w
2は、ニューロン150a3,150a6間のエッジの重み値である。
【0101】
モデル修正部130は、失敗データを正しく認識できるように、抽出した重み値をメタヒューリスティック最適化手法によって修正することで、機械学習モデルを修正する。モデル修正部130は、メタヒューリスティック最適化手法として、例えば、粒子群最適化を用いることができる。
【0102】
図11は、粒子群最適化の例を説明する図である。
k個の粒子x
1,x
2,x
3,…,x
kの値は、それぞれ、抽出されたw
1~w
nの値を表す。粒子x
1,x
2,x
3,…,x
kのそれぞれに対して、評価値(以下、fitnessと表記する)が算出される。fitnessは、例えば、以下の式(1)で表される。
【0103】
fitness=修正率+(1-退行率)+(失敗損失(M’)+Δ)/(失敗損失(Morig)+Δ)+(成功損失(M’)+Δ)/(成功損失(Morig)+Δ) (1)
修正率は、機械学習モデルに入力された全失敗データのうち、推論に成功した失敗データの割合である。退行率は、機械学習モデルに入力された全成功データのうち、推論に失敗した成功データの割合である。失敗損失(M’)は、機械学習モデルに各失敗データを入力した場合の推論結果から得られる損失の和、または損失の平均誤差である。成功損失(M’)は、機械学習モデル150に各成功データを入力した場合の推論結果から得られる損失の和、または損失の平均誤差である。失敗損失(Morig)は、粒子群最適化の適用前の機械学習モデルに、各失敗データを入力した場合の推論結果151から得られる損失の和、または損失の平均誤差である。成功損失(Morig)は、粒子群最適化の適用前の機械学習モデルに、各成功データを入力した場合の推論結果から得られる損失の和、または損失の平均誤差である。Δは、0による除算を防ぐための所定の微小値である。
【0104】
なお、fitnessは上記の式(1)に限定されるわけではない。
fitnessが最大の粒子の値がグローバルベストである。イタレーション番号0の試行では、fitnessが最大となるx2の値が、グローバルベストである。粒子群最適化では、他の粒子についてもグローバルベストに近づくように、値の更新が行われる。
【0105】
イタレーション番号tにおける粒子xi(i=1,2,3,…,k)の値は、以下の式(2)、式(3)により計算できる。
xi(t)=xi(t-1)+vi(t-1) (2)
vi(t)=c0vi(t-1)+c1r1(pl-xi(t))+c2r2(pg-xi(t)) (3)
式(3)において、c0,c1,c2は所定の定数、r1,r2は乱数、plはローカルベスト、pgはグローバルベストである。ローカルベストは、イタレーション番号tの試行までに得られたxiの値のうち、fitnessが最大となる値である。
【0106】
イタレーション番号1の試行では、x3の値がグローバルベストとなっている。粒子群最適化では、指定回数、またはグローバルベストが一定回数変わらなくなるまで、上記の処理が行われる。
【0107】
図11の例では、イタレーション番号Nの試行において得られているグローバルベスト(x
1の値)が、端子群最適化の結果である。このときのx
1の値、すなわち、w
1~w
nの値が、元のw
1~w
nの値の代わりに用いられることで、機械学習モデルが修正される。
【0108】
モデル修正部130は、修正した機械学習モデルを修正済モデル記憶部126に保存する。モデル修正部130は、修正した機械学習モデルを表示装置111に表示してもよいし、他の情報処理装置に送信してもよい。
【0109】
図12は、機械学習モデル修正の手順例を示すフローチャートである。S10~S27は、処理のステップを表している。
(S10)高損失データ抽出部127と高混乱データ抽出部128は、第1入力データ群記憶部121に記憶された第1入力データ群と、モデル記憶部122に記憶された機械学習モデルを取得する。
【0110】
(S11)高損失データ抽出部127は、第1入力データ群を、機械学習モデルに入力し、推論を実行し、推論結果を示すクラス確率分布を生成する。
(S12)高損失データ抽出部127は、推論結果に基づいて、第1入力データ群に含まれるデータごとに損失を算出し、損失が大きい順にデータを特定する情報を配列した損失ランキングリストを作成する。
【0111】
(S13)高混乱データ抽出部128は、第1入力データ群を、機械学習モデルに入力し、推論を実行し、推論結果を示すクラス確率分布を生成する。
(S14)高混乱データ抽出部128は、推論結果に基づいて、第1入力データ群に含まれるデータごとに混乱度を算出し、混乱度が大きい順にデータを特定する情報を配列した混乱度ランキングリストを作成する。
【0112】
(S15)高損失データ抽出部127は、損失ランキングリストから、損失の大きさが上位のN件を高損失データとして抽出する。そして、高損失データ抽出部127は、抽出した高損失データを特定する情報を含む高損失データリストを作成する。高混乱データ抽出部128は、混乱度ランキングリストから、混乱度の大きさが上位のN件を高混乱データとして抽出する。そして、高混乱データ抽出部128は、抽出した高混乱データを特定する情報を含む高混乱データリストを作成する。なお、抽出するデータ件数であるNの値は、例えば、ユーザによって、全データ件数の数%の値などと指定されてもよい。また、抽出される高損失データと、高混乱データの件数は、異なっていてもよい。
【0113】
(S16)外れ値除去部129は、成功データを特定する情報を含む成功データリストと、失敗データを特定する情報を含む失敗データリストを作成する。
(S17)外れ値除去部129は、第1入力データ群記憶部121に記憶されているデータ(成功データまたは失敗データ)を選択する。
【0114】
(S18)外れ値除去部129は、高損失データリストと高混乱データリストを参照し、選択したデータが、高損失データ、かつ高混乱データであるか否かを判定する。高損失データ、かつ高混乱データであると判定された場合、ステップS19の処理が行われる。高損失データ、かつ高混乱データではないと判定された場合、ステップS20の処理が行われる。
【0115】
(S19)外れ値除去部129は、ステップS17の処理で選択したデータを除去リストに追加する。ステップS19の処理後、ステップS23の処理が行われる。
(S20)外れ値除去部129は、高損失データリストと高混乱データリストと失敗データリストに基づいて、選択したデータが高損失かつ低混乱(高損失データであり、高混乱データではない)の失敗データであるか否かを判定する。高損失かつ低混乱の失敗データであると判定された場合、ステップS21の処理が行われる。高損失かつ低混乱の失敗データではないと判定された場合、ステップS23の処理が行われる。
【0116】
(S21)外れ値除去部129は、ユーザにより、高損失かつ低混乱の失敗データを除去する指定があるか否かを判定する。高損失かつ低混乱の失敗データを除去する指定があると判定された場合、ステップS22の処理が行われ、高損失かつ低混乱の失敗データを除去する指定がないと判定された場合、ステップS23の処理が行われる。
【0117】
(S22)外れ値除去部129は、ステップS17の処理で選択したデータを除去リストに追加する。その後、ステップS23の処理が行われる。
(S23)外れ値除去部129は、第1入力データ群記憶部121に記憶されているデータのうちで、未選択のデータがあるか否かを判定する。未選択のデータがあると判定された場合、ステップS17からの処理が繰り返される。未選択のデータがないと判定された場合、ステップS24の処理が行われる。
【0118】
(S24)外れ値除去部129は、第1入力データ群から除去リストに含まれる除去対象データを除去した第2入力データ群を生成する。
(S25)外れ値除去部129は、第2入力データ群を出力する。
【0119】
(S26)モデル修正部130は、第2入力データ群を用いて、モデル記憶部122に記憶された訓練済みの機械学習モデルを、例えば前述の方法を用いて修正する。
(S27)モデル修正部130は、修正した機械学習モデルを出力する。以上で、機械学習モデルの修正処理が終了する。
【0120】
なお、上記の各ステップの処理順序は一例であって、適宜入れ替えてもよい。
以上説明したように、第2の実施の形態の情報処理装置100は、第1入力データ群に含まれる成功データと失敗データに対して推論を行い、その結果から高損失データと高混乱データを抽出する。そして、情報処理装置100は、高損失データであり、かつ高混乱データであるデータを、第1入力データ群から除去した第2入力データ群を生成する。情報処理装置100は、第2入力データ群を用いて機械学習モデルを修正する。
図5や
図6を用いて説明したように、損失と混乱度の両方が大きくなる推論結果が得られるデータは、機械学習モデルを極度に混乱させ、修正を困難にする有害な外れ値である可能性が高い。このため、上記のように抽出した高損失データであり、かつ高混乱データであるデータを、第1入力データ群から除去した第2入力データ群を用いて機械学習モデルを修正することで、修正後の精度が向上する。
【0121】
また、情報処理装置100は、高損失かつ低混乱の失敗データについても、例えば、ユーザにより除去する旨の指定、または除去対象データのデータ件数などに応じて、除去する。
図6を用いて説明したように、高損失かつ低混乱の失敗データについても、機械学習モデルの修正後の精度を悪化させる外れ値である可能性があるため、このような失敗データについても除去することで、修正後の精度がより向上する。
【0122】
(適用例)
以下、第2の実施の形態の情報処理装置100を、車両画像から車種(7クラス)を判別する事例に適用した例を示す。
【0123】
本事例では、機械学習モデルの訓練時に用いられた入力データ群に含まれる車両の正面および後ろ正面の画像の割合よりも、運用時の入力データ群に含まれる当該画像の割合が10倍大きいため、機械学習モデルの判別精度が低下したものとする。つまり、訓練時の入力データ群と、運用時の入力データ群との間にデータドリフトが生じている。
【0124】
また、本事例では、前述の第1入力データ群は、データ件数が5547件の訓練時の入力データ群と、データ件数が1237件の運用時の入力データ群を含む。さらに、機械学習モデルの修正結果を確認するためのデータが、2476件、用いられる。修正結果を確認するためのデータは、修正対象の機械学習モデルを用いた本来の業務などにおいては、存在しないデータである。このようなデータは、機械学習モデルが汎化性能を担保しつつ、判別に失敗したデータを正しく判別できるように修正されているかどうかを確認するために用いられる。
【0125】
情報処理装置100は、上記のような第1入力データ群を用いて、
図12に示した手順により、機械学習モデルの修正を行う。損失は、スパースカテゴリカル交差エントロピーにより算出され、混乱度は、
図9に示したようなプログラムコードを実行することで算出される。また、本事例において、高損失データリストと高混乱データリストのデータ件数は、それぞれ100件(すなわちN=100)である。
【0126】
なお、
図12のステップS26の機械学習モデルを修正する処理では、第2入力データ群のうち、訓練時の入力データ群は、成功データが用いられる。訓練時において判別に成功していたデータについては、判別結果を維持するためである。また、第2入力データ群のうち、運用時の入力データ群は、失敗データが用いられる。訓練時にあまり想定されていなかったデータについても、機械学習モデルで正しく判別できるようにするためである。訓練時にあまり想定されていなかったデータは、例えば、前述のように、車両の正面および後ろ正面の画像である。
【0127】
図13は、機械学習モデルの修正後の精度の例を示す図である。
本事例において、高損失データでかつ高混乱データのデータを除去した第2入力データ群を用いて機械学習モデルを修正した場合の精度が示されている。
図13に示されている精度は、分類精度を改善したいクラスに属する入力データ(画像)のうち、そのクラスに正しく分類された割合(正答率)である。
図13において、精度は、10回の実験によって得られた値が示されている。精度の最大値と最小値は、横バーで示され、精度の四分位範囲がブロックで表されている。ブロック内の横線は中央値を示している。
【0128】
図13には比較のために、データの除去を行わず、第1入力データ群を用いて機械学習モデルを修正した場合の精度が、さらに示されている。また、比較のために、前述のVAEによる外れ値の除去方法によりデータを除去した入力データ群を用いて、機械学習モデルを修正した場合の精度が、さらに示されている。
【0129】
図13のように、データの除去を行わない場合や、VAEによる外れ値の除去方法を用いた場合よりも、高損失データでかつ高混乱データのデータを除去した場合には、中央値や平均値レベルが高くなっている。この点から、高損失データでかつ高混乱データのデータを除去した場合には、機械学習モデルの修正後の精度が、データの除去を行わない場合や、VAEによる外れ値の除去方法を用いた場合よりも、向上していることが分かる。
【0130】
図14は、機械学習モデルの修正後の修正率と退行率の比を示す図である。
本事例において、高損失データでかつ高混乱データのデータを除去した第2入力データ群を用いて機械学習モデルを修正した場合の、修正率と退行率の比(修正率/退行率)が示されている。修正率は、運用時において判別に失敗したデータが、修正後の機械学習モデルに入力された場合に、正解のクラスに分類される確率である。修正率は、訓練時において判別に成功したデータが、修正後の機械学習モデルに入力された場合に、誤ったクラスに分類される確率である。
図14において、修正率/退行率は、10回の実験によって得られた値が示されている。修正率/退行率の最大値と最小値は、横バーで示され、修正率/退行率の四分位範囲がブロックで表されている。ブロック内の横線は中央値を示している。
【0131】
図14には比較のために、データの除去を行わず、第1入力データ群を用いて機械学習モデルを修正した場合の修正率/退行率が、さらに示されている。また、比較のために、前述のVAEによる外れ値の除去方法によりデータを除去した入力データ群を用いて、機械学習モデルを修正した場合の修正率/退行率が、さらに示されている。
【0132】
図14のように、データの除去を行わない場合や、VAEによる外れ値の除去方法を用いた場合よりも、高損失データでかつ高混乱データのデータを除去した場合には、中央値や平均値レベルが高くなっている。この点からも、高損失データでかつ高混乱データのデータを除去した場合には、機械学習モデルの修正後の精度が、データの除去を行わない場合や、VAEによる外れ値の除去方法を用いた場合よりも、向上していることが分かる。
【0133】
以上、実施の形態に基づき、本発明のモデル修正プログラム、モデル修正方法および情報処理装置の一観点について説明してきたが、これらは一例にすぎず、上記の記載に限定されるものではない。
【符号の説明】
【0134】
10 情報処理装置
11 記憶部
12 処理部
13 第1入力データ群
14 機械学習モデル(訓練済み)
15 除去対象データ
16 第2入力データ群
17 機械学習モデル(修正後)