(19)【発行国】日本国特許庁(JP)
(12)【公報種別】特許公報(B2)
(11)【特許番号】
(24)【登録日】2023-11-13
(45)【発行日】2023-11-21
(54)【発明の名称】学習装置、学習方法、及び、プログラム
(51)【国際特許分類】
G06N 20/00 20190101AFI20231114BHJP
G06N 3/096 20230101ALI20231114BHJP
G06T 7/00 20170101ALI20231114BHJP
G06V 10/70 20220101ALI20231114BHJP
【FI】
G06N20/00 130
G06N3/096
G06T7/00 350B
G06V10/70
(21)【出願番号】P 2021555705
(86)(22)【出願日】2019-11-13
(86)【国際出願番号】 JP2019044596
(87)【国際公開番号】W WO2021095176
(87)【国際公開日】2021-05-20
【審査請求日】2022-04-15
(73)【特許権者】
【識別番号】000004237
【氏名又は名称】日本電気株式会社
(74)【代理人】
【識別番号】100107331
【氏名又は名称】中村 聡延
(74)【代理人】
【識別番号】100104765
【氏名又は名称】江上 達夫
(74)【代理人】
【識別番号】100131015
【氏名又は名称】三輪 浩誉
(72)【発明者】
【氏名】石井 遊哉
【審査官】多賀 実
(56)【参考文献】
【文献】米国特許出願公開第2018/0268292(US,A1)
【文献】特開2019-053569(JP,A)
【文献】特開2019-215861(JP,A)
【文献】WANG, Mengjiao et al.,"DISCOVER THE EFFECTIVE STRATEGY FOR FACE RECOGNITION MODEL COMPRESSION BY IMPROVED KNOWLEDGE DISTILLATION",2018 25th IEEE International Conference on Image Processing (ICIP) [online],米国,IEEE,2018年09月06日,pp.2416-2420,[検索日 2020.02.04], インターネット:<URL:https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8451808>
【文献】藤吉 弘亘 ほか,「深層学習による画像認識」,日本ロボット学会誌,一般社団法人日本ロボット学会,2017年04月15日,第35巻, 第3号,pp.8-13
(58)【調査した分野】(Int.Cl.,DB名)
G06N 3/00-99/00
G06T 7/00- 7/90
G06V 10/70-10/86
(57)【特許請求の範囲】
【請求項1】
学習データに対する推論結果を出力する教師モデルと、
前記学習データに対する推論結果を出力する生徒モデルと、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいてトータル損失を算出する損失算出手段と、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する更新手段と、
を備え、
前記損失算出手段は、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、の
うちの複数の損失の加重平均により前記トータル損失を算出する学習装置。
【請求項2】
前記損失算出手段は、前記真値と前記生徒モデルの出力との差、前記真値と前記教師モデルの出力との差、及び、前記教師モデルの出力と前記生徒モデルの出力との差を距離として計算する距離計算器を備える請求項1に記載の学習装置。
【請求項3】
前記損失算出手段は、前記自信度をエントロピー関数により算出する請求項1又は2に記載の学習装置。
【請求項4】
前記学習データは画像データであり、
前記画像データから特徴を抽出して特徴マップを生成する特徴抽出手段を備え、
前記教師モデル及び前記生徒モデルは、前記特徴マップに対して規定したアンカー毎に前記推論結果を出力する請求項1乃至
3のいずれか一項に記載の学習装置。
【請求項5】
前記教師モデル及び前記生徒モデルは、前記特徴抽出手段が抽出した特徴マップに基づいて、前記画像データに含まれる対象物のクラス分類を行う請求項
4に記載の学習装置。
【請求項6】
前記教師モデル及び前記生徒モデルは、前記特徴抽出手段が抽出した特徴マップに基づいて、前記画像データに含まれる対象物の位置を検出する請求項
5に記載の学習装置。
【請求項7】
教師モデルを用いて、学習データに対する推論結果を出力し、
生徒モデルを用いて、前記学習データに対する推論結果を出力し、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいて、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、の
うちの複数の損失の加重平均によりトータル損失を算出し、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する学習方法。
【請求項8】
教師モデルを用いて、学習データに対する推論結果を出力し、
生徒モデルを用いて、前記学習データに対する推論結果を出力し、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいて、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、の
うちの複数の損失の加重平均によりトータル損失を算出し、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する処理をコンピュータに実行させるプログラム。
【発明の詳細な説明】
【技術分野】
【0001】
本発明は、物体検知に関する。
【背景技術】
【0002】
近年、深層学習を用いたニューラルネットワークによる物体検知技術が多く提案されている。物体検知とは、画像や動画に映る物体が何であるかを推定し、それと同時に物体に外接する矩形の位置を求めることで、物体の位置と大きさを推定することである。物体検出器は、物体の矩形位置と、その物体が各クラスである確率を示す信頼度を出力する。
【0003】
非特許文献1は、元画像から特徴抽出により得た特徴マップ上にアンカーを規定し、注目すべきアンカーに重点を置いて学習を行う手法を記載している。また、非特許文献2は、さらに蒸留という手法を利用し、学習済みの教師モデルの出力を用いて生徒モデルの学習を行う手法を記載している。
【先行技術文献】
【非特許文献】
【0004】
【文献】Focal Loss for Dense Object Detection, Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollar, arXiv:1708.02002v2, 2018
【文献】Learning Efficient Detector with Semi-supervised Adaptive Distillation, Shitao Tang, Litong Feng, Wenqi Shao, Zhanghui Kuang, Wei Zhang, Yimin Chen, arXiv: 1901.00366v2, 2019
【発明の概要】
【発明が解決しようとする課題】
【0005】
非特許文献2の手法は、主として以下の方針で生徒モデルの学習を行う。
・教師モデルが自信のないアンカーほど、教師モデルと生徒モデルとの出力の差を小さくする。
・教師モデルと生徒モデルの出力の差が大きいアンカーほど、教師モデルと生徒モデルとの出力の差を小さくする。
【0006】
しかし、上記の方針によると、教師モデルの精度が低い場合に、誤った方向性で学習が行われることがある。即ち、1つ目の方針では、教師モデルが自信のないアンカーほど、生徒モデルも自信のない出力を行うように学習されてしまうことがある。また、2つ目の方針では、教師モデルの出力が不正解だと、仮に生徒モデルが正解していたとしても、生徒モデルは不正解を学習し直してしまう恐れがある。
【0007】
本発明の1つの目的は、教師モデルの精度が特別高くない場合でも、それによる悪影響を受けずに、正しく学習を行うことが可能な学習方法を提供することにある。
【課題を解決するための手段】
【0008】
本発明の一つの観点では、学習装置は、
学習データに対する推論結果を出力する教師モデルと、
前記学習データに対する推論結果を出力する生徒モデルと、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいてトータル損失を算出する損失算出手段と、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する更新手段と、
を備え、
前記損失算出手段は、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、のうちの複数の損失の加重平均により前記トータル損失を算出する。
【0009】
本発明の他の観点では、学習方法は、
教師モデルを用いて、学習データに対する推論結果を出力し、
生徒モデルを用いて、前記学習データに対する推論結果を出力し、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいて、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、のうちの複数の損失の加重平均によりトータル損失を算出し、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する。
【0010】
本発明の他の観点では、プログラムは、
教師モデルを用いて、学習データに対する推論結果を出力し、
生徒モデルを用いて、前記学習データに対する推論結果を出力し、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいて、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、のうちの複数の損失の加重平均によりトータル損失を算出し、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する処理をコンピュータに実行させる。
【発明の効果】
【0011】
本発明によれば、教師モデルの精度が特別高くない場合でも、それによる悪影響を受けずに、正しく学習を行うことができる。
【図面の簡単な説明】
【0012】
【
図1】ニューラルネットワークを用いた物体検知モデルの基本構成例を示す。
【
図3】第1実施例に係る物体検知装置のハードウェア構成を示す。
【
図4】第1実施例に係る物体検知装置のクラス分類に関する機能構成を示す。
【
図8】第1実施例による物体検知装置の全体の機能構成を示す。
【
図9】第2実施例に係る物体検知装置のクラス分類に関する機能構成を示す
【
図12】第2実施例による物体検知装置の全体の機能構成を示す。
【
図13】本発明の第2実施形態に係る学習装置の機能構成を示す。
【
図14】第2実施形態による学習処理のフローチャートである。
【発明を実施するための形態】
【0013】
以下、図面を参照して、本発明の好適な実施形態について説明する。
<物体検知モデルの基本構成>
図1は、ニューラルネットワークを用いた物体検知モデルの基本構成例を示す。なお、
図1の例は、「RetinaNet」と呼ばれるネットワークであるが、本発明は他に「SSD」、「YOLO」、「Faster RCNN」などのネットワークにも適用可能である。学習モデル70は、特徴抽出部71と、クラス分類部72と、矩形位置検出部73とを備える。特徴抽出部71は、CNN(Convolutional Neural Network)などにより入力画像から特徴を抽出し、特徴マップを生成する。クラス分類部72は、特徴マップに基づいて検出対象のクラス分類を行い、分類結果を出力する。
図1の例では、検出対象は、「自転車」、「馬」、「犬」、「車」の4クラスであり、クラス分類部72は分類結果として各クラスの信頼度(確率)を出力する。また、矩形位置検出部73は、特徴マップに基づいて、各検出対象に外接する矩形位置の座標を出力する。
【0014】
入力画像に対しては、あらかじめ正解データ(「真値(ground truth)」とも呼ぶ。)が用意されている。クラス分類部72によるクラス分類結果と、クラス分類の正解データとに基づいてクラス分類の差分(以下、「分類ロス」とも呼ぶ。なお、「ロス」を「損失」とも呼ぶ。)Lclsが算出される。また、矩形位置検出部73により検出された矩形位置の座標と、矩形位置の座標の正解データとの差分(以下、「回帰ロス」とも呼ぶ。)Lregが算出される。なお、矩形位置の座標と矩形位置の座標の正解データの差分は、回帰以外の方法で算出することもできるが、本実施形態では回帰により算出するものとし、「回帰ロス」と呼ぶ。そして、以下に示す分類ロスLclsと回帰ロスLregの合計(「トータルロス」とも呼ぶ。)Lを最小化するように、学習モデルの学習が行われる。
【0015】
【0016】
<クラス分類器の学習>
次に、クラス分類器の学習について説明する。
[フォーカルロス]
まず、「フォーカルロス(Focal Loss:以下、「FL」とも記す。)」と呼ばれる手法について説明する。
図1に示すRetinaNetは、特徴抽出部71により抽出された特徴マップ上に、その画素ごとに広がりを持つ「アンカー」を埋め込み、アンカーごとにクラス分類と矩形位置の検出を行う手法である。特に、フォーカルロスは、特徴マップに含まれる複数のアンカーのうち、注目すべきアンカーに重きをおいて学習を行う。例えば、特徴マップ上に設定された複数のアンカーのうち、背景に対応するアンカーよりも、検出対象が存在すると予測されるアンカーに注目する。具体的には、DNN(Deep Neural Network)により予測が難しいアンカー、即ち、正解と予測との差が大きいアンカーほど注目度を高くする。フォーカルロスFL(p)は以下の式で表される。なお、「α」は、学習データのクラスバランスに基づき決定される定数である。
【0017】
【0018】
例えば、検出対象として犬と自転車が含まれるアンカーでは、「犬」と「自転車」については正解クラスの式を使用し、それ以外については不正解クラスの式を使用する。そして、前述の式(1)における分類ロスとして、Lclsの代わりに、FL(p)を使用し、以下のトータルロスLを用いてモデルの学習を行う。
【0019】
【0020】
[蒸留]
次に、蒸留(Knowledge Distillation)と呼ばれる手法について説明する。蒸留は、既に学習済みの教師モデルの出力を用いて生徒モデルの学習を行う手法である。
図2は、蒸留を用いた物体検知モデルを示す。生徒モデル80は、学習の対象となるモデルであり、特徴抽出部81と、クラス分類部82と、矩形位置検出部83とを備える。特徴抽出部81は入力画像から特徴マップを生成する。クラス分類部82は特徴マップに基づいて検出対象のクラス分類結果を出力する。また、矩形位置検出部83は、特徴マップに基づいて検出対象の矩形位置の座標を出力する。
【0021】
一方、教師モデル90は、多数の画像を用いて予め学習済みのモデルであり、特徴抽出部91と、クラス分類部92と、矩形位置検出部93とを備える。入力画像は、教師モデル90にも入力される。教師モデル90では、特徴抽出部91は入力画像から特徴マップを生成する。クラス分類部92は特徴マップに基づいて検出対象のクラス分類結果を出力する。また、矩形位置検出部93は、特徴マップに基づいて検出対象の矩形位置の座標を出力する。
【0022】
生徒モデル80が出力したクラス分類結果と、教師モデル90が出力したクラス分類結果との差分が分類ロスLclsとして算出され、生徒モデル80が出力した矩形位置の座標と、教師モデル90が出力した矩形位置の座標との差分が回帰ロスLregとして算出される。なお、回帰ロスLregとしては、生徒モデル80が出力した矩形位置の座標と、真値との差分を用いてもよい。そして、上記の式(1)で示すトータルロスLが最小となるように、生徒モデル80の学習が行われる。
【0023】
[ADL]
次に、ADL(Adaptive Distillation knowledge Loss)について説明する。ADLは、フォーカルロスの考え方を蒸留に適用した学習方法であり、以下の方針でモデルの学習を行う。
・教師モデルと生徒モデルの出力の差が大きいアンカーほど、教師モデルと生徒モデルの出力の差を小さくする。
・教師モデルが自信のないアンカーほど、教師モデルと生徒モデルの出力の差を小さくする。
上記の方針より、ADLは以下の式で示される。
【0024】
【数4】
・「q」は教師モデルの出力であり、「p」は生徒モデルの出力である。
・「KL」は、KL Divergenceであり、「KL距離」又は単に「距離」とも例えられる。KL(q||p)は、qとpの値の近さを測る関数であり、q=pのときに最小値「0」をとる。
・「T」はエントロピー関数であり、T(q)=-qlog[q]で与えられる。T(q)は、上に凸の関数であり、q=0.5のときに最大となり、q=0,1のとき最小となる。q=0,1のときは、教師モデルのクラス分類結果の自信が大きいことを示す。一方、q=0.5のときは、教師モデルのクラス分類結果がどちらかわからず、自信がないことを示す。よって、T(q)により教師モデルの出力の自信度を測ることができる。
・「β」、「γ」は、ハイパーパラメータである。
【0025】
また、フォーカルロスの手法と、ADLの手法を組み合わせると、以下の方針が得られる。
・生徒モデルが間違えているアンカーほど、真値と生徒モデルの出力との差を小さくする。
この場合のトータルロスは、以下の式が用いられる。
【0026】
【0027】
以上より、下記の方針が得られる。
(A)生徒が間違えているアンカーほど、真値と生徒の出力との差を小さくする。
(B)教師が自信のない(0.5を出力する)アンカーほど、教師モデルと生徒モデルの出力の差を小さくする。
(C)教師モデルと生徒モデルの出力の差が大きいアンカーほど、教師モデルと生徒モデルの出力の差を小さくする。
【0028】
しかし、上記の方針(B)、(C)によると、教師モデルの精度が低い場合に、誤った方向性で学習が行われることがある。即ち、方針(B)では、教師モデルが自信のないアンカーほど、生徒モデルも自信のない出力を行うように学習されることがある。また、方針(C)では、教師モデルの出力が不正解だと、仮に生徒モデルが正解していたとしても、生徒モデルは不正解を学習し直してしまう恐れがある。
【0029】
<第1実施形態>
[基本原理]
上記の観点から、本実施形態では以下の方針(1)~(4)を考慮する。
(1)教師モデルと生徒モデルの出力の差が大きいアンカーほど、教師モデルと生徒モデルの出力の差を小さくする。これは、上記の方針(C)と同一である。この方針で得られる教師モデルと生徒モデルのロス(損失)を「L1」とすると、L1は以下の式で得られる。なお、「γ1」はハイパーパラメータである。
【0030】
【0031】
(2)教師モデルが自信のないアンカーほど、真値と生徒モデルの出力との差を小さくする。これにより、上記の方針(B)の不具合が解消できる。この方針で得られる教師モデルと生徒モデルのロスを「L2」とすると、L2は以下の式で得られる。なお、「γ2」はハイパーパラメータである。
【0032】
【0033】
(3)真値と教師モデルの出力との差が大きいアンカーほど、真値と生徒モデルの出力との差を小さくする。この方針で得られる教師モデルと生徒モデルのロスを「L3」とすると、L3は以下の式で得られる。なお、「γ2」はハイパーパラメータである。
【0034】
【0035】
(4)教師モデルと生徒モデルの出力の差が大きく、かつ、真値と教師モデルの出力との差が小さいアンカーほど、教師モデルと生徒モデルの出力の差を小さくする。これにより、上記の方針(C)の不具合が解消できる。この方針で得られる教師モデルと生徒モデルのロスを「L4」とすると、L4は以下の式で得られる。なお、「γ1」はハイパーパラメータである。
【0036】
【0037】
ここで、ロスL4は方針(C)の不具合を解消するものであるので、上記のロスL1の代わりにロスL4を使用することが望ましい。よって、本実施形態では、上記のロスL2~L4の少なくとも1つ、即ち、いずれか1つ又は複数の組み合わせを「myADL」とし、下記のトータルロスLが小さくなるようにモデルの学習を行う。
【0038】
【数10】
これにより、上記の方針(B)、(C)により、教師モデルの精度が低い場合に、誤った方向性で学習が行われるという不具合を解消することができる。
【0039】
以上より、本実施形態によれば、特別に精度の高い教師モデルを用いなくても、生徒モデルの精度を向上させることができる。また、教師モデルの出力を目標として生徒モデルの学習を行うので、生徒モデルの出力を真値に近づける場合と比較して、学習の収束を早めることができる。言い換えると、学習データが少なくても、十分な認識精度を得ることができる。なお、上記の説明ではロスL1~ロスL4を挙げて説明しているが、本発明においては、ロスL1は用いなくてもよく、ロスL2~L4のうちの少なくとも1つを用いればよい。
【0040】
[第1実施例]
次に、第1実施形態の第1実施例について説明する。第1実施例は、上記のロスL1~L4のうち、ロスL1及びL2を使用するものである。
【0041】
(ハードウェア構成)
図3は、第1実施例に係る物体検知装置のハードウェア構成を示すブロック図である。図示のように、物体検知装置100は、入力インタフェース(IF)12と、プロセッサ13と、メモリ14と、記録媒体15と、データベース(DB)16と、を備える。
【0042】
入力IF12は、物体検知に必要なデータを外部から入力するためのインタフェースである。具体的に、物体検知装置100が学習時に使用する学習データや、学習後の実際の物体検知処理に使用する画像データなどが入力IF12を介して入力される。
【0043】
プロセッサ13は、CPU(Central Processing Unit)又はGPU(Graphics Processing Unit)などのコンピュータであり、予め用意されたプログラムを実行することにより、物体検知装置100の全体を制御する。具体的に、プロセッサ13は、後述する物体検知モデルの学習を行う。
【0044】
メモリ14は、ROM(Read Only Memory)、RAM(Random Access Memory)などにより構成される。メモリ14は、プロセッサ13により実行される各種のプログラムを記憶する。また、メモリ14は、プロセッサ13による各種の処理の実行中に作業メモリとしても使用される。
【0045】
記録媒体15は、ディスク状記録媒体、半導体メモリなどの不揮発性で非一時的な記録媒体であり、物体検知装置100に対して着脱可能に構成される。記録媒体15は、プロセッサ13が実行する各種のプログラムを記録している。物体検知装置100が各種の処理を実行する際には、記録媒体15に記録されているプログラムがメモリ14にロードされ、プロセッサ13により実行される。
【0046】
データベース16は、入力IF12を通じて外部装置から入力される画像データなどを記憶する。具体的には、物体検知装置100の学習に使用される画像データが記憶される。なお、上記に加えて、物体検知装置100は、ユーザが指示や入力を行うためのキーボード、マウスなどの入力機器や、ユーザに物体検知の結果を提示する表示装置などを備えていても良い。
【0047】
(機能構成)
前述のように、物体検知装置はクラス分類と矩形位置検出を行うが、説明の便宜上、まず、クラス分類に関する構成のみを先に説明する。
図4は、第1実施例に係る物体検知装置100のクラス分類に関する機能構成を示すブロック図である。なお、
図4は、物体検知装置100の学習のための構成を示す。
【0048】
図示のように、物体検知装置100は、教師モデル110と、生徒モデル120と、L1計算部130と、L2計算部140と、FL計算部150と、加重平均計算器161と、パラメータ更新量計算器162と、を備える。
【0049】
学習のためのデータとしては、学習データD1と、それに対応する真値yとが用意される。学習データD1は、検知対象が写った画像を含む画像データである。真値yは、学習データに対するクラス分類の正解を示すデータである。学習データD1は、教師モデル110と、生徒モデル120に入力される。また、真値yは、L2計算部140及びFL計算部150に入力される。
【0050】
教師モデル110は、多数の学習データを用いて既に学習済みのモデルであり、学習データD1から物体を検出してそのクラス分類結果(以下、「教師モデル出力」とも呼ぶ。)qをL1計算部130及びL2計算部140に入力する。生徒モデル120は、学習の対象となるモデルであり、学習データD1から物体を検出してそのクラス分類結果(以下、「生徒モデル出力」とも呼ぶ。)pをL1計算部130、L2計算部140及びFL計算部150に入力する。
【0051】
L1計算部130は、教師モデル出力qと、生徒モデル出力pを用いて、前述のロスL1を算出する。
図5は、L1計算部130の構成を示す。L1計算部130は、距離計算器131と、係数1計算器132と、L1計算器133とを備える。距離計算器131は、教師モデル出力qと生徒モデル出力pの距離KL(p||q)を算出する。係数1計算器132は、上記の式(6)に基づき、以下の係数1を算出する。
【0052】
【数11】
そして、L1計算部133は、係数1と、距離KL(q||p)に基づいて、式(6)によりロスL1を算出する。
【0053】
L2計算部140は、教師モデル出力qと、生徒モデル出力pと、真値yとを用いて、ロスL2を算出する。
図6は、L2計算部140の構成を示す。L2計算部140は、自信計算器141と、係数2計算器142と、距離計算器143と、L2計算器144とを備える。自信計算器141は、前述のエントロピー関数Tを用いて、教師モデル出力qの自信T(q)を算出する。係数2計算器142は、上記の式(7)に基づき、以下の係数2を算出する。
【0054】
【数12】
この係数2は、教師モデル出力qの自信T(q)が低いほど大きくなり、本発明の第1の重みに相当する。
【0055】
一方、距離計算器143は、真値yと生徒モデル出力pとの距離KL(y||p)を算出する。そして、L2計算器144は、係数2と、距離KL(y||p)に基づいて、式(7)によりロスL1を算出する。
【0056】
FL計算部150は、真値yと生徒モデル出力pを用いて、前述のフォーカルロスFLを算出する。
図7は、FL計算部150の構成を示す。FL計算部150は、FL係数計算器151と、距離計算器152と、FL計算器153とを備える。FL係数計算器151は、真値yと生徒モデル出力pとを用いて上記の式(2)に基づき、下記のFL係数を算出する。
【0057】
【数13】
また、距離計算器152は、真値yと生徒モデル出力pの距離KL(y||p)を算出する。そして、FL計算器153は、FL係数と距離KL(y||p)とに基づいて、上記の式(2)により係数FLを算出する。
【0058】
図4に戻り、加重平均計算器161は、所定の重みを用いて、L1計算部130から出力されるロスL1と、L2計算部140から出力されるロスL2と、FL計算部150から出力されるロスFLとの加重平均を算出し、ロスLaとしてパラメータ更新量計算器162に入力する。パラメータ更新量計算器162は、ロスLaが小さくなるように、生徒モデル120のパラメータを更新する。物体検知装置100は、複数の学習データD1及びその真値yを用いて物体検知モデルの学習を行い、所定の終了条件が具備されたときに、学習を終了する。
【0059】
次に、物体検知装置100の全体構成について説明する。
図8は、第1実施例による物体検知装置100の全体の機能構成を示す。物体検知装置100は、
図4に示すクラス分類に関する部分に加えて、回帰ロス計算器163を備える。また、教師モデル110は、特徴抽出器111と、矩形位置計算器112と、分類計算器113とを備える。また、生徒モデル120は、特徴抽出器121と、分類計算器122と、矩形位置計算器123とを備える。
【0060】
教師モデル110においては、特徴抽出器111は、学習データD1に対して特徴抽出を行い、特徴マップを矩形位置計算器112と分類計算器113に入力する。なお、本実施例では、矩形位置計算器112の計算結果は使用しない。分類計算器113は、特徴マップに基づいてクラス分類を行い、教師モデルのクラス分類結果qを出力する。
【0061】
一方、生徒モデル120においては、特徴抽出器121は、学習データD1に対して特徴抽出を行い、特徴マップを分類計算器122と矩形位置計算器123に出力する。分類計算器122は、特徴マップに基づいてクラス分類を行い、生徒モデルのクラス分類結果pを出力する。矩形位置計算器123は、特徴マップに基づいて矩形位置cを算出し、回帰ロス計算器163に出力する。回帰ロス計算器163には、矩形位置の真値ctが入力されており、回帰ロス計算器163は、矩形位置cとその真値ctの差分を回帰ロスLregとして算出し、加重平均計算器161に出力する。
【0062】
加重平均計算器161は、所定の重みを用いて、ロスL1と、ロスL2と、ロスFLと、回帰ロスLregとの加重平均を算出し、ロスLaとしてパラメータ更新量計算器162に入力する。このロスLaは、式(10)に示すトータルロスLに相当する。パラメータ更新量計算器162は、ロスLaが小さくなるように、生徒モデル120のパラメータを更新する。こうして、物体検知モデルの学習が行われる。
[第2実施例]
次に、第1実施形態の第2実施例について説明する。第2実施例は、上記のロスL1~L4のうち、ロスL3及びL4を使用するものである。
【0063】
(ハードウェア構成)
第2実施例に係る物体検知装置のハードウェア構成は、
図3に示す第1実施例のものと同様であるので、説明を省略する。
【0064】
(機能構成)
第2実施例においても、まず、クラス分類に関する構成のみを先に説明する。
図9は、第2実施例に係る物体検知装置100xのクラス分類に関する機能構成を示すブロック図である。なお、
図9は、物体検知装置100xの学習のための構成を示す。
【0065】
図示のように、物体検知装置100xは、教師モデル110と、生徒モデル120と、FL計算部150と、加重平均計算器161と、パラメータ更新量計算器162と、L3計算部170と、L4計算部180と、を備える。即ち、第2実施例の物体検知装置100xは、第1実施例の物体検知装置100におけるL1計算部130とL2計算部140の代わりに、L3計算部170とL4計算部180を設けたものであり、それ以外の点は第1実施例の物体検知装置100と同様である。
【0066】
L3計算部170は、教師モデル出力qと、生徒モデル出力pと、真値yとを用いて、前述のロスL3を算出する。
図10は、L3計算部170の構成を示す。L3計算部170は、距離計算器171及び172と、係数3計算器173と、L3計算器174とを備える。距離計算器171は、教師モデル出力qと真値yの距離KL(y||q)を算出する。距離計算器172は、真値yと生徒モデル出力pとの距離KL(y||p)を算出する。係数3計算器173は、上記の式(8)に基づき、以下の係数3を算出する。
【0067】
【数14】
この係数3は、真値yと教師モデル出力qとの差が大きいほど大きくなり、本発明における第2の重みに相当する。そして、L3計算器174は、係数3と、距離KL(y||p)とに基づいて、上記の式(8)により、ロスL3を算出する。
【0068】
L4計算部180は、教師モデル出力qと、生徒モデル出力pと、真値yとを用いて、前述のロスL4を算出する。
図11は、L4計算部180の構成を示す。L4計算部180は、距離計算器181及び182と、係数4計算器183と、係数5計算器184と、L4計算器185とを備える。距離計算器181は、教師モデル出力qと生徒モデル出力pの距離KL(q||p)を算出する。距離計算器182は、真値yと教師モデル出力qとの距離KL(y||q)を算出する。係数4計算器183は、上記の式(9)に基づき、以下の係数4を算出する。
【0069】
【数15】
この係数4は、教師モデル出力qと前記生徒モデル出力pとの差が大きいほど大きくなり、本発明における第3の重みに相当する。
【0070】
また、係数5計算器184は、上記の式(9)に基づき、以下の係数5を算出する。
【0071】
【数16】
この係数5は、真値yと教師モデル出力qとの差が小さいほど大きくなり、本発明における第4の重みに相当する。そして、L4計算器185は、係数4と、係数5と、距離KL(q||p)とに基づいて、式(9)によりロスL4を算出する。
【0072】
図9に戻り、加重平均計算器161は、所定の重みを用いて、L3計算部170から出力されるロスL3と、L4計算部180から出力されるロスL4と、FL計算部150から出力されるロスFLとの加重平均を算出し、ロスLaとしてパラメータ更新量計算器162に入力する。パラメータ更新量計算器162は、ロスLaが小さくなるように、生徒モデル120のパラメータを更新する。物体検知装置100は、複数の学習データD1及びその真値yを用いて物体検知モデルの学習を行い、所定の終了条件が具備されたときに、学習を終了する。
【0073】
次に、物体検知装置100xの全体構成について説明する。
図12は、第2実施例による物体検知装置100xの全体の機能構成を示す。第2実施例による物体検知装置100xは、
図8に示す第1実施例の物体検知装置100におけるL1計算部130とL2計算部140の代わりに、L3計算部170とL4計算部180を設けたものであり、それ以外の点は第1実施例の物体検知装置100と同様である。
【0074】
物体検知装置100xでは、加重平均計算器161は、所定の重みを用いて、ロスL3と、ロスL4と、ロスFLと、回帰ロスLregとの加重平均を算出し、ロスLaとしてパラメータ更新量計算器162に入力する。このロスLaは、式(10)に示すトータルロスLに相当する。パラメータ更新量計算器162は、ロスLaが小さくなるように、生徒モデル120のパラメータを更新する。こうして、物体検知モデルの学習が行われる。
【0075】
[第2実施形態]
次に、第2実施形態について説明する。
図13は、本発明の第2実施形態に係る学習装置の機能構成を示す。なお、学習装置50のハードウェア構成は基本的に
図3と同様である。
【0076】
図示のように、学習装置50は、教師モデル51と、生徒モデル52と、損失算出部53と、更新部54とを備える。教師モデル51は、入力された学習データに対する推論結果を損失算出部53に出力する。生徒モデルは、入力された学習データに対する推論結果を損失算出部53に出力する。損失算出部53は、教師モデルの出力と、生徒モデルの出力と、学習データに対する真値とに基づいてトータル損失を算出する。
【0077】
ここで、損失算出部53は、
(1)教師モデル51の出力の自信度が低いほど大きくなる第1の重みを、真値と生徒モデル52の出力との差に乗算して得た第1の損失と、
(2)真値と教師モデル51の出力との差が大きいほど大きくなる第2の重みを、真値と生徒モデル52の出力との差に乗算して得た第2の損失と、
(3)教師モデル51の出力と生徒モデル52の出力との差が大きいほど大きくなる第3の重み、及び、真値と教師モデル51の出力との差が小さいほど大きくなる第4の重みを、教師モデル51の出力と生徒モデル52の出力との差に乗算して得た第3の損失と、の少なくとも1つを用いてトータル損失を算出する。そして、更新部54は、トータル損失に基づいて、生徒モデル52のパラメータを更新する。
【0078】
図14は、学習装置50による学習処理のフローチャートである。学習データ及びそれに対する真値が入力されると、教師モデル51は、学習データの推論を行い、推論結果を出力する(ステップS11)。次に、生徒モデル52は、学習データの推論を行い、推論結果を出力する(ステップS12)。次に、損失算出部53は、上記の方法により、第1~第3の損失の少なくとも1つを用いてトータル損失を算出する(ステップS13)。そして、更新部54は、トータル損失に基づいて、生徒モデル52のパラメータを更新する。
【0079】
[変形例]
上記の実施形態では、教師モデル出力、生徒モデル出力及び真値の距離としてKL距離を使用している。この場合、教師モデル出力qと生徒モデル出力pのKL距離は以下の式で与えられる。
【0080】
【数17】
その代わりに、以下に示すユークリッド距離(「L2ノルム」とも呼ばれる。)を用いてもよい。
【0081】
【0082】
上記の実施形態の一部又は全部は、以下の付記のようにも記載されうるが、以下には限られない。
【0083】
(付記1)
学習データに対する推論結果を出力する教師モデルと、
前記学習データに対する推論結果を出力する生徒モデルと、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいてトータル損失を算出する損失算出部と、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する更新部と、
を備え、
前記損失算出部は、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、の少なくとも1つを用いて前記トータル損失を算出する学習装置。
【0084】
(付記2)
前記損失算出部は、前記真値と前記生徒モデルの出力との差、前記真値と前記教師モデルの出力との差、及び、前記教師モデルの出力と前記生徒モデルの出力との差を距離として計算する距離計算器を備える付記1に記載の学習装置。
【0085】
(付記3)
前記損失算出部は、前記自信度をエントロピー関数により算出する付記1又は2に記載の学習装置。
【0086】
(付記4)
前記損失算出部は、前記第1の損失、前記第2の損失及び前記第3の損失のうちの複数の損失の加重平均により前記トータル損失を算出する付記1乃至3のいずれか一項に記載の学習装置。
【0087】
(付記5)
前記学習データは画像データであり、
前記画像データから特徴を抽出して特徴マップを生成する特徴抽出部を備え、
前記教師モデル及び前記生徒モデルは、前記特徴マップに対して規定したアンカー毎に前記推論結果を出力する付記1乃至4のいずれか一項に記載の学習装置。
【0088】
(付記6)
前記教師モデル及び前記生徒モデルは、前記特徴抽出部が抽出した特徴マップに基づいて、前記画像データに含まれる対象物のクラス分類を行う付記5に記載の学習装置。
【0089】
(付記7)
前記教師モデル及び前記生徒モデルは、前記特徴抽出部が抽出した特徴マップに基づいて、前記画像データに含まれる対象物の位置を検出する付記6に記載の学習装置。
【0090】
(付記8)
教師モデルを用いて、学習データに対する推論結果を出力し、
生徒モデルを用いて、前記学習データに対する推論結果を出力し、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいて、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、の少なくとも1つを用いてトータル損失を算出し、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する学習方法。
【0091】
(付記9)
教師モデルを用いて、学習データに対する推論結果を出力し、
生徒モデルを用いて、前記学習データに対する推論結果を出力し、
前記教師モデルの出力と、前記生徒モデルの出力と、前記学習データに対する真値とに基づいて、
(1)前記教師モデルの出力の自信度が低いほど大きくなる第1の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第1の損失と、
(2)前記真値と前記教師モデルの出力との差が大きいほど大きくなる第2の重みを、前記真値と前記生徒モデルの出力との差に乗算して得た第2の損失と、
(3)前記教師モデルの出力と前記生徒モデルの出力との差が大きいほど大きくなる第3の重み、及び、前記真値と前記教師モデルの出力との差が小さいほど大きくなる第4の重みを、前記教師モデルの出力と前記生徒モデルの出力との差に乗算して得た第3の損失と、の少なくとも1つを用いてトータル損失を算出し、
前記トータル損失に基づいて、前記生徒モデルのパラメータを更新する処理をコンピュータに実行させるプログラムを記録した記録媒体。
【0092】
以上、実施形態及び実施例を参照して本発明を説明したが、本発明は上記実施形態及び実施例に限定されるものではない。本発明の構成や詳細には、本発明のスコープ内で当業者が理解し得る様々な変更をすることができる。
【符号の説明】
【0093】
50 学習装置
100、100x 物体検知装置
110 教師モデル
120 生徒モデル
130 L1計算部
140 L2計算部
150 FL計算部
170 L3計算部
180 L4計算部
161 加重平均計算器
162 パラメータ更新量計算器