(19)【発行国】日本国特許庁(JP)
(12)【公報種別】特許公報(B2)
(11)【特許番号】
(24)【登録日】2024-11-11
(45)【発行日】2024-11-19
(54)【発明の名称】学習装置、学習方法、及び、プログラム
(51)【国際特許分類】
G06N 20/00 20190101AFI20241112BHJP
【FI】
G06N20/00
(21)【出願番号】P 2023549281
(86)(22)【出願日】2021-09-27
(86)【国際出願番号】 JP2021035277
(87)【国際公開番号】W WO2023047562
(87)【国際公開日】2023-03-30
【審査請求日】2024-03-01
(73)【特許権者】
【識別番号】000004237
【氏名又は名称】日本電気株式会社
(74)【代理人】
【識別番号】100107331
【氏名又は名称】中村 聡延
(74)【代理人】
【識別番号】100104765
【氏名又は名称】江上 達夫
(74)【代理人】
【識別番号】100131015
【氏名又は名称】三輪 浩誉
(72)【発明者】
【氏名】吉田 周平
【審査官】今城 朋彬
(56)【参考文献】
【文献】国際公開第2021/144943(WO,A1)
【文献】特表2021-502626(JP,A)
【文献】MEISTER, Clara et al.,Generalized Entropy Regularization or: There's Nothing Special about Label Smoothing,arXiv [online],2020年05月12日,インターネット<URL:https://arxiv.org/pdf/2005.00820.pdf>
(58)【調査した分野】(Int.Cl.,DB名)
G06N 20/00
(57)【特許請求の範囲】
【請求項1】
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力する推論手段と、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算する重み計算手段と、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算する重み総和計算手段と、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算する正則化項計算手段と、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する最適化手段と、
を備える学習装置。
【請求項2】
前記正則化項計算手段は、前記クラススコアが高い場合に前記正則化項の値を大きくし、前記クラススコアが低い場合に前記正則化項の値を小さくする請求項1に記載の学習装置。
【請求項3】
前記クラススコアと、前記訓練データに対応する正解クラスとに基づいて損失を計算する損失計算手段を備え、
前記総損失は、前記損失と、前記正則化項との和である請求項1又は2に記載の学習装置。
【請求項4】
前記クラススコアは、1つの訓練データに対する各クラスの信頼度スコアを含み、
前記重み関数は、前記各クラスの信頼度スコアの2乗を全クラスにわたり合計する関数であり、
前記リスケール関数は、前記総和の平方根を計算する関数である請求項1乃至3のいずれか一項に記載の学習装置。
【請求項5】
前記クラススコアは、1つの訓練データに対する各クラスの信頼度スコアを含み、
前記重み関数は、前記各クラスの信頼度スコアの2乗の自然対数を全クラスにわたり合計する関数であり、
前記リスケール関数は、前記総和の対数を計算する関数である請求項1乃至3のいずれか一項に記載の学習装置。
【請求項6】
前記クラススコアは、1つの訓練データに対する各クラスの信頼度スコアを含み、
前記重み関数は、前記各クラスの信頼度スコアの自然対数を全クラスにわたり合計する関数であり、
前記リスケール関数は、前記総和の対数を計算する関数である請求項1乃至3のいずれか一項に記載の学習装置。
【請求項7】
コンピュータが、
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力し、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算し、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算し、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算し、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する学習方法。
【請求項8】
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力し、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算し、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算し、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算し、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する処理をコンピュータに実行させるプログラム
。
【発明の詳細な説明】
【技術分野】
【0001】
本開示は、機械学習モデルの学習方法に関する。
【背景技術】
【0002】
深層学習などの大規模な機械学習モデルを学習する際、過学習を抑制するために正則化を行うことが知られている。例えば、特許文献1は、誤差関数に正則化項を加えたコスト関数を用いて、ニューラルネットワークの重みパラメータを更新する手法を開示している。
【先行技術文献】
【特許文献】
【0003】
【発明の概要】
【発明が解決しようとする課題】
【0004】
従来の手法では、全ての訓練データに対して一様に正則化を行っていた。このため、予測の簡単な訓練データに対して正則化が弱くなって過適合が生じたり、予測の難しい訓練データに対して正則化が強くなって学習の効率が低下したりすることがあった。
【0005】
本開示の1つの目的は、深層学習において、訓練データに応じて正則化の強度を適応的に制御することにある。
【課題を解決するための手段】
【0006】
本開示の一つの観点では、学習装置は、
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力する推論手段と、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算する重み計算手段と、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算する重み総和計算手段と、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算する正則化項計算手段と、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する最適化手段と、
を備える。
【0007】
本開示の他の観点では、学習方法は、
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力し、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算し、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算し、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算し、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する。
【0008】
本開示のさらに他の観点では、プログラムは、
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力し、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算し、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算し、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算し、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する処理をコンピュータに実行させる。
【発明の効果】
【0009】
本開示によれば、深層学習において、訓練データに応じて正則化の強度を適応的に制御することが可能となる。
【図面の簡単な説明】
【0010】
【
図1】第1実施形態の学習装置のハードウェア構成を示すブロック図である。
【
図2】第1実施形態の学習装置の機能構成を示すブロック図である。
【
図4】第1実施形態の学習装置による学習処理のフローチャートである。
【
図5】第2実施形態の学習装置の機能構成を示すブロック図である。
【
図6】第2実施形態の学習装置による学習処理のフローチャートである。
【発明を実施するための形態】
【0011】
以下、図面を参照して、本開示の好適な実施形態について説明する。
<第1実施形態>
[学習装置]
(ハードウェア構成)
図1は、第1実施形態の学習装置100のハードウェア構成を示すブロック図である。図示のように、学習装置100は、インタフェース(I/F)11と、プロセッサ12と、メモリ13と、記録媒体14と、データベース(DB)15と、を備える。
【0012】
インタフェース11は、外部装置との間でデータの入出力を行う。具体的に、学習に使用される訓練データセットは、インタフェース11を通じて学習装置100に入力される。
【0013】
プロセッサ12は、CPU(Central Processing Unit)などのコンピュータであり、予め用意されたプログラムを実行することにより学習装置100の全体を制御する。なお、プロセッサ12は、GPU(Graphics Processing Unit)またはFPGA(Field-Programmable Gate Array)であってもよい。プロセッサ12は、後述する学習処理を実行する。
【0014】
メモリ13は、ROM(Read Only Memory)、RAM(Random Access Memory)などにより構成される。メモリ13は、プロセッサ12による各種の処理の実行中に作業メモリとしても使用される。
【0015】
記録媒体14は、ディスク状記録媒体、半導体メモリなどの不揮発性で非一時的な記録媒体であり、学習装置100に対して着脱可能に構成される。記録媒体14は、プロセッサ12が実行する各種のプログラムを記録している。学習装置100が各種の処理を実行する際には、記録媒体14に記録されているプログラムがメモリ13にロードされ、プロセッサ12により実行される。DB15は、必要に応じて、I/F11を通じて入力された訓練データセットを記憶する。
【0016】
(機能構成)
図2は、第1実施形態の学習装置100の機能構成を示すブロック図である。学習装置100は、推論部21と、損失関数計算部22と、総和計算部23と、重み関数計算部24と、重み総和計算部25と、リスケール関数計算部26と、パラメータ更新部27と、を備える。
【0017】
学習装置100には、訓練データセットが入力される。訓練データセットは、複数の訓練データxiと、各訓練データxiに対応する正解クラスyiとを含む。訓練データxiは推論部21に入力され、正解クラスyiは損失関数計算部22へ入力される。
【0018】
推論部21は、学習装置100による学習の対象となる深層学習モデルを用いて推論を行う。具体的には、推論部21は、学習の対象となる深層学習モデルを構成するニューラルネットワークを備える。推論部21は、入力された訓練データxiに対する推論を行い、推論結果としてクラススコアv→
iを出力する。詳細には、推論部21は、訓練データxiに対するクラス分類を行い、クラス毎の信頼度スコアを示すベクトルであるクラススコアv→
iを出力する。なお、本明細書では、便宜上ベクトルを示す「→」を、「v」の右側に上付きで表記する。クラススコアv→
iは、損失関数計算部22及び重み関数計算部24へ入力される。
【0019】
損失関数計算部22は、予め用意された損失関数を用いて、クラススコアv→
iに対する損失lcls,iを計算する。具体的に、損失関数計算部22は、ある訓練データxiに対するクラススコアv→
iと、その訓練データxiに対する正解クラスyiとを用いて、式(1)に示すように損失lcls,iを計算する。計算された損失lcls,iは、総和計算部23へ入力される。
【0020】
【0021】
一方、重み関数計算部24は、推論部21が生成したクラススコアv→
iに基づいて、訓練データxiに対する重みを計算する。具体的に、重み関数計算部24は、訓練データxiに対する推論結果であるクラススコアv→
iから、以下の式(2)により、単一の実数値である重みwiを決定する。
【0022】
【0023】
重み関数としては、クラススコアv→
iに含まれる各クラスの信頼度スコアが過大または過少なときに急速に増大する関数が選ばれる。「急速に」とは、線形より早く、という意味である。重み関数の増大が急速であるという条件は、クラススコアv→
iに含まれる過大または過小な信頼度スコアを強調するために必要となる。即ち、急速に増大する関数を用いて重みを計算することにより、クラススコアv→
iが過大または過小な信頼度スコアの値を含む場合、それら過大または過小な値が強調され、重みwiはより大きな値となる。これにより、重み関数の選択が、後述する正則化項の勾配に対する各訓練データの重みの寄与度を決定することになる。なお、重み関数計算部24は、単にクラススコアv→
iに含まれる各クラスの信頼度スコアを重み関数に入力した結果を出力するため、出力される重みwiの値は特に正規化された値ではない。重み関数計算部24は、計算した重みwiを重み総和計算部25へ出力する。
【0024】
重み総和計算部25は、重みwiのミニバッチ分の総和を計算する。ミニバッチとは、所定数(例えばN個)の訓練データの集合である。具体的に、重み総和計算部25は、下記の式(3)により、N個の訓練データxiに対応するN個の重みwiの総和Sを計算する。
【0025】
【0026】
重み総和計算部25は、計算した総和Sをリスケール関数計算部26へ出力する。
【0027】
リスケール関数計算部26は、入力された総和Sに基づき、リスケール関数の計算を行って正規化項Lregを生成する。具体的に、リスケール関数計算部26は、以下の式(4)により、正規化項Lregを生成する。
【0028】
【0029】
式(4)において、「g(S)」はリスケール関数である。リスケール関数g(S)としては、緩やかに増大する単調増加関数が選ばれる。なお、この緩やかに増大する単調増加関数は、数学的な「緩増加関数」とは異なる。
【0030】
ここで、「緩やかに」とは、線形より遅く、という意味である。リスケール関数g(S)が緩やかであるという条件は、急速に増大する重み関数によって正則化項の勾配が増大し、その結果、学習が不安定になることを抑制するために必要となる。言い換えると、重み関数により過大または過小な信頼度スコアが強調された重みwiをそのまま使うと正則化が強すぎてしまう恐れがあるため、リスケール関数g(S)を用いて、重みwiの全体のスケールを調整している。この点、リスケール関数g(S)は、重みwiを正規化し、全体の正則化の強さを調整していると捉えることもできる。リスケール関数計算部26は、こうして得られた正規化項Lregを総和計算部23へ出力する。
【0031】
総和計算部23は、損失関数計算部22から入力される損失lcls,iと、リスケール関数計算部26から入力される正規化項Lregとの総和(以下、「総損失L」とも呼ぶ。)を計算する。具体的に、総和計算部23は、下記の式(5)により、損失lcls,iと正規化項Lregの和を訓練データ数i個分加算した値を、ミニバッチに含まれる訓練データ数Nで除して総損失Lを計算する。
【0032】
【0033】
そして、総和計算部23は、得られた総損失Lをパラメータ更新部27へ出力する。
【0034】
パラメータ更新部27は、入力された総損失Lに基づいて推論部21を最適化する。具体的には、パラメータ更新部27は、総損失Lに基づいて、推論部21を構成するニューラルネットワークのパラメータを更新する。こうして、推論部21を構成する深層学習モデルの学習が行われる。
【0035】
以上のように、第1実施形態の学習装置100によれば、正則化項をミニバッチの単位で計算することにより、各訓練データの正則化項に対する寄与度を適応的に決定することができる。また、学習装置100は、推論部21が出力する過大または過小な推論結果を重み関数を用いて強調することで、簡単な訓練データに対しては正則化を強めることで過適合を防ぎ、難しい訓練データに対しては正則化を弱めることで学習の効率を上げることができる。さらに、学習装置100は、リスケール関数を用いて重みの全体のスケールを調整することで、重み関数を用いて部分的に強調された重みを正規化し、全体の正則化の強さを調整することができる。その結果、訓練データに応じて正則化の強度を適応的に決定し、より高い汎化性能、即ち分類精度を得ることが可能となる。
【0036】
上記の構成において、推論部21は推論手段の一例であり、損失関数計算部22は損失計算手段の一例であり、重み関数計算部24は重み計算手段の一例であり、重み総和計算部25は重み計算手段の一例であり、リスケール関数計算部26は正則化項計算手段の一例であり、パラメータ更新部27は最適化手段の一例である。
【0037】
(関数の例)
図3は、重み関数とリスケール関数の例を示す。第1の例では、重み関数は、クラススコアv
→
iに含まれる各クラスの信頼度スコアv
icの2乗を、全クラス数cにわたり合計する関数である。また、リスケール関数は、重み総和計算部25が出力する総和Sの平方根を計算する関数である。
【0038】
第2の例では、重み関数は、クラススコアv→
iに含まれる各クラスの信頼度スコアvicの2乗の自然対数を全クラス数cにわたり合計する関数である。また、リスケール関数は、重み総和計算部25が出力する総和Sの対数を計算する関数である。
【0039】
第3の例では、重み関数は、クラススコアv→
iに含まれる各クラスの正負の信頼度スコアvicの自然対数を、全クラス数cにわたり合計する関数である。また、リスケール関数は、重み総和計算部25が出力する総和Sの対数を計算する関数である。
【0040】
(学習処理)
図4は、学習装置100による学習処理のフローチャートである。この処理は、
図1に示すプロセッサ12が予め用意されたプログラムを実行し、
図2に示す各要素として動作することにより実現される。
【0041】
まず、推論部21は、入力された訓練データxiに対する推論を行う(ステップS11)。推論部21は、推論により得られたクラススコアv→
iを、損失関数計算部22及び重み関数計算部24に出力する。損失関数計算部22は、クラススコアv→
iに基づき、式(1)を用いて損失lcls,iを計算し、総和計算部23へ出力する(ステップS12)。
【0042】
次に、重み関数計算部24は、クラススコアv→
iに基づき、式(2)を用いて重みwiを計算し、重み総和計算部25へ出力する(ステップS13)。次に、重み総和計算部25は、式(3)によりミニバッチ毎に重みwiの総和Sを計算し、リスケール関数計算部26へ出力する(ステップS14)。次に、リスケール関数計算部26は、リスケール関数を用いて、入力された総和Sから正規化項Lregを計算し、総和計算部23へ出力する(ステップS15)。なお、ステップS12と、ステップS13~S15の処理は、逆の順序で行われてもよく、時間的に並行して行われてもよい。
【0043】
次に、総和計算部23は、損失関数計算部22から入力される損失lcls,iと、リスケール関数計算部26から入力される正規化項Lregとに基づき、式(5)を用いて損失の総和(総損失L)を計算し、パラメータ更新部27へ出力する(ステップS16)。次に、パラメータ更新部27は、損失の総和(総損失L)に基づいて、推論部21を構成するニューラルネットワークのパラメータを更新する(ステップS17)。
【0044】
次に、学習の終了条件が具備されたか否かが判定される(ステップS18)。終了要件としては、例えば、全ての訓練データが使用されたこと、または、推論部21の精度が所定の精度に達したこと、などを用いることができる。終了条件が具備されていない場合(ステップS18:No)、処理はステップS11へ戻り、次の訓練データを用いてステップS11~S17の処理が行われる。一方、終了条件が具備された場合(ステップS18:Yes)、学習処理は終了する。
【0045】
<第2実施形態>
図5は、第2実施形態の学習装置の機能構成を示すブロック図である。学習装置200は、推論手段201と、重み計算手段202と、重み総和計算手段203と、正則化項計算手段204と、最適化手段205と、を備える。
【0046】
図6は、第2実施形態の学習装置200による学習処理のフローチャートである。まず、推論手段201は、訓練データに対する推論を行い、クラススコアを出力する(ステップS21)。次に、重み計算手段202は、推論手段201が出力したクラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算する(ステップS22)。次に、重み総和計算手段203は、所定数の訓練データを含むミニバッチにわたって、重みの総和を計算する(ステップS23)。次に、正則化項計算手段204は、線形より緩やかに増大する単調増加関数であるリスケール関数を総和に適用し、正則化項を計算する(ステップS24)。そして、最適化手段205は、正則化項を含む損失を用いて、推論手段を最適化する(ステップS25)。
【0047】
第2実施形態の学習装置200によれば、深層学習において、訓練データに応じて正則化の強度を適応的に制御することが可能となる。
【0048】
上記の実施形態の一部又は全部は、以下の付記のようにも記載されうるが、以下には限られない。
【0049】
(付記1)
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力する推論手段と、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算する重み計算手段と、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算する重み総和計算手段と、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算する正則化項計算手段と、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する最適化手段と、
を備える学習装置。
【0050】
(付記2)
前記正則化項計算手段は、前記クラススコアが高い場合に前記正則化項の値を大きくし、前記クラススコアが低い場合に前記正則化項の値を小さくする付記1に記載の学習装置。
【0051】
(付記3)
前記クラススコアと、前記訓練データに対応する正解クラスとに基づいて損失を計算する損失計算手段を備え、
前記総損失は、前記損失と、前記正則化項との和である付記1又は2に記載の学習装置。
【0052】
(付記4)
前記クラススコアは、1つの訓練データに対する各クラスの信頼度スコアを含み、
前記重み関数は、前記各クラスの信頼度スコアの2乗を全クラスにわたり合計する関数であり、
前記リスケール関数は、前記総和の平方根を計算する関数である付記1乃至3のいずれか一項に記載の学習装置。
【0053】
(付記5)
前記クラススコアは、1つの訓練データに対する各クラスの信頼度スコアを含み、
前記重み関数は、前記各クラスの信頼度スコアの2乗の自然対数を全クラスにわたり合計する関数であり、
前記リスケール関数は、前記総和の対数を計算する関数である付記1乃至3のいずれか一項に記載の学習装置。
【0054】
(付記6)
前記クラススコアは、1つの訓練データに対する各クラスの信頼度スコアを含み、
前記重み関数は、前記各クラスの信頼度スコアの自然対数を全クラスにわたり合計する関数であり、
前記リスケール関数は、前記総和の対数を計算する関数である付記1乃至3のいずれか一項に記載の学習装置。
【0055】
(付記7)
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力し、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算し、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算し、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算し、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する学習方法。
【0056】
(付記8)
推論モデルを用いて訓練データに対する推論を行い、クラススコアを出力し、
出力された前記クラススコアから、当該クラススコアが過大又は過少なときに線形より急速に増大する重み関数を用いて重みを計算し、
所定数の訓練データを含むミニバッチにわたって前記重みの総和を計算し、
線形より緩やかに増大する単調増加関数であるリスケール関数を前記総和に適用して、正則化項を計算し、
前記正則化項を含む総損失を用いて、前記推論モデルを最適化する処理をコンピュータに実行させるプログラムを記録した記録媒体。
【0057】
以上、実施形態及び実施例を参照して本開示を説明したが、本開示は上記実施形態及び実施例に限定されるものではない。本開示の構成や詳細には、本開示のスコープ内で当業者が理解し得る様々な変更をすることができる。
【符号の説明】
【0058】
12 プロセッサ
21 推論部
22 損失関数計算部
23 総和計算部
24 重み関数計算部
25 重み総和計算部
26 リスケール関数計算部
27 パラメータ更新部
100、200 学習装置