IP Force 特許公報掲載プロジェクト 2022.1.31 β版

知財求人 - 知財ポータルサイト「IP Force」

▶ 株式会社Preferred Networksの特許一覧

特許7315748データ識別器訓練方法、データ識別器訓練装置、プログラム及び訓練方法
(19)【発行国】日本国特許庁(JP)
(12)【公報種別】特許公報(B2)
(11)【特許番号】
(24)【登録日】2023-07-18
(45)【発行日】2023-07-26
(54)【発明の名称】データ識別器訓練方法、データ識別器訓練装置、プログラム及び訓練方法
(51)【国際特許分類】
   G06N 3/0475 20230101AFI20230719BHJP
   G06N 3/084 20230101ALI20230719BHJP
   G06N 3/094 20230101ALI20230719BHJP
【FI】
G06N3/0475
G06N3/084
G06N3/094
【請求項の数】 24
(21)【出願番号】P 2022070856
(22)【出願日】2022-04-22
(62)【分割の表示】P 2019203198の分割
【原出願日】2018-06-28
(65)【公開番号】P2022101650
(43)【公開日】2022-07-06
【審査請求日】2022-05-20
(31)【優先権主張番号】P 2017127769
(32)【優先日】2017-06-29
(33)【優先権主張国・地域又は機関】JP
(73)【特許権者】
【識別番号】515130201
【氏名又は名称】株式会社Preferred Networks
(74)【代理人】
【識別番号】100091487
【弁理士】
【氏名又は名称】中村 行孝
(74)【代理人】
【識別番号】100120031
【弁理士】
【氏名又は名称】宮嶋 学
(74)【代理人】
【識別番号】100118876
【弁理士】
【氏名又は名称】鈴木 順生
(74)【代理人】
【識別番号】100202429
【弁理士】
【氏名又は名称】石原 信人
(72)【発明者】
【氏名】宮戸 岳
【審査官】北川 純次
(56)【参考文献】
【文献】国際公開第2017/094899(WO,A1)
【文献】国際公開第2017/094267(WO,A1)
【文献】米国特許出願公開第2015/0170020(US,A1)
【文献】YOSHIDA, Y. et al.,Spectral Norm Regularization for Improving the Generalizability of Deep Learning,arXiv.org [online],2017年05月31日,[検索日 2023.03.27], インターネット:<URL:https://arxiv.org/pdf/1705.10941v1.pdf>,<DOI: 10.48550/arXiv.1705.10941>
(58)【調査した分野】(Int.Cl.,DB名)
G06N 3/02-3/10
G06N 20/00
(57)【特許請求の範囲】
【請求項1】
ニューラルネットワークの重み行列をスペクトル正規化して、正規化された重み行列を算出することと、
前記正規化された重み行列を用いて取得した出力に基づいて、前記ニューラルネットワークの前記重み行列を更新することと、を備え、
前記正規化された重み行列は、第1データと前記第1データとは異なる第2データとの判別に用いられる、
訓練方法。
【請求項2】
前記出力に基づいて誤差を算出すること、を更に備え、
前記ニューラルネットワークの前記重み行列は、前記誤差に基づいて更新される、
請求項1に記載の訓練方法。
【請求項3】
前記スペクトル正規化は、前記重み行列のスペクトルノルムに基づく前記重み行列の正規化である、
請求項1又は請求項2に記載の訓練方法。
【請求項4】
前記スペクトル正規化は、前記重み行列の特異値に基づく前記重み行列の正規化である、
請求項1又は請求項2に記載の訓練方法。
【請求項5】
前記特異値は、前記重み行列の最大の特異値である、
請求項4に記載の訓練方法。
【請求項6】
前記特異値は、べき乗法を用いて算出される、
請求項4又は請求項5に記載の訓練方法。
【請求項7】
前記ニューラルネットワークは、CNN(Convolutional Neural Network)である、
請求項1乃至請求項6のいずれか一項に記載の訓練方法。
【請求項8】
前記重み行列は、前記ニューラルネットワークの所定の層における重み行列である、
請求項1乃至請求項7のいずれか一項に記載の訓練方法。
【請求項9】
前記重み行列は、前記ニューラルネットワークの複数の層における重み行列である、
請求項1乃至請求項8のいずれか一項に記載の訓練方法。
【請求項10】
前記第1データ及び前記第2データは、画像データである、
請求項1乃至請求項9のいずれか一項に記載の訓練方法。
【請求項11】
ニューラルネットワークの重み行列を前記重み行列のスペクトルノルムに基づいて正規化し、正規化された重み行列を算出することと、
前記正規化された重み行列に基づいて、評価値を算出することと、
前記評価値に基づいて、前記ニューラルネットワークの前記重み行列を更新することと、
を備える訓練方法。
【請求項12】
ニューラルネットワークの重み行列を前記重み行列の最大の特異値に基づいて正規化し、正規化された重み行列を算出することと、
前記正規化された重み行列に基づいて、評価値を算出することと、
前記評価値に基づいて、前記ニューラルネットワークの前記重み行列を更新することと、
を備える訓練方法。
【請求項13】
前記特異値は、べき乗法を用いて算出される、
請求項12に記載の訓練方法。
【請求項14】
前記ニューラルネットワークは、分類を行うニューラルネットワークである、
請求項11乃至請求項13のいずれか一項に記載の訓練方法。
【請求項15】
前記ニューラルネットワークは、敵対的生成ネットワークのディスクリミネータである、
請求項11乃至請求項13のいずれか一項に記載の訓練方法。
【請求項16】
請求項1乃至請求項15のいずれか一項に記載の訓練方法を用いて前記正規化された重み行列を生成する、
モデル生成方法。
【請求項17】
請求項1乃至請求項15のいずれか一項に記載の訓練方法によって生成された前記正規化された重み行列を用いて、敵対的生成ネットワークのジェネレータを生成する、
モデル生成方法。
【請求項18】
請求項17に記載のモデル生成方法を用いて生成された前記ジェネレータを用いて、データを生成する、
データ生成方法。
【請求項19】
少なくとも1つのプロセッサ、を備え、
前記少なくとも1つのプロセッサは、請求項1乃至請求項15のいずれか一項に記載の訓練方法を実行する、
訓練装置。
【請求項20】
少なくとも1つのプロセッサ、を備え、
前記少なくとも1つのプロセッサは、請求項16又は請求項17に記載のモデル生成方法を実行する、
モデル生成装置。
【請求項21】
少なくとも1つのプロセッサ、を備え、
前記少なくとも1つのプロセッサは、請求項18に記載のデータ生成方法を実行する、
データ生成装置。
【請求項22】
少なくとも1つのプロセッサに、請求項1乃至請求項15のいずれか一項に記載の訓練方法を実行させる、
プログラム。
【請求項23】
少なくとも1つのプロセッサに、請求項16又は請求項17に記載のモデル生成方法を実行させる、
プログラム。
【請求項24】
少なくとも1つのプロセッサに、請求項18に記載のデータ生成方法を実行させる、
プログラム。
【発明の詳細な説明】
【技術分野】
【0001】
本発明は、データ識別器訓練方法、データ識別器訓練装置、プログラム及び訓練方法に関する。
【背景技術】
【0002】
敵対的生成ネットワーク(GAN:Generative Adversarial Networks、Ian Goodfellow, Jean Pouget-Abadie, Mehdi Mirza, Bing Xu, David Warde-Farley, Sherjil Ozair, Aaron Courville, and Yoshua Bengio, "Generative adversarial networks," arXiv preprint arXiv:1406.2661, 10 Jun 2014)は、近年、生成モデルのフレームワークとして広く研究され、種々のデータセットに適用されている。GANは、与えられたターゲット分布を模したモデル分布を生成するためのフレームワークであり、モデル分布を生成するジェネレータと、モデル分布をターゲットから区別するディスクリミネータとで構成される。各ステップにおいて、モデル分布と、それに対するディスクリミネータにより測定されたターゲット分布との差を減少させるように、最良なディスクリミネータを連続に訓練する。
【0003】
GANの訓練において、ディスクリミネータのパフォーマンスの制御が問題となる。高次元空間では、ディスクリミネータによる密度比推定は、その訓練中に、不正確、不安定であることが多い。この結果、ジェネレータネットワークは、ターゲット分布のマルチモーダル構造を学習することができない。さらに、モデル分布のサポートと、ターゲット分布のサポートが分離している場合、モデル分布をターゲット分布から完全に区別することができるディスクリミネータが存在する。この状況下において、このようなディスクリミネータが生成されると、入力に対する当該ディスクリミネータの導関数が0となるため、ジェネレータの訓練は、停止する。
【発明の開示】
【0004】
本発明の実施形態は、ディスクリミネータネットワークの訓練の安定化をするデータ判別器方法、装置及びプログラムを提案する。
【0005】
一実施形態によれば、データ識別器訓練方法は、正解データと、擬データと、を識別するニューラルネットワークモデルを備えるデータ識別器を訓練する、データ識別器訓練方法であって、前記データ識別器に前記正解データを入力し、第1予測結果を取得するステップと、前記データ識別器に前記擬データを入力し、第2予測結果を取得するステップと、取得された前記第1予測結果及び取得された前記第2予測結果に基づいて、誤差を算出するステップと、前記誤差と、前記ニューラルネットワークモデルの各層の重み行列の特異値と、に基づいて、前記重み行列を更新するステップと、を備える。
【0006】
一実施形態によれば、GANにおけるディスクリミネータの訓練の安定化を図ることが可能となる。
【図面の簡単な説明】
【0007】
図1】本実施形態に係るデータ判別器の学習の流れを示すフローチャート。
図2】本実施形態に係るデータ判別装置の機能を示すブロック図。
図3】比較実験におけるパラメータの設定例を示す図。
図4A】比較実験におけるインセプションスコアを示すグラフ。
図4B】比較実験におけるインセプションスコアを示すグラフ。
図5A】比較実験におけるインセプションスコアを示す図。
図5B】比較実験におけるインセプションスコアを示す図。
図6】比較実験における出力結果の例を示す図。
図7】比較実験における出力結果の例を示す図。
図8】比較実験における重み行列の特異値の大きさの理論値を示す図。
図9A】比較実験における重み行列の特異値の大きさを示す図。
図9B】比較実験における重み行列の特異値の大きさを示す図。
図10A】比較実験における処理時間を示す図。
図10B】比較実験における処理時間を示す図。
図11】比較実験におけるインセプションスコアを示す図。
図12A】比較実験におけるロスを示す図。
図12B】比較実験におけるロスを示す図。
図13】比較実験におけるインセプションスコアを示す図。
【発明を実施するための形態】
【0008】
以下、説明文中において、数式中の変数又は関数の上部に付するバーは「/」と表し、同じくハットは「^」、チルダは「」と表す。例えば、xに上記のものを付する場合には、それぞれ、「x/」、「x^」、「x」と表す。また、ベクトル又は行列に対して右側に「」と記載した場合、それぞれベクトル又は行列の転置を表す。
【0009】
(敵対的生成ネットワーク)
まず、本実施形態の基礎となる敵対的生成ネットワーク(以下、GANと記載する。)について簡単に説明する。GANは、ジェネレータと、ディスクリミネータと、を備え、ジェネレータと、ディスクリミネータを並行して学習する生成モデルの訓練方法の一種である。
【0010】
ジェネレータ(データ生成器)は、ターゲットデータである訓練データ(正解データ)の学習をして、訓練データに類似するデータ(擬データ)を生成する。このジェネレータは、雑音データが入力されると擬データを出力するモデルとして学習される。ディスクリミネータ(データ判別器)は、ジェネレータが生成した擬データと正解データとの判別を行う。すなわち、ジェネレータは、ターゲットデータの分布(訓練データの分布)と生成したモデルの分布とが一致するように学習され、一方で、ディスクリミネータは、正解データと擬データとを区別するように学習される。
【0011】
この学習においては、ジェネレータネットワークとディスクリミネータネットワークの2つが存在する。ネットワークとしては、例えば、多層パーセプトロン(MLP:Multi-Layer Perceptron)、コンボリューションネットワーク等が用いられる。
【0012】
例えば、以下の式で表されるMLPによるディスクリミネータのネットワークについて説明する。なお、CNNにおいても、例えば、畳み込み層における重み行列のそれぞれいついても以下の式を用いることが可能であり、本実施形態にかかる重み行列の正規化を同様に適用することができる。
【数1】
【0013】
ベクトルhは、第l層の出力、行列Wは、第l-1層と第l層との間の重み付け行列、ベクトルbは、第l層におけるバイアス、aは、エレメントごとの非線形の活性化関数を示す。ここで、dim(l)がlの次元を表し、Rが実数体を表すものとして、W∈Rdim(l)×dim(l-1)、b∈Rdim(l)、h∈Rdim(l)、h(x)=xである。上記の式を一連の構成として解釈すると、入力ベクトルxを有するネットワークの最終層の出力がhとなる。以下の説明では、簡単のため、f(x)=h(x)として記載する。
【0014】
このように定義すると、Aをユーザが選択した距離測定の発散に対応する活性化関数として、ディスクリミネータの出力は、D(x)=A(f(x))として与えられる。GANの標準的な形式は、以下の式のように表される。
【数2】
【0015】
ここで、Gをジェネレータの出力とする。G及びDの最大値及び最小値は、それぞれ、ジェネレータ及びディスクリミネータのセットに引き継がれる。V(G,D)の一般的な式は、以下のように与えられる。
【数3】
【0016】
ここで、E[・]は、期待値を表し、qは、ターゲットデータの分散、pは、敵対的最小値最大値最適化を介して学習されるモデルのジェネレータの分散、x’は、ジェネレータにより生成された擬データである。この形式のDにおいて用いられる活性化関数Aは、例えば、シグモイド関数のような[0,1]の範囲の連続的な関数である。固定されたジェネレータGに対して、V(G,D)のこの形式の最適なディスクリミネータは、D =q(x)/(q(x)+p(x))で与えられることが知られている。
【0017】
ディスクリミネータが選択された関数空間がGANのパフォーマンスに決定的に影響を及ぼすと考えられている。多くの研究において、統計の有界性を保証する上で、リプシッツ連続性の重要性が指摘されている。このようなことに鑑みると、例えば、GANの最適なディスクリミネータは、以下のように表される。
【数4】
【0018】
この微分は、以下のように表され、これは、有界ではなく(unbound)、又は、現実的に計算することができない(incomputable)ようになり得る。
【数5】
【0019】
そこで、この配列において、入力サンプルxに定義された正規化項を追加することによりディスクリミネータのリプシッツ定数を制御する方法がある。以下の式のように、リプシッツ連続関数のセットからディスクリミネータDを探る。
【数6】
【0020】
ここで、||f||Lip≦Kは、全てのx、x’に対して、||f(x)-f(x’)||/||x-x’||≦Kであることを意味する。ただし、ノルム(||・||)は、Lノルムを表すものとする。
【0021】
入力ベースの正則化は、サンプルに基づく比較的容易な公式化を可能とするが、ヒューリスティック及びそれに近い手段を除き、ジェネレータ及びターゲットデータ分布のサポートの外側の空間に対する正規化をインポーズすることが困難である。そこで、本実施形態においては、スペクトル正規化(Spectral Normalization、Yuichi Yoshida and Takeru Miyato, "Spectral norm regularization for improving the generalizability of deep learning," arXiv preprint arXiv:1705.10941, 31 May 2017)を用いて重み行列を正規化する。
【0022】
(スペクトル正規化)
例えば、活性化関数aをReLU(Rectified Linear Unit)、leakyReLUとすると、||a||Lip=1となる。活性化関数aのそれぞれが、||a||Lip=1を満たす場合、不等式||g*g||Lip≦||g||Lip・||g||Lipであるので、||f||Lipを以下のように上から押さえることができる。ただし、上記の式においては、g*gは、gとgとの合成関数であることを意味する。
【数7】
【0023】
ここで、σ(W)は、行列WのLノルムであるスペクトルノルムであり、Wの最大の特異値と同等であるものとして、以下のように表される。
【数8】
【0024】
スペクトル正規化は、以下のように、重み行列Wのスペクトルノルムを規格化し、リプシッツ定数を1とすることが可能である。
【数9】
【0025】
[数8]を用いてそれぞれのWを規格化した場合、||f||Lipが1で上から押さえられることから[数6]の不等式が成立し、σ(W/SN(W))=1とすることができる。
【0026】
ijに対するW/SN(W)の勾配は、以下のように表される。
【数10】
【0027】
ここで、Eijは、(i,j)要素が1、他の要素が0である行列を表し、uは、Wの第1左特異ベクトル、vは、Wの第1右特異ベクトルである。第1左特異ベクトル、第1右特異ベクトルとは、Wを特異値分解した場合に、特異値成分を左上から右下へと向かって降順にソートした特異値行列において、左上成分である特異値(第1特異値)に対応する左特異ベクトル、右特異ベクトルのことを示す。
【0028】
もし、hがネットワーク内において、重み行列Wにより変換される隠れたノードであるならば、ディスクリミネータDのWに関するミニバッチ上で計算されたV(G,D)の導関数は、以下のように与えられる。
【数11】
【0029】
ここで、E^[・]は、ミニバッチにおける経験的期待値を表し、δ=(∂V(G,D)/∂(W/SNh))、λ=E^[δ(W/SNh)]である。
【0030】
[数10]の下段の式において、第1項のE^[δh]は、正規化していない重みの導関数と等しい。この観点から、第2項は、補償正規化係数λを用いて第1特異値成分をペナルティ化する正規化の項とみることができる。λは、δ及びW/SNhが同じ方向を示すのであれば、正の値となり、これは、訓練中にWの列が一方向に集中するのを防止する。換言すると、スペクトル正規化は、それぞれのレイヤにおける変換が一方向にセンシティブになるのを防止する。
【0031】
この正規化の効果を利用して、アルゴリズムの多彩なバージョンを検討することができる。以下のように与えられるディスクリミネータの重み行列の別パラメータ化を考えることも可能である。
【数12】
【0032】
ここで、γは、学習されるスカラー変数である。このパラメータ化は、注目しているレイヤにおける1-リプシッツ制約を含むが、モデルが縮退するのを防止するとともに、モデルに自由度を与えることができる。この再パラメータ化を行うためには、勾配ペナルティ(Gradient Penalty)のような他の手段によりリプシッツ状態を制御する。
【0033】
上述したように、ディスクリミネータのネットワークの各層において正規化するスペクトルノルムσ(W)は、Wの最大の特異値となる。このようにアルゴリズムの各ラウンドにおいて、単純に特異値分解を適用すると、計算のコストが膨大なものとなる。そこで、σ(W)を評価するためにべき乗法(Power Iteration Method)を使用してもよい。
【0034】
この手法では、乱数により初期化されたベクトルu及びvから開始する。もし、優性な特異値において多重しない場合、かつ、u及びvが第1特異ベクトルと直交しない場合、u及びvは、以下の更新ルールに基づいて、第1左特異ベクトルu及び第1右特異ベクトルvへとそれぞれ収束する。
【数13】
【0035】
さらに、以下のように、Wのスペクトルノルムを上記のように近似した特異ベクトルのペアによって近似することができる。
【数14】
【0036】
もし、SGD(Stochastic Gradient Descent:確率的勾配降下法)を更新に用いるのであれば、各更新におけるWの変化は小さく、したがって、最大の特異値が変化する。実装において、この事実を利用し、アルゴリズムの各ステップで計算されたuを次のステップの初期ベクトルとして再利用する。このリサイクルプロシージャにより、1ラウンド以上のラウンドのべき乗反復を行う。
【0037】
以下、本実施形態に係るスペクトル正規化に基づいたGANの手法について、フローチャートに基づいて説明する。図1は、本実施形態の処理の流れを示すフローチャートである。
【0038】
なお、以下の説明において、ジェネレータ生成についての説明は省略するが、ジェネレータ生成の手法については、特に限定するものではない。本実施形態においては、ディスクリミネータの生成について説明する。また、一般的なニューラルネットワークモデルの生成と同様の処理について、例えば、ハイパーパラメータの設定、順伝播、逆伝播等の詳細な説明等は、省略することがある。ディスクリミネータの学習は、上述したようにミニバッチを用いて行ってもよいし、別の例として、バッチ学習又はオンライン学習により行っても下記と同様に処理を行うことができる。
【0039】
まず、ジェネレータ及びディスクリミネータのニューラルネットワークモデルを設定する(S100)。設定されるモデルは、上述したように、例えば、MLP、CNN等のニューラルネットワークモデルである。
【0040】
次に、ディスクリミネータの各層を接続する重み行列Wについて、左特異ベクトルu (∈Rdim(l))の初期化を行う(S102)。初期化は、例えば、等方正分布に基づいた乱数を用いて行われる。
【0041】
モデル及び変数の初期化が終了した後、ジェネレータ、ディスクリミネータの学習に移行する。上述したように、ジェネレータとディスクリミネータの最適化については、それぞれの出力結果を[数3]等の式により評価することにより、並行して又は各ステップにおいて交互に実行される。
【0042】
以下のS104からS110の説明においては、各層ごとの重み行列Wに対しての処理を行うことを記載している。例えば、第l-1層と、第l層とを接続する重み行列Wについて処理することを説明する。
【0043】
逆伝播において、重み行列の更新には、スペクトル正規化を用いる。そのため、誤差逆伝播処理においては、まず、左右それぞれの特異ベクトルを更新する(S104)。第1特異ベクトルの更新は、例えば、以下に示される式に基づき実行される。
【数15】
【0044】
ここで、u は、重み行列Wの左特異ベクトル、v は、重み行列Wの右特異ベクトルをそれぞれ示す。すなわち、乱数により初期化された左特異ベクトルu 及び重み行列Wを用いて、右特異ベクトルv を更新する。更新された右特異ベクトルv 及び重み行列Wを用いて、左特異ベクトルu を更新する。このようにべき乗法に基づき交互に更新する収束演算を行うことにより、第1右特異ベクトル及び第1左特異ベクトルを算出する。このステップは、任意で、所定数回繰り返し行うようにしてもよい。
【0045】
次に、更新された左特異ベクトルu及び右特異ベクトルvに基づいて、重み行列を正規化する(S106)。この正規化は、Wのスペクトルノルムに基づき、上述したスペクトル正規化により実行される。Wのスペクトルノルムσ(W)を用いて、例えば、以下に示す数式によりスペクトル正規化された重み行列W/SN が算出される。
【数16】
【0046】
次に、スペクトル正規化された重み行列W/SN に基づいて、誤差の算出する(S108)。トレーニングデータである正解データと、ジェネレータの出力結果である擬データと、がディスクリミネータに入力され、順伝播される。出力層において、例えば、ディスクリミネータから出力された結果が、ジェネレータの出力結果が偽、正解データの出力結果が真となるラベルに基づき、出力層における誤差を算出する。
【0047】
より具体的な例として、このステップ(S108)は、次の3つのステップを備える。まず、正解データが正解データであると判定されるか否かの第1予測結果を取得する。次に、擬データが正解データでは無いと判定されるか否かの第2予測結果を取得する。そして、[数3]で表されるようなロス関数により、これら第1予測結果及び第2予測結果に基づき、誤差を算出する。この誤差を逆伝播することにより、以下に説明するように、重み行列の更新が行われる。
【0048】
次に、算出された誤差に基づいて、重み行列Wを更新する(S110)。例えば、以下の式に基づいて、重み行列Wが更新される。
【数17】
【0049】
ここで、W/SN (W)は、スペクトル正規化された重み行列を示し、Dは、所定のデータセットに基づいた値であることを示す。例えば、ミニバッチ処理を行っている場合には、Dは、ミニバッチ内のデータセットに基づいて重み行列Wを更新することを示す。また、lは、ロス関数を示し[数3]等に基づいた関数、例えば、後述する[数17]、アルゴリズムによっては後述する[数20]等で示される関数である。これらの式では、正解データをデータ識別器に入力した場合の第1予測結果と真値との誤差(第1部分誤差)、及び、偽データをデータ識別器に入力した場合の第2予測結果と偽値との誤差(第2部分誤差)に基づいて全体的な誤差(ロス)を算出する。例えば、[数3]の第1項が第1予測結果と真値との誤差、第2項が第2予測結果と偽値との誤差を示し、これらの和を求めることによりロスを算出する。[数16]に示される更新は、スペクトル正規化された重み行列W/SN を用いていること以外は、一般的なSGDによる更新に基づくものである。ミニバッチ学習等により学習を行っている場合には、各ミニバッチからの出力に基づいて、さらに重み行列を更新してもよい。各ミニバッチからの出力に基づく重み行列の更新は、一般的な手法により行われる。
【0050】
次に、最適化が終了しているか否かを判断する(S112)。最適化の終了は、例えば、全ての層において重み行列の更新がされた、所定回数のエポックが終了した、評価関数が所定条件を満たした、ロス関数が所定条件を満たした等に基づいて判断される。バッチ学習、ミニバッチ学習等の学習をしている場合には、必要となるデータに対して学習が終了したか否かを判断してもよい。
【0051】
例えば、全ての層において重み行列が更新されていない場合、算出した誤差を逆伝播することにより、1つ前の層の重み行列の更新を続けて行う。あるエポックにおいて全てのミニバッチに対して全ての層の重み行列の更新が行われた後であれば、エポックすうが所定回数に達しているか否かを判断し、訓練を終了、又は、続行する。図1におけるS112では、異なるレベルにおける終了条件をまとめて記載しているが、もちろん、最適化の終了条件をより細かく設定してもよく、ネストされたループとしてフローチャートを理解できるものであるとする。
【0052】
最適化が終了していない場合(S112:No)、S104からS110の処理が繰り返し行われる。上述したように、例えば、全ての層において重み行列の更新がされていない場合、S110において重み行列が更新された層の前の層にロスの逆伝播をし、前の層の重み行列の更新処理を行う。所定回数のエポックが終了していない場合、所定回数となるまで処理を行う。評価関数、又は、ロス関数等が所定条件を満たしていない場合、所定条件を満たすまで処理を行う。バッチ学習、ミニバッチ学習等においては、必要となるデータに対して学習が終了するまで処理を行い、その上で、上記のエポック数、評価関数、ロス関数の条件を満たすまで処理が繰り返される。なお、上述した処理において、特に、左特異ベクトルの初期値は、前ステップにおいてべき乗法により最適化されたベクトルを用いてもよい。
【0053】
最適化が終了した場合(S112:Yes)、学習済みモデルを出力し(S114)、処理を終了する。
【0054】
図2は、本実施形態に係るデータ判別器生成装置の機能を示すブロック図である。データ判別器生成装置1は、データ判別器初期化部10と、データ判別器記憶部12と、入力部14と、順伝播部16と、誤差算出部18と、逆伝播ブロック20と、を備える。
【0055】
データ判別器初期化部10は、GANにおけるディスクリミネータ(データ判別器)のモデルの初期化を行う。例えば、モデルとして用いられるニューラルネットワークモデルの選択、隠れ層の数、各層間を接続する重み行列等の初期化を行う。ニューラルネットワークモデルの選択及び隠れ層の数は、ユーザによる指定を受け付けるものであってもよい。重み行列の初期化は、ユーザによる指定を受け付けるものであってもよいし、乱数等により自動生成されるものであってもよい。データ判別器初期化部10により、上述したS100の処理が行われる。また、モデルの生成と併せて、S102に示される各層間における重み行列の左特異ベクトルの初期化を行ってもよい。
【0056】
なお、データ判別器初期化部10は、データ判別器生成装置1に必須の構成ではない。例えば、ユーザがデータ判別器記憶部12にあらかじめ生成されているモデルを入力することにより、ディスクリミネータのモデルが記憶されてもよい。別の例として、データ判別器生成装置1の外部において自動的な処理により生成されたモデルがデータ判別器記憶部12へと入力され、ディスクリミネータのモデルとして記憶されてもよい。
【0057】
データ判別器記憶部12は、データ判別器初期化部10により初期化されたモデル及び当該モデルを最適化したモデル等を記憶する。学習の最中においては、重み行列等が更新されたモデルを記憶しておいてもよい。順伝播部16及び逆伝播ブロック20は、このデータ判別器記憶部12に記憶されているモデルを用いて順伝播及び逆伝播を行い、当該モデルを更新する。
【0058】
入力部14は、ジェネレータ(データ生成器)が生成した正解データ(トレーニングデータ)に類似するデータである擬データ及び正解データを順伝播部16へと入力する。ディスクリミネータは、ジェネレータが生成した擬データと、正解データとを判別するように最適化される。
【0059】
順伝播部16は、データ判別器記憶部12に記憶されているデータ判別器に上記の擬データ、又は、正解データを入力し、順伝播を行う。
【0060】
順伝播部16は、データ判別器の入力層へとデータを入力し、出力層からの判別結果を取得する。データ判別器として、データ判別器記憶部12に記憶されているモデルを使用する。
【0061】
誤差算出部18は、データ判別器に擬データが入力された場合の出力と、データ判別器に正解データが入力された場合の出力とを比較し、誤差を算出する。誤差の算出は、例えば、[数3]に示される数式を用いる。この誤差算出部18が、S104の処理を行う。誤差算出部18が算出した誤差は、逆伝播ブロック20へと入力され、誤差逆伝播処理が実行される。また、誤差算出部18は、誤差逆伝播中において、誤差を算出する。算出された誤差を用いて、逆伝播ブロック20が誤差の逆伝播及び重み行列の更新を行う。
【0062】
逆伝播ブロック20は、データ判別器記憶部12に記憶されているデータ判別器のモデルを誤差逆伝播により更新する。例えば、モデルにおける重み行列(パラメータ)が更新される。逆伝播ブロック20は、逆伝播部200と、更新部202と、を備える。
【0063】
逆伝播部200は、データ判別器記憶部12に記憶されているデータ判別器のモデルと、誤差算出部18が算出した誤差とに基づいて、誤差逆伝播処理を行う。
【0064】
更新部202は、逆伝播処理のうち特に、重み行列を更新する処理を行う。例えば、この更新部202が、S104からS110の処理を行う。フローチャートに示されるように、逆伝播するタイミングにおいて、S106における特異ベクトルの更新、及び、S108の重み行列の正規化の処理により、重み行列の更新を行う。
【0065】
隠れ層がL層あるモデルの場合、まず、逆伝播部200が出力層から第L層へと誤差を逆伝播し、更新部202が第L層における重み行列を更新する。ミニバッチ学習である場合には、ミニバッチ内において第L層の重み行列を更新する。次に、第L層から第L-1層へと誤差を逆伝播し、同様に重み行列を更新する。このように、逐次的に誤差を逆伝播することにより、各層における重み行列を更新する。ミニバッチ学習の場合、隠れ層の重み行列の更新が終了し、ミニバッチ内における学習ステップ、例えば、上述したように評価値等に基づいた学習ステップが終了した後、次のミニバッチを生成し、同じように重み行列を更新していく。ミニバッチの処理については、一般的な手法で行われる。
【0066】
上述においては、基本的なGANに対してスペクトル正規化を適用する例を説明したが、GANではなく、WGAN(Wesserstein GAN)、WGAN-GP(Wesserstein GAN with Gradient Penalty)、DCGAN(Deep Convolutional GAN)、DRAGAN(Deep Regret Analytic GAN)等の他のGANのアルゴリズムに対してもスペクトル正規化を適用することが可能である。
【0067】
本実施形態に係るデータ判別器生成装置1を備えるGANの学習の安定性について、当該GANにより生成されたジェネレータによる画像生成の例を挙げながら説明する。以下の例においては、ジェネレータ、ディスクリミネータともにCNNに基づいたモデルの学習を行っている。
【0068】
以下の説明において、誤差逆伝播におけるSGDとしてAdam(Adaptive Moment Estimation)を用いてシミュレーションを行った。なお、Adam以外のSGD手法、Momentum、AdaGrad、RMSProp、AdaDelta等の他の手法を用いてもよい。本実施形態におけるディスクリミネータの更新に用いるロス関数は、以下の式を用いた。
【数18】
【0069】
また、ジェネレータの更新に用いるコスト関数は、以下に示される式を用いた。
【数19】
【0070】
本実施形態と、比較例とにおけるジェネレータが生成した擬データの評価として、以下のように定義されるインセプションスコア(Inception score)を用いた。
【数20】
【0071】
ここで、DKL[・]は、KLダイバージェンス(カルバック・ライブラー情報量:Kullback-Leibler Divergence)をあらわす。また、p(y)は、(1/N)Σn=1 p(y|x)で周辺確率として計算できる。
【0072】
図3は、シミュレーションのパラメータ例を挙げたものである。設定の項は、それぞれのパラメータの名称を示す。α、β、βは、それぞれAdamにおけるハイパーパラメータである。αは、学習率、βは、1次モーメンタム、βは、2次モーメンタムをそれぞれ示す。ndisは、ジェネレータが1回更新されるごとにディスクリミネータが更新される回数を示す。これらのパラメータの組み合わせは、シミュレーション結果を記載するための一例として表示されたものであり、本実施形態に係るスペクトル正規化を用いるGANの手法において重要な箇所ではなく、任意に変更してよい箇所である。
【0073】
設定Aは、WGAN-GP手法の論文(I. Gulrajani, et.al, "Improved training of Wasserstein gans." arXiv preprint, arXiv:1704.00028, 2017)で示されたパラメータである。設定Bは、論文(D. Warde-Farley, et.al, "Improving generative adversarial networks with denoising feature matching," ICLR, Nov. 6, 2016)で示されたパラメータである。設定Cは、DCGAN(Deep Convolutional GAN)手法の論文(A. Radford, et.al, "Unsupervised representation learning with deep convolutional generative adversarial networks," arXiv preprint, arXiv:1611.06624, 2016)で示されたパラメータである。
【0074】
設定A乃至設定Cは、既に論文として発表されている結果と比較するために挙げた例である。一方、設定D乃至設定Fは、さらに積極的な学習をする状況において、アルゴリズムの改善を評価するために設定されたパラメータセットである。これらのそれぞれの設定において、10万回のジェネレータのアップデータを学習させた。
【0075】
以下、図において、本実施形態に係るスペクトル正規化によるディスクリミネータの学習を用いたジェネレータの学習をSNと記載する。また、WN(Weight Normalization:T. Salimans, et.al, "Weight normalization: A simple reparameterization to accelerate training of deep neural networks," Advance in Neural Information Processing Systems, p901, 2016)、WGAN-GPの手法をそれぞれ比較対象とした結果をそれぞれ比較例1及び比較例2として記載する。
【0076】
図4Aは、データセットCIFAR-10の画像を用いて、図4Bは、データセットSTL-10の画像を用いて、図3に示す各設定におけるシミュレーションを行った結果を示すグラフである。縦軸は、上述したインセプションスコアを示す。
【0077】
これらの図から、SNは、積極的な学習率とモーメンタムパラメータに対して、比較例1及び比較例2と比較してロバストであることが読み取れる。比較例2は、高い学習率と高いモーメンタムパラメータに対して、GANによってよい出力結果を得ることに失敗している。比較例1は、CIFAR-10よりも多様な例により構成されているSTL-10においてSN及び比較例2よりも最適化の実効性が確保できていない。SNは、他のCIFAR-10及びSTL-10の双方において、他の方法よりも優れている。
【0078】
図5Aは、データセットCIFAR-10の画像を用いて、図5Bは、データセットSTL-10の画像を用いて、SN、比較例1、比較例2、及び、その他の手法を用いたインセプションスコアの結果を示す表である。リアルデータは、データセット中のデータを用いて取得されたインセプションスコアであることを示す。
【0079】
図5Aに示すように、SNは、Warde-Farley(図3の設定Bの記載されている論文の手法)以外においては、よりよいインセプションスコアを出していることが分かる。CIFAR-10よりも多様性のある画像から構成されるSTL-10においては、全ての他の手法よりもよいインセプションスコアである。
【0080】
図6は、データセットCIFAR-10の画像を用いて学習したジェネレータにより生成された画像を示し、図7は、STL-10の画像を用いて学習したジェネレータにより生成された画像を示す。
【0081】
上に描かれている8×8ブロックに分割された48×48ピクセルの画像は、データセットとして学習に与えられた画像である。下に描かれている画像は、上から順番に、SN、比較例1、比較例2を用いたGANにより学習されたジェネレータが生成した画像である。このように、生成された画像をみると、SNによるジェネレータが生成した画像が比較的よい結果であることが見られる。
【0082】
特に、学習率を挙げた場合には、比較例1及び比較例2に対して良好な結果が出力されている。設定D乃至設定Fでは、比較例1においては、例えば、全面がほぼ同一色といった全体的にコントラストが低い画像が出力され、比較例2においては、ほぼ雑音データが出力されている。一方、本実施形態に係るSNによれば、比較的コントラストが高いデータが生成されている。
【0083】
図8は、SN及び比較例1における重い行列の特異値の2乗値の存在する領域を示す理論値を示す図である。図8及び図9においては、重み行列の特異値を昇順に並べ、当該インデクスを横軸とし、縦軸として各特異値の2乗値を最大の特異値で正規化したものである。重み行列に対して、様々な状況において取り得る理論値の分布を示すものが図8のグラフである。実線は、SNでの理論値、破線は、比較例1での理論値を示す。この図8に示すように、SNでは、比較例1と比べ特異値の存在する幅が広いことが分かる。
【0084】
比較例1においては、特異値の非対称な分布となることから、重み行列の列空間は、低次元のベクトル空間となる。一方、SNにおいては、ディスクリミネータにより使用される特徴の次元数を妥協することなく利用できる。重み行列を掛けられた隠れ層の出力のノルムをできるだけ保持し、ディスクリミネータをより高精度にするためには、この(正規化された)ノルムを大きくすることが望まれる。
【0085】
例えば、比較例1においては、ノルムを大きくするためには、ランクを下げることとなるが、ランクを下げると、ディスクリミネータにおいて判断材料となる特徴量の数を減少させることとなる。より詳しくは、それぞれの特徴量に基づく判断は、特異値のノルムの大きさに依存する。すなわち、比較例1のように、一部の特異値だけが大きい値をとり、他の特異値がほぼ0となるような場合、特異値のノルムが大きい特徴量が重視され、特異値のノルムが小さい特徴量が判断に及ぼす影響が小さくなる。しかしながら、より高精度のディスクリミネータを学習するためには、特徴量の数を減少させることは得策ではない。このように比較例1においては、より高精度のディスクリミネータを学習するために、ノルムを大きくする(多くのノルムを取得可能とする)ことと、特徴量の数を減少させないことを両立することが困難である。
【0086】
図8に示すように、SNにおいては、比較例1に比べて正規化されたノルムを大きく保つことが可能である。すなわち、SNでは、ノルムを大きく保つことと、特徴量の数を減少させないことを両立することが可能となる。これは、線形演算のリプシッツ定数が最大の特異値によってのみ評価されることに基づく。すなわち、スペクトルノルムは、行列のランクとは独立していることに基づく。
【0087】
図9A及び図9Bは、異なる手法であるSN及び比較例1を用いて学習を行った場合の各層における重み行列の特異値の2乗値をそれぞれ示すグラフである。図9Aは、CIFAR-10のデータセットを用いたもの、図9Bは、STL-10のデータセットを用いたものである。図8と同様に、実線は、SNによる結果を示し、破線は、比較例1による結果を示す。
【0088】
図9A及び図9Bに示されるように、SNによれば、比較例1よりもほとんどの範囲において特異値のノルムが大きくなっていることが分かる。このように、特異値のノルムを大きくし、かつ、ランクを下げないようにすることが可能となり、正規化を行う場合において、ランク安定性を確保することができる。
【0089】
第1層乃至第5層において、比較例1においては、いくつかの値に集中している。すなわち、比較例1においては、これらの層における重み行列のランクが不足している。一方、SNにおいては、広く分布している。高次元空間にエンベデッドされた低次元非線形データの多様性乗の確率分布の対を区別することが目標である場合、下位層におけるランク不足は特に致命的となる蓋然性がある。下位層の出力は、線形変換の数少ないセットを介した出力であり、ほとんどの部分が線形である空間に偏向していることを示す。このような空間における入力分布の多くの特徴を過小評価してしまうことは、過剰に単純化されたディスクリミネータを生成することに繋がる。
【0090】
図7に示すシミュレーション結果によれば、このように過剰に単純化されたディスクリミネータが及ぼす影響を実際に確認することができる。スペクトル正規化を用いて生成された画像は、比較例1による画像よりも多様性を有し、かつ、複雑な画像である。
【0091】
図10A及び図10Bは、ジェネレータを100回更新した場合における演算時間を示す図である。縦軸は、各種法におけるジェネレータを100回更新した場合における演算時間[秒]を示す。図10Aは、データセットCIFAR-10を用いた結果であり、図10Bは、データセットSTL-10を用いた結果である。比較例2においては、誤差関数として、GP(Gradient Penalty)である||∇D||を余分に求める必要があるので、他の手法に比べて長い時間が必要となっている。
【0092】
これらの図において、SNは、比較例1とほぼ同等の時間で演算をできていることが示されている。これは、べき乗法に必要となる相対的な計算コストは、順伝播及び逆伝播のコストと比較して無視できる程度に小さいためである。
【0093】
以上のように、本実施形態に係るSN手法よれば、GANにおける各層の重み行列の更新において、スペクトル正規化を行った重み行列を用いることにより、安定したディスクリミネータの学習を実現することが可能となる。結果から読み取れるとおり、比較例と比べて多様性があり、複雑な画像を生成することが可能である。さらに、処理時間に関しては、比較例1と比べそれほど長い時間が掛かる訳ではなく、例えば、STL-10のデータセットを用いた場合等は、ほぼ同等の時間で処理を行える。
【0094】
(変形例)
前述の実施形態においては、GANの手法にスペクトル正規化を適用する例について述べたが、これには限られない。すなわち、WGAN-GP(比較例2)の手法において、スペクトル正規化を適用してもよい。以下、比較例2にSNを適用したものを、比較例2+SN等と表す。この場合、誤差関数は、以下の式を用いる。
【数21】
【0095】
シミュレーション結果は、図11に示す通りである。図11は、比較例2+SNについてのインセプションスコアを示す表である。スタンダードなCNNについてSNを適用したもの、及び、ResNet(Residual Network)を用いたCNNについてSNを適用したものを示している。比較として、比較例2+比較例1、及び、比較例2の結果も示している。シミュレーションにおいては、ディスクリミネータの学習における[数11]で表される関数を全てSN及び比較例1の手法により正規化した。図11から、比較例2、及び、比較例2+比較例1に比べてインセプションスコアが改善されている。
【0096】
図12Aは、評価としてのロス、図12Bは、バリデーションのロスを示す図である。実線は、比較例2+SN、破線は、比較例2+比較例1、点線は、比較例2による結果を示す。これらの図から、比較例2+SNによる学習は、比較例2及び比較例2+比較例1による学習よりも過学習をしていないことが示される。特に図12Bから、バリデーションデータに対しても評価値が下がっていないことから、比較例2+SNによる学習が他の手法に比べて過学習をしていないことが示されている。
【0097】
図13は、同シミュレーション状況におけるインセプションスコアを示す図である。実線は、比較例2+SN、破線は、比較例2+比較例1、点線は、比較例2による結果を示す。この図13からも、過学習の度合いがそれぞれの場合について示されている。最終的な結果ではなく、学習中に最適なものを抽出したとしても、比較例2+SNによる手法(7.28)は、他の手法(7.04、6.69)と比べてよいインセプションスコアであることが示される。
【0098】
以上のように、スタンダードなGANだけではなく、WGAN-GPの手法においても、本実施形態に係るスペクトル正規化は、より安定なディスクリミネータの学習を提供することが可能である。
【0099】
なお、上述の例では、正解データであるか否かのラベルを備えるトレーニングデータに基づいてデータの識別を行うデータ識別器としての訓練について記載したがこれには限られない。GANには限られず、例えば、カテゴリによりラベル付けされたトレーニングデータを用いて、上述の重み行列の更新を行うことにより、分類器として訓練を行うことも可能である。さらに、これらの他にも、一般的な重み行列を更新することにより訓練を行うニューラルネットワークモデルであれば、当該重み行列の更新に上述したスペクトル正規化を用いた訓練方法を適用することが可能である。[数8]に表される正規化を行うことにより、これらの重み付け行列の正規化を行い、重み付け行列の更新を行うことが可能となる。さらには、複数のラベル付けされたトレーニングデータを用いて、入力に対して連続値又は離散値を出力できるような訓練を行うようにしてもよい。
【0100】
上記の全ての記載において、データ判別器生成装置1の少なくとも一部はハードウェアで構成されていてもよいし、ソフトウェアで構成され、ソフトウェアの情報処理によりCPU等が実施をしてもよい。ソフトウェアで構成される場合には、データ判別器生成装置1及びその少なくとも一部の機能を実現するプログラムをフレキシブルディスクやCD-ROM等の記憶媒体に収納し、コンピュータに読み込ませて実行させるものであってもよい。記憶媒体は、磁気ディスクや光ディスク等の着脱可能なものに限定されず、ハードディスク装置やメモリなどの固定型の記憶媒体であってもよい。すなわち、ソフトウェアによる情報処理がハードウェア資源を用いて具体的に実装されるものであってもよい。さらに、ソフトウェアによる処理は、FPGA(Field-Programmable Gate Array)等の回路に実装され、ハードウェアが実行するものであってもよい。仮想環境の構築等の処理は、例えば、GPU(Graphical Processing Unit)等のアクセラレータを使用して行ってもよい。
【0101】
データ判別器生成装置1及び当該装置により生成されたデータ判別器は、上記のようにプログラムにより生成されるほか、アナログ回路又はデジタル回路により構成されていてもよい。この場合、一部又は全ての機能を制御する、制御回路を備えていてもよい。すなわち、データ判別器生成装置1及びデータ判別器は、制御回路と、メモリを備え、その機能の一部又は全部が制御回路により制御されるものであってもよい。
【0102】
上記の全ての記載に基づいて、本発明の追加、効果又は種々の変形を当業者であれば想到できるかもしれないが、本発明の態様は、上記した個々の実施形態に限定されるものではない。特許請求の範囲に規定された内容及びその均等物から導き出される本発明の概念的な思想と趣旨を逸脱しない範囲において種々の追加、変更及び部分的削除が可能である。
【符号の説明】
【0103】
1:データ判別器生成装置、10:データ判別器初期化部、12:データ判別器記憶部、14:入力部、16:順伝播部、18:誤差算出部、20:逆伝播ブロック、200:逆伝播部、202:更新部
図1
図2
図3
図4A
図4B
図5A
図5B
図6
図7
図8
図9A
図9B
図10A
図10B
図11
図12A
図12B
図13