(19)【発行国】日本国特許庁(JP)
(12)【公報種別】特許公報(B2)
(11)【特許番号】
(24)【登録日】2024-10-15
(45)【発行日】2024-10-23
(54)【発明の名称】学習装置、学習方法、および、学習プログラム
(51)【国際特許分類】
G06N 3/094 20230101AFI20241016BHJP
G06N 20/00 20190101ALI20241016BHJP
【FI】
G06N3/094
G06N20/00 130
(21)【出願番号】P 2023553921
(86)(22)【出願日】2021-10-18
(86)【国際出願番号】 JP2021038503
(87)【国際公開番号】W WO2023067669
(87)【国際公開日】2023-04-27
【審査請求日】2024-02-16
(73)【特許権者】
【識別番号】000004226
【氏名又は名称】日本電信電話株式会社
(74)【代理人】
【識別番号】110002147
【氏名又は名称】弁理士法人酒井国際特許事務所
(72)【発明者】
【氏名】山下 智也
(72)【発明者】
【氏名】山田 真徳
【審査官】今城 朋彬
(56)【参考文献】
【文献】中国特許出願公開第113255531(CN,A)
【文献】WANG, Yisen et al.,Improving Adversarial Robustness Requires Revisiting Misclassified Examples,2020年03月11日,pp.1-14,https://openreview.net/forum?id=rklOg6EFwS
(58)【調査した分野】(Int.Cl.,DB名)
G06N 3/00-20/00
(57)【特許請求の範囲】
【請求項1】
Adversarial Exampleを含む入力データの識別結果を出力するためのモデルの学習用データを取得するデータ取得部と、
前記モデルのAdversarial Exampleに対する識別難易度を判定する難易度判定部と、
判定された前記モデルのAdversarial Exampleに対する識別難易度が高いほど、MART(Misclassification Aware adveRsarial Training)によるモデルの学習に用いられるloss関数における、前記Adversarial Exampleに対する識別結果と、前記Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを大きくして、前記モデルの学習を行う学習処理部と
を備えることを特徴とする学習装置。
【請求項2】
前記学習処理部は、
判定された前記モデルのAdversarial Exampleに対する識別難易度が低いほど、前記loss関数における、前記Adversarial Exampleに対する識別結果と、前記Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを小さくして、前記モデルの学習を行うこと
を特徴とする請求項1に記載の学習装置。
【請求項3】
前記難易度判定部は、
学習対象のモデルに、Adversarial Exampleを入力することにより、前記モデルのAdversarial Exampleに対する識別難易度を判定する
ことを特徴とする請求項1に記載の学習装置。
【請求項4】
Adversarial Trainingが行われた前記モデルを用いて、入力データを識別する識別部
をさらに備えることを特徴とする請求項1に記載の学習装置。
【請求項5】
学習装置により実行される学習方法であって、
Adversarial Exampleを含む入力データの識別結果を出力するためのモデルの学習用データを取得する工程と、
前記モデルのAdversarial Exampleに対する識別難易度を判定する工程と、
判定された前記モデルのAdversarial Exampleに対する識別難易度が高いほど、MART(Misclassification Aware adveRsarial Training)によるモデルの学習に用いられるloss関数における、前記Adversarial Exampleに対する識別結果と、前記Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを大きくして、前記モデルの学習を行う工程と
を含むことを特徴とする学習方法。
【請求項6】
Adversarial Exampleを含む入力データの識別結果を出力するためのモデルの学習用データを取得する工程と、
前記モデルのAdversarial Exampleに対する識別難易度を判定する工程と、
判定された前記モデルのAdversarial Exampleに対する識別難易度が高いほど、MART(Misclassification Aware adveRsarial Training)によるモデルの学習に用いられるloss関数における、前記Adversarial Exampleに対する識別結果と、前記Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを大きくして、前記モデルの学習を行う工程と
をコンピュータに実行させるための学習プログラム。
【発明の詳細な説明】
【技術分野】
【0001】
本発明は、学習装置、学習方法、および、学習プログラムに関する。
【背景技術】
【0002】
従来、データにノイズを加えることでモデル(例えば、分類器)に誤判定させるAdversarial Exampleがある。このAdversarial Exampleに対しロバストなモデルの学習方法として、MART(Misclassification Aware adveRsarial Training)がある。
【0003】
MARTは、Adversarial Trainingの一手法であり、元の入力データ(ノイズを加える前のデータ)に対するモデルの識別難易度を基にして決定した方針に基づき、モデルの学習(Adversarial Training)を行う手法である。
【先行技術文献】
【非特許文献】
【0004】
【文献】YisenWang, Difan Zou, Jinfeng Yi, James Bailey, Xingjun Ma, Quanquan Gu,”IMPROVING ADVERSARIAL ROBUSTNESS REQUIRES REVISITING MISCLASSIFIED EXAMPLES”, ICLR 2020
【発明の概要】
【発明が解決しようとする課題】
【0005】
しかし、MARTは、モデルが元の入力データを正しく認識できる場合であっても、Adversarial Exampleのラベルを正しく識別する学習がモデルの課題として厳しいことがある。その結果、MARTにより学習されたモデルが、Adversarial Exampleを精度よく分類できない場合があるという問題がある。そこで、本発明は、Adversarial Exampleに対し、識別精度の高いモデルの学習を行うことを課題とする。
【課題を解決するための手段】
【0006】
前記した課題を解決するため、本発明は、Adversarial Exampleを含む入力データの識別結果を出力するためのモデルの学習用データを取得するデータ取得部と、前記モデルのAdversarial Exampleに対する識別難易度を判定する難易度判定部と、判定された前記モデルのAdversarial Exampleに対する識別難易度が高いほど、MART(Misclassification Aware adveRsarial Training)によるモデルの学習に用いられるloss関数における、前記Adversarial Exampleに対する識別結果と、前記Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを大きくして、前記モデルの学習を行う学習処理部とを備えることを特徴とする。
【発明の効果】
【0007】
本発明によれば、Adversarial Exampleに対し、識別精度の高いモデルの学習を行うことができる。
【図面の簡単な説明】
【0008】
【
図2】
図2は、学習装置の処理手順の例を示すフローチャートである。
【
図3】
図3は、学習装置の処理手順の例を示すフローチャートである。
【
図5】
図5は、学習装置により学習されたモデルに対する実験結果を示す図である。
【
図6】
図6は、学習プログラムを実行するコンピュータの構成例を示す図である。
【発明を実施するための形態】
【0009】
以下、図面を参照しながら、本発明を実施するための形態(実施形態)について説明する。本発明は、本実施形態に限定されない。
【0010】
[学習装置の概要]
まず、
図1を用いて、本実施形態の学習装置の概要を説明する。学習装置は、既存手法であるMARTを改善した手法により、モデルのAdversarial Trainingを行う。
【0011】
すなわち、学習装置は、まず、学習対象のモデルのAdversarial Exampleに対する識別難易度を判定する。そして、学習対象のモデルのAdversarial Exampleに対する識別難易度が低い場合、学習装置は、Adversarial Exampleに対するモデルの出力が、元の入力データ(Adversarial Exampleのノイズが付加される前のデータ)に対するモデルの出力から大きく外れないようにすることよりも、Adversarial Exampleの正しいラベルを識別することを重視してモデルの学習を行う。
【0012】
一方、学習対象のモデルのAdversarial Exampleに対する識別難易度が高い場合、学習装置は、Adversarial Exampleの正しいラベルを識別することよりも、Adversarial Exampleに対するモデルの出力が、元の入力データに対するモデルの出力から大きく外れないようにすることを重視してモデルの学習を行う。
【0013】
このように学習装置10は、Adversarial Exampleに対する識別精度(Robust Accuracy)に基づき、モデルの学習方針を決定し、モデルを学習するので、識別精度の高いモデルの学習を行うことができる。
【0014】
[前提知識]
ここで、本実施形態におけるAdversarial ExampleとAdversarial Trainingについて説明する。
【0015】
[Adversarial Example]
Adversarial Exampleは、入力データに対し、人の目では認識できないほどの微小なノイズを乗せることでモデルに誤判断を起こさせる攻撃手法である。Adversarial Exampleの目的関数は式(1)に示す通りである。
【0016】
【0017】
上記の式(1)におけるl(θ,x,y)は、モデルのloss関数である。θはモデルのパラメータ、xはモデルの入力データ、yはモデルから出力される入力データxの識別結果である。Adversarial Exampleを生成する代表的なアルゴリズムとして、例えば、FGSM(Fast Gradient Sign Method)、PGD(Projection Gradient Descent)がある。例えば、FGSMでは以下の式(2)に従って、入力データxにノイズを乗せる。
【0018】
【0019】
式(2)におけるεはノイズの大きさである。FGSMでは、モデルのloss関数の値が大きくなるように、入力データxにノイズを乗せる。
【0020】
また、PGDでは以下の式(3)に示すアルゴリズムに従って、入力データxにノイズを乗せる。
【0021】
【0022】
[Adversarial Training]
Adversarial Trainingは、Adversarial Exampleに対し、ロバストなモデルを学習する学習方法である。Adversarial Trainingの目的関数は、式(4)に示す通りである。
【0023】
【0024】
一般的なAdversarial Trainingでは、以下の式(5)に示すアルゴリズムに従って、モデルのパラメータθを更新する。
【0025】
【0026】
l(θ,x,y)は、loss関数であり、主にCE(Cross Entropy)関数等が用いられる。式(5)におけるx´は、入力データxのAdversarial Exampleである。このアルゴリズムが一般的な学習アルゴリズムと異なる点は、学習データにAdversarial Exampleを利用する点である。
【0027】
[MART]
MARTでは以下の式(6)に示すloss関数を用いてモデルの学習を行う。
【0028】
【0029】
このloss関数中のBCEはBoosted Cross Entropy関数であり、以下の式(7)により表される。
【0030】
【0031】
また、式(6)においてKLで表される関数はカルバックライブラー距離であり、確率分布同士の距離を測る指標として用いられる。
【0032】
MARTでは、式(6)に示すloss関数を用いることで元の入力データに対する識別難易度(1-py(x,θ))に基づき、モデルの学習の方針を決定することを可能にしている。
【0033】
ここで、モデルが元の入力データを正しく識別できる場合(つまり、識別難易度が低い場合)、式(6)に示すloss関数の第1項(BCEの項)を重視した学習が行われる。すなわち、モデルがAdversarial Exampleを正しいラベルに識別することを重視した学習が行われる。
【0034】
一方、モデルが元の入力データを正しく識別できない場合(つまり、識別難易度が高い場合)、式(6)に示すloss関数の第2項(KLの項)を重視した学習が行われる。すなわち、Adversarial Exampleに対するモデルの出力が、元の入力データに対するモデルの出力から大きく外れないようにすることを重視した学習が行われる。
【0035】
MARTで用いるloss関数は、モデルが元の入力データを正しく認識できない場合、当該モデルがAdversarial Exampleを正しいラベルに認識するよう学習することはモデルの課題として厳しすぎるという直感に基づいて設計されている。
【0036】
しかし、モデルが元の入力データを正しく認識できる場合でも、Adversarial Exampleの正しいラベルに認識する学習が当該モデルの課題として厳しすぎる可能性もある。
【0037】
そこで、本実施形態の学習装置は、モデルのAdversarial Exampleに対する識別難易度を基に、モデルがAdversarial Exampleを正しいラベルに識別する学習が当該モデルの学習として厳しすぎるか否かを判定する。
【0038】
[学習装置の構成例]
図1を用いて、学習装置10の構成例を説明する。学習装置10は、例えば、入力部11、出力部12、通信制御部13、記憶部14、および、制御部15を備える。
【0039】
入力部11は、各種データの入力を受け付けるインタフェースである。例えば、入力部11は、後述する学習処理および予測処理に用いるデータの入力を受け付ける。出力部12は、各種データの出力を行うインタフェースである。例えば、出力部12は、制御部15により予測されたデータのラベルを出力する。
【0040】
通信制御部13は、NIC(Network Interface Card)等で実現され、ネットワークを介したサーバ等の外部の装置と制御部15との通信を制御する。例えば、通信制御部13は、学習対象のデータを管理する管理装置(
図4参照)等と制御部15との通信を制御する。
【0041】
記憶部14は、RAM(Random Access Memory)、フラッシュメモリ(Flash Memory)等の半導体メモリ素子、または、ハードディスク、光ディスク等の記憶装置によって実現され、後述する学習処理により学習されたモデルのパラメータ等が記憶される。
【0042】
制御部15は、例えば、CPU(Central Processing Unit)等を用いて実現され、記憶部14に記憶された処理プログラムを実行する。これにより、制御部15は、
図1に例示するように、取得部15a、学習部15bおよび予測部15cとして機能する。
【0043】
取得部15aは、後述する学習処理および予測処理に用いるデータを、入力部11あるいは通信制御部13を介して取得する。
【0044】
学習部15bは、Adversarial Exampleを含む入力データのラベルを予測するためのモデルの学習(Adversarial Training)を行う。学習部15bは、難易度判定部151と、学習処理部152とを備える。
【0045】
難易度判定部151は、学習対象のモデルのAdversarial Exampleに対する識別難易度を判定する。例えば、難易度判定部151は、学習対象のモデルに、学習用のAdversarial Exampleを入力し、当該モデルがAdversarial Exampleを正しいラベルに識別する確率py(x´,θ))を得る。そして、難易度判定部151は、(1-py(x´,θ))を学習対象のモデルのAdversarial Exampleに対する識別難易度とする。
【0046】
学習処理部152は、難易度判定部151により判定されたモデルのAdversarial Exampleに対する識別難易度の高さに基づき、当該モデルの学習方針を決定し、学習を行う。
【0047】
例えば、学習処理部152は、モデルのAdversarial Exampleに対する識別難易度が高いほど、MARTによるモデルの学習に用いられるloss関数における、Adversarial Exampleに対する識別結果と、Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを大きくして、モデルの学習を行う。
【0048】
つまり、学習処理部152は、Adversarial Exampleに対するモデルの出力が、元の入力データに対するモデルの出力から大きく外れないようにすることよりも、Adversarial Exampleを正しいラベルに識別することの方を重視してモデルの学習を行う。
【0049】
また、学習処理部152は、モデルのAdversarial Exampleに対する識別難易度が低いほど、MARTによるモデルの学習に用いられるloss関数における、Adversarial Exampleに対する識別結果と、Adversarial Exampleのノイズが付加される前のデータに対する識別結果との差の値の重みを小さくして、モデルの学習を行う。
【0050】
つまり、学習処理部152は、Adversarial Exampleを正しいラベルに識別することよりも、Adversarial Exampleに対するモデルの出力が元の入力データに対するモデルの出力から大きく外れないようにすることの方を重視してモデルの学習を行う。
【0051】
例えば、学習処理部152は、以下の式(8)に示すloss関数を用いてモデルの学習を行う。
【0052】
【0053】
MARTで用いられるloss関数(式(6)参照)との違いは、モデルの学習方針を決定する識別難易度の部分(1-py(x,θ))が、Adversarial Exampleに対する識別難易度(1-py(x´,θ))に置き換わっていることである。これにより、学習処理部152は、モデルの学習方針を決定する際、学習対象のモデルのAdversarial Exampleに対する識別難易度を用いるので、Adversarial Exampleに対し識別精度の高いモデルの学習を行うことができる。
【0054】
予測部15cは、学習部15bにより学習されたモデルを用いて、入力データのラベルを予測(識別)する。例えば、予測部15cは、学習されたモデルを用いて、新たに取得されたデータの各ラベルの確率を算出し、最も確率が高いラベルを出力する。これにより、学習装置10は、例えば、入力データがAdversarial Exampleであった場合にも、正しいラベルを出力することができる。
【0055】
[学習処理]
次に、
図2を参照して、学習装置10による学習処理手順の例について説明する。
図2に示す処理は、例えば、学習処理の開始を指示する操作入力があったタイミングで開始される。
【0056】
まず、取得部15aが、Adversarial Exampleを含む学習データを取得する(S1)。次に、学習部15bが、学習データと、loss関数(式(8)参照)とを用いて、入力データのラベルの確率分布を表すモデルを学習する(S2)。学習部15bは、S2で学習されたモデルのパラメータを記憶部14に記憶する。
【0057】
[予測処理]
次に、
図3を参照して、学習装置10による入力データのラベルの予測処理の例について説明する。
図3に示す処理は、例えば、予測処理の開始を指示する操作入力があったタイミングで開始される。
【0058】
まず、取得部15aは、ラベルの予測対象のデータを取得する(S11)。次に、予測部15cは、学習部15bにより学習されたモデルを用いて、S11で取得されたデータのラベルを予測する(S12)。例えば、予測部15cは、学習されたモデルを用いて、S11で取得されたデータx’のp(x’)を算出し、最も確率が高いラベルを出力する。これにより、例えば、データx’がAdversarial Exampleであった場合でも、学習装置10は、正しいラベルを出力することができる。
【0059】
[学習装置の適用例]
上記の学習装置10を、画像中の物体認識処理に適用してもよい。この場合の適用例を、
図4を参照しながら説明する。
【0060】
例えば、学習装置10は、データ取得装置から取得した教師データ(学習データ)と、前記したloss関数とを用いて、モデルの学習(Adversarial Training)を行う。その後、学習装置10は、データ取得装置から画像データを取得すると、学習済みモデルを用いて、取得した画像データのラベルの予測を行う。そして、学習装置10は、予測結果に基づき、物体の認識結果を出力する。
【0061】
[実験]
次に、学習装置10により学習されたモデルの実験結果を説明する。実験では、CIFAR10のデータセットに対してMARTによるモデルの学習と、学習装置10によるモデルの学習とを行い、それぞれのモデルの識別精度を比較した。実験に用いたモデルは、ResNet18である。学習のパラメータはMARTの論文(非特許文献1)の設定と一致させることとし、ハイパーパラメータλは上記の論文で用いられていた1,2,…,10を用いた。
【0062】
MARTにより学習されたモデルと、学習装置10により学習されたモデル(Propose)の評価結果を
図5に示す。
図5におけるNatAccは、元の入力データに対するモデルの正解率、RobAccは、Adversarial Exampleに対するモデルの正解率を示す。なお、
図5に示す精度は、Adversarial Exampleに対する精度が最も高いエポックでの精度である。
【0063】
図5に示すとおり、λ=7の場合を除き、RobAccについては、ProposeがMARTを上回っていることが確認できた。また、λ=7における、各エポックのRobAccを確認したところ、あるエポックにおいて、大きく上振れたRobAccが存在することが確認できた。
【0064】
[システム構成等]
また、図示した各部の各構成要素は機能概念的なものであり、必ずしも物理的に図示のように構成されていることを要しない。すなわち、各装置の分散・統合の具体的形態は図示のものに限られず、その全部又は一部を、各種の負荷や使用状況等に応じて、任意の単位で機能的又は物理的に分散・統合して構成することができる。さらに、各装置にて行われる各処理機能は、その全部又は任意の一部が、CPU及び当該CPUにて実行されるプログラムにて実現され、あるいは、ワイヤードロジックによるハードウェアとして実現され得る。
【0065】
また、前記した実施形態において説明した処理のうち、自動的に行われるものとして説明した処理の全部又は一部を手動的に行うこともでき、あるいは、手動的に行われるものとして説明した処理の全部又は一部を公知の方法で自動的に行うこともできる。この他、上記文書中や図面中で示した処理手順、制御手順、具体的名称、各種のデータやパラメータを含む情報については、特記する場合を除いて任意に変更することができる。
【0066】
[プログラム]
前記した学習装置10は、パッケージソフトウェアやオンラインソフトウェアとしてプログラム(学習プログラム)を所望のコンピュータにインストールさせることによって実装できる。例えば、上記のプログラムを情報処理装置に実行させることにより、情報処理装置を学習装置10として機能させることができる。ここで言う情報処理装置にはスマートフォン、携帯電話機やPHS(Personal Handyphone System)等の移動体通信端末、さらには、PDA(Personal Digital Assistant)等の端末等がその範疇に含まれる。
【0067】
図6は、学習プログラムを実行するコンピュータの一例を示す図である。コンピュータ1000は、例えば、メモリ1010、CPU1020を有する。また、コンピュータ1000は、ハードディスクドライブインタフェース1030、ディスクドライブインタフェース1040、シリアルポートインタフェース1050、ビデオアダプタ1060、ネットワークインタフェース1070を有する。これらの各部は、バス1080によって接続される。
【0068】
メモリ1010は、ROM(Read Only Memory)1011及びRAM(Random Access Memory)1012を含む。ROM1011は、例えば、BIOS(Basic Input Output System)等のブートプログラムを記憶する。ハードディスクドライブインタフェース1030は、ハードディスクドライブ1090に接続される。ディスクドライブインタフェース1040は、ディスクドライブ1100に接続される。例えば磁気ディスクや光ディスク等の着脱可能な記憶媒体が、ディスクドライブ1100に挿入される。シリアルポートインタフェース1050は、例えばマウス1110、キーボード1120に接続される。ビデオアダプタ1060は、例えばディスプレイ1130に接続される。
【0069】
ハードディスクドライブ1090は、例えば、OS1091、アプリケーションプログラム1092、プログラムモジュール1093、プログラムデータ1094を記憶する。すなわち、上記の学習装置10が実行する各処理を規定するプログラムは、コンピュータにより実行可能なコードが記述されたプログラムモジュール1093として実装される。プログラムモジュール1093は、例えばハードディスクドライブ1090に記憶される。例えば、学習装置10における機能構成と同様の処理を実行するためのプログラムモジュール1093が、ハードディスクドライブ1090に記憶される。なお、ハードディスクドライブ1090は、SSD(Solid State Drive)により代替されてもよい。
【0070】
また、上述した実施形態の処理で用いられるデータは、プログラムデータ1094として、例えばメモリ1010やハードディスクドライブ1090に記憶される。そして、CPU1020が、メモリ1010やハードディスクドライブ1090に記憶されたプログラムモジュール1093やプログラムデータ1094を必要に応じてRAM1012に読み出して実行する。
【0071】
なお、プログラムモジュール1093やプログラムデータ1094は、ハードディスクドライブ1090に記憶される場合に限らず、例えば着脱可能な記憶媒体に記憶され、ディスクドライブ1100等を介してCPU1020によって読み出されてもよい。あるいは、プログラムモジュール1093及びプログラムデータ1094は、ネットワーク(LAN(Local Area Network)、WAN(Wide Area Network)等)を介して接続される他のコンピュータに記憶されてもよい。そして、プログラムモジュール1093及びプログラムデータ1094は、他のコンピュータから、ネットワークインタフェース1070を介してCPU1020によって読み出されてもよい。
【符号の説明】
【0072】
10 学習装置
11 入力部
12 出力部
13 通信制御部
14 記憶部
15 制御部
15a 取得部
15b 学習部
15c 予測部
151 難易度判定部
152 学習処理部