IP Force 特許公報掲載プロジェクト 2022.1.31 β版

知財求人 - 知財ポータルサイト「IP Force」

▶ 国立大学法人 東京大学の特許一覧 ▶ トヨタ自動車株式会社の特許一覧

特開2024-51536学習システム、学習方法、およびプログラム
<>
  • 特開-学習システム、学習方法、およびプログラム 図1
  • 特開-学習システム、学習方法、およびプログラム 図2
  • 特開-学習システム、学習方法、およびプログラム 図3
  • 特開-学習システム、学習方法、およびプログラム 図4
  • 特開-学習システム、学習方法、およびプログラム 図5
  • 特開-学習システム、学習方法、およびプログラム 図6
< >
(19)【発行国】日本国特許庁(JP)
(12)【公報種別】公開特許公報(A)
(11)【公開番号】P2024051536
(43)【公開日】2024-04-11
(54)【発明の名称】学習システム、学習方法、およびプログラム
(51)【国際特許分類】
   G06N 3/08 20230101AFI20240404BHJP
【FI】
G06N3/08
【審査請求】未請求
【請求項の数】5
【出願形態】OL
(21)【出願番号】P 2022157754
(22)【出願日】2022-09-30
(71)【出願人】
【識別番号】504137912
【氏名又は名称】国立大学法人 東京大学
(71)【出願人】
【識別番号】000003207
【氏名又は名称】トヨタ自動車株式会社
(74)【代理人】
【識別番号】100103894
【弁理士】
【氏名又は名称】家入 健
(72)【発明者】
【氏名】鹿島 大河
(72)【発明者】
【氏名】中山 英樹
(72)【発明者】
【氏名】安間 絢子
(57)【要約】
【課題】機械学習に使用したデータの漏洩を防止できる学習システム、学習方法、およびプログラムを実現する。
【解決手段】本開示の一形態に係る学習システム100は、複数のクライアント端末110とサーバ120を備える。各クライアント端末110は、モデルの重み係数を学習し、重み係数からなるベクトルとキーベクトルの内積をサーバ120に送信する。サーバ120は、複数のクライアント端末110からそれぞれ受信した内積に基づいて、機械学習モデルの重み係数を最適化する。
【選択図】図1
【特許請求の範囲】
【請求項1】
複数のクライアント端末とサーバを備える学習システムであって、
各クライアント端末が、
機械学習モデルの重み係数を学習し、前記重み係数からなるベクトルとキーベクトルの内積を前記サーバに送信し、
前記サーバが、
前記複数のクライアント端末からそれぞれ受信した前記内積に基づいて、前記機械学習モデルの重み係数を最適化する、
学習システム。
【請求項2】
前記サーバは、
前記キーベクトルを独立変数とし、前記内積を従属変数とする線形回帰を行うことで、前記機械学習モデルの重み係数を最適化する、
請求項1に記載の学習システム。
【請求項3】
前記サーバは、ロバスト回帰またはラッソ回帰を行う、
請求項2に記載の学習システム。
【請求項4】
各クライアント端末が、機械学習モデルの重み係数を学習し、前記重み係数からなるベクトルとキーベクトルの内積をサーバに送信するステップと、
前記サーバが、複数のクライアント端末からそれぞれ受信した前記内積に基づいて、前記機械学習モデルの重み係数を最適化するステップと、
を含む学習方法。
【請求項5】
コンピュータに、
機械学習モデルの重み係数からなるベクトルとキーベクトルの内積を各クライアント端末から受信する処理と、
複数のクライアント端末からそれぞれ受信した前記内積に基づいて、前記機械学習モデルの重み係数を最適化する処理と、
を実行させるプログラム。
【発明の詳細な説明】
【技術分野】
【0001】
本開示は、学習システム、学習方法、およびプログラムに関する。
【背景技術】
【0002】
サーバと複数のクライアント端末が協力して学習を行う連合学習システムが提案されている(例えば、特許文献1参照)。連合学習システムでは、各クライアント端末が機械学習モデルの重み係数を学習してサーバに送信する。サーバは、複数のクライアント端末からそれぞれ受け取った重み係数に基づいて、機械学習モデルの重み係数を更新する。サーバは、更新した重み係数を複数のクライアント端末に配布する。
【先行技術文献】
【特許文献】
【0003】
【特許文献1】特開2019-028656号公報
【発明の概要】
【発明が解決しようとする課題】
【0004】
本出願人は、以下の課題を見出した。重み係数に基づいて、学習に用いた画像(例えば、家庭内の画像)を復元される可能性が指摘されている。したがって、重み係数をそのまま送信することは、プライバシー保護の観点で問題がある。また、性能のよい構成の大きなモデルの場合、重み係数の送信の負荷が高いという問題がある。さらに、悪意のあるユーザがモデルの性能を低下させる可能性が指摘されている。
【0005】
本開示は、このような問題点に鑑みてなされたものであり、機械学習に使用したデータの漏洩を防止できる学習システム、学習方法、およびプログラムを実現する。
【課題を解決するための手段】
【0006】
本開示の一態様の学習システムは、複数のクライアント端末とサーバを備える学習システムであって、
各クライアント端末が、
機械学習モデルの重み係数を学習し、前記重み係数からなるベクトルとキーベクトルの内積を前記サーバに送信し、
前記サーバが、
前記複数のクライアント端末からそれぞれ受信した前記内積に基づいて、前記機械学習モデルの重み係数を最適化する。
【0007】
本開示の一態様の学習方法は、
各クライアント端末が、機械学習モデルの重み係数を学習し、前記重み係数からなるベクトルとキーベクトルの内積をサーバに送信するステップと、
前記サーバが、複数のクライアント端末からそれぞれ受信した前記内積に基づいて、前記機械学習モデルの重みを最適化するステップと
を含む。
【0008】
本開示の一態様のプログラムは、
コンピュータに、
機械学習モデルの重み係数からなるベクトルとキーベクトルの内積を各クライアント端末から受信する処理と、
複数のクライアント端末からそれぞれ受信した前記内積に基づいて、前記機械学習モデルの重み係数を最適化する処理と、
を実行させる。
【発明の効果】
【0009】
本開示によれば、機械学習に使用したデータの漏洩を防止できる学習システム、学習方法、およびプログラムを実現できる。
【図面の簡単な説明】
【0010】
図1】実施形態1にかかる学習システムの構成を説明する図である。
図2】実施形態1にかかる学習システムの動作の検証結果を説明する図である。
図3】実施形態1にかかる学習システムの動作の検証結果を説明する図である。
図4】実施形態1にかかる学習システムでロバスト回帰を用いた場合の検証結果を説明する図である。
図5】実施形態1にかかる学習システムでロバスト回帰を用いた場合の検証結果を説明する図である。
図6】実施形態1にかかる学習システムでラッソ回帰を用いた場合の検証結果を説明する図である。
【発明を実施するための形態】
【0011】
以下、本開示を適用した具体的な実施の形態について、図面を参照しながら詳細に説明する。但し、本開示が以下の実施の形態に限定される訳ではない。また、説明を明確にするため、以下の記載及び図面は、適宜、簡略化されている。
【0012】
本開示に至る経緯
まず、図1を参照して、一般的な連合学習方法の流れについて説明する。図1に示す学習システム100は、クライアント端末110_1、クライアント端末110_2、クライアント端末110_3、およびサーバ120を備える。クライアント端末110_1、クライアント端末110_2、クライアント端末110_3、およびサーバ120は、ネットワークNを介して相互に通信可能に接続されている。
【0013】
クライアント端末110_1は家屋10_1に配置され、クライアント端末110_2は家屋10_2に配置され、クライアント端末110_3は家屋10_3に配置される。家屋10_1、家屋10_2、および家屋10_3を互いに区別しない場合、家屋10と言う。クライアント端末110_1、110_2、および110_3を互いに区別しない場合には、クライアント端末110と言う。
【0014】
まず、サーバ120は、Webから収集した画像や、3次元モデル(シミュレータとも言う)から生成した画像を用いて機械学習モデル(例えば、物体認識用のモデル)をトレーニングする。そして、サーバ120は、機械学習モデルの重み係数の初期値を、クライアント端末110_1、110_2、および110_3に配布する。
【0015】
次に、クライアント端末110_1は家屋10_1で学習データ20_1を収集し、クライアント端末110_2は家屋10_2で学習データ20_2を収集し、クライアント端末110_3は家屋10_3で学習データ20_3を収集する。学習データ20_1、20_2、および20_3を互いに区別しない場合、学習データ20と言う。
【0016】
次に、クライアント端末110_1が学習データ20_1で機械学習モデルの重み係数を更新し、クライアント端末110_2が学習データ20_2で機械学習モデルの重み係数を更新し、クライアント端末110_3が学習データ20_3で機械学習モデルの重み係数を更新する。次に、各クライアント端末110は、学習した機械学習モデルの重み係数をサーバ120に送信する。
【0017】
次に、サーバ120は、クライアント端末110_1から受け取った重み係数、クライアント端末110_2から受け取った重み係数、およびクライアント端末110_3から受け取った重み係数に基づいて、機械学習モデルの重み係数を更新する。サーバ120は、受け取った重み係数の平均値を計算し、新たな重み係数としてもよい。
【0018】
次に、サーバ120は、更新した機械学習モデルの重み係数をクライアント端末110_1、クライアント端末110_2、およびクライアント端末110_3に配布する。そして、クライアント端末110が学習データ20を収集する工程に戻る。
【0019】
一般的な連合学習方法では、クライアント端末110は、学習データ20ではなく、機械学習モデルの重み係数を送信する。これにより、学習データ20が漏洩するリスクを低減できる。しかし、学習に用いた画像を、重み係数から復元される可能性があることが知られている。また、性能のよいモデルの場合、重み係数の送信の負荷が高いという問題がある。さらに、悪意のあるユーザが連合学習に参加している場合、機械学習モデルの性能が低下する恐れがある。本願の発明者は、以上の経緯から実施形態1にかかる学習システムに想到した。
【0020】
実施形態1
以下、図面を参照して実施形態1にかかる学習システムについて説明する。まず、図1を参照して、実施形態1にかかる学習システム100の概要について説明する。
【0021】
図1ではクライアント端末110の数が3台の場合を示しているが、クライアント端末110の数は2台であってもよく、4台以上であってもよい。クライアント端末110は、カメラとPCの組み合わせであってもよく、ロボットであってもよい。
【0022】
クライアント端末110およびサーバ120は、例えばCPU(Central Processing Unit)などの演算部と、各種制御プログラムやデータ等が格納されたRAM(Random Access Memory)、ROM(Read Only Memory)等の記憶部とを備える。すなわち、クライアント端末110およびサーバ120は、コンピュータとしての機能を有しており、上記各種制御プログラム等に基づいて処理を行う。
【0023】
次に、学習システム100の動作について説明する。学習システム100では、まず、クライアント端末110_1、110_2、および110_3が、収集した学習データ20_1、20_2、および20_3を用いて機械学習モデルの重み係数を表す行列W1、W2、およびW3を学習する。行列W1、W2、およびW3を互いに区別しない場合、行列Wと言う。行列W1に含まれる各列ベクトルをベクトルw1と言い、行列W2に含まれる各列ベクトルをベクトルw2と言い、行列W3に含まれる各列ベクトルをベクトルw3と言う。ベクトルw1、w2、およびw3を互いに区別しない場合、ベクトルwと言う。ベクトルwは、機械学習モデルの重み係数からなる。
【0024】
学習データ20は、例えば、機械学習に有用な画像と、その画像に写った物体を示すアノテーションとを含む。クライアント端末110は、画像をユーザに提示し、ユーザの回答に応じてアノテーションを定めてもよい。また、アノテーションは、ユーザの動作や反応に応じて定められてもよい。画像は、家屋10内に設置されたカメラやロボットによって撮影されてもよい。なお、機械学習モデルは物体認識用のモデルには限られない。
【0025】
次に、クライアント端末110_1がキーベクトルv1を生成し、クライアント端末110_2がキーベクトルv2を生成し、クライアント端末110_3がキーベクトルv3を生成する。ベクトルv1、v2、およびv3を互いに区別しない場合、キーベクトルvと言う。キーベクトルvの成分の数は、ベクトルwの成分の数と一致しているものとする。例えば、vがn次元ベクトルである場合、v=[v,v,・・・,v]と表される。
【0026】
クライアント端末110_1はベクトルw1とキーベクトルv1の内積y1を計算する。クライアント端末110_2はベクトルw2とキーベクトルv2の内積y2を計算する。クライアント端末110_3はベクトルw3とキーベクトルv3の内積y3を計算する。内積y1、y2、およびy3を互いに区別しない場合、内積yと言う。
【0027】
ベクトルwは行列Wの列数だけ存在するため、内積yは行列Wの列数だけ存在する。複数の内積yを並べたベクトルを、内積ベクトルYと言う。行列Wのl番目(lはアルファベットのエルを表す)の列ベクトルをWとし、l番目の内積の計算結果をYと表すと、Y=<W,v>=Wl1・v+Wl2・v+・・・である。なお、Wljは、Wのj番目の要素を表す。
【0028】
ところで、aおよびbがn次元ベクトルである場合、内積cは、c=<a,b>=a・b+a・b+・・・+a・bで与えられる。aはベクトルaのj番目の要素を表し、bはベクトルbのj番目の要素を表す。内積cとベクトルaが与えられても、ベクトルbは一意に定まらないという特徴がある。例えば、a=[1,1,1,1,1]、c=[1]とした場合、c=<a,b>を満たすベクトルbは一意に定めらない。さらに、内積cの次元は、ベクトルaの次元やベクトルbの次元よりも小さいという特徴がある。
【0029】
したがって、内積yは、ベクトルwをキーベクトルvで圧縮したデータであると考えることもできる。クライアント端末110は、行列Wに含まれる複数の列ベクトルをキーベクトルvで圧縮する。
【0030】
次に、クライアント端末110_1が、キーベクトルv1および内積ベクトルY1をサーバ120に送信する。クライアント端末110_2が、キーベクトルv2および内積ベクトルY2をサーバ120に送信する。クライアント端末110_3が、キーベクトルv3および内積ベクトルY3をサーバ120に送信する。
【0031】
クライアント端末110は、行列Wに含まれる複数の列ベクトルを圧縮する。通常、キーベクトルvのデータ量と内積ベクトルYのデータ量の和は、行列Wのデータ量よりも小さい。
【0032】
次に、サーバ120は、内積ベクトルY1~Y3およびキーベクトルv1~v3に基づいて、機械学習モデルの重み係数を最適化する。具体的には、サーバ120は、内積ベクトルYをキーベクトルvで説明するモデルをY=Wvと表し、最適な行列Wを推定する。行列Wは、行列Wの転置行列を表す。
【0033】
学習システム100は、キーベクトルvを独立変数とし、内積ベクトルYを従属変数とする線形回帰を行うことで機械学習モデルの重み係数を最適化する。回帰分析を行うときには、最小二乗法を用いてもよく、ロバスト回帰を用いてもよく、ラッソ回帰を用いてもよい。最小二乗法を用いる場合、サーバ120は、式(1)に示す損失関数Lを最小化するように重み係数を算出する。行列Wは、重み係数を表す行列Wを最適化した行列を表す。Yは、内積ベクトルY1、Y2、およびY3を並べた行列を表す。Vは、キーベクトルv1、v2、およびv3を並べた行列を表す。
【数1】
【0034】
また、サーバ120がロバスト回帰を用いる場合、悪意のあるユーザからの攻撃によって機械学習モデルの性能が低下することを抑制できる。サーバ120がラッソ回帰を用いた場合、不要な重み係数を低減できるため、通信負荷を特に低減できる。
【0035】
クライアント端末110の数が1台であり、クライアント端末110_1のみが存在した場合、内積ベクトルY1のl番目の要素であるY1とキーベクトルv1から、Y1=Wl1・v1+Wl2・v1+・・・を満たすWを推定しなければならない。しかし、内積の性質から、内積Y1とキーベクトルv1のみからWを一意に定めることはできない。学習システム100は、複数のクライアント端末110が存在することを前提としている。サーバ120は、複数のクライアント端末110で計算した内積から、一つのWを復号すればよい。
【0036】
学習システム100は、上記処理を1ラウンド(Communication Roundと言う)とし、ラウンドを繰り返すことで学習を進める。
【0037】
学習システム100では、各クライアント端末110は、機械学習モデルの重み係数ではなく、内積とキーベクトルを送信する。内積とキーベクトルが傍受されたとしても、機械学習モデルの重み係数を一意に定めることはできないため、連合学習に参加するユーザのプライバシーの保護を強化できる。
【0038】
さらに、内積のデータ量とキーベクトルのデータ量の和は、重み係数を表す行列のデータ量よりも小さい。したがって、学習システム100により、通信負荷を低減できる。
【0039】
図1では、クライアント端末110がキーベクトルの生成を行う場合を説明したが、サーバ120が、キーベクトルvを生成し、キーベクトルvをクライアント端末110に配布してもよい。この場合、クライアント端末110は、内積を送信するときにキーベクトルを送信しなくてもよい。
【0040】
次に、図2および図3を参照して、学習システム100の動作の検証結果を説明する。図2では、機械学習モデルとしてCNN(Convolutional Neural Network)を用いた。データセットとして、CIFAR(Canadian Institute for Advanced Research)-10を用いた。機械学習モデルの重み係数は、最小二乗法により推定した。横軸はラウンド数を示しており、縦軸は正解率(Accuracy)を示す。
【0041】
実施形態1を用いた場合の正解率は点線で示されている。グラフ31、グラフ32、グラフ33、およびグラフ34は、クライアント端末110の数が1台、30台、50台、および100台である場合の正解率を示す。
【0042】
従来技術を用いた場合の正解率は実線で示されている。従来技術では、重み係数の平均値を新たな重み係数とした。グラフ41、グラフ42、グラフ43、およびグラフ44は、クライアント端末110の数が1台、30台、50台、および100台である場合の正解率を示す。グラフ42、43、および44は、ほぼ重なっている。
【0043】
図2を参照すると、クライアント端末110の数が少ない場合には正解率が低いが、クライアント端末110の数が多い場合には正解率が高い。クライアント端末110の数が多い場合、高い精度の機械学習モデルを得られる。クライアント端末110の数が増加するにつれて、実施形態1における正解率は従来技術における正解率(ベースライン)に近づく。
【0044】
また、図3は、機械学習モデルとして多層パーセプトロン(MLP:MultiLayer Perceptron)を用いた場合の検証結果を示す。データセットとして、MNIST(Modified National Institute of Standards and Technology)を用いた。
【0045】
実施形態1を用いた場合の正解率は点線で示されている。グラフ51、グラフ52、グラフ53、およびグラフ54は、クライアント端末110の数が1台、30台、50台、および100台である場合の正解率を示す。
【0046】
従来技術を用いた場合の正解率は実線で示されている。グラフ61、グラフ62、グラフ63、およびグラフ64は、従来技術を用いた場合の正解率を示す。グラフ61、グラフ62、グラフ63、およびグラフ64は、クライアント端末110の数が1台、30台、50台、100台である場合の正解率を示す。グラフ62、63、および64は、ほぼ重なっている。
【0047】
図2と同様に、クライアント端末110の数が増加すると高い精度の機械学習モデルを得られることがわかる。実施形態1は、機械学習モデルが多層パーセプトロンである場合にも効果を奏する。
【0048】
次に、図4および図5を参照して、ロバスト回帰を用いる場合について説明する。ロバスト回帰は、機械学習モデルの性能を低下させる攻撃に対して頑健である。ロバスト回帰では、Huber損失が用いられる。
【0049】
図4は、データセットとしてCIFARを用いた場合の検証結果を示す。グラフ71~73は、悪意のあるクライアントの数が1である場合の正解率を示す。グラフ71~73は、従来技術を用いた場合、損失としてL2損失を用いた場合、損失としてHuber損失を用いた場合の正解率を示す。
【0050】
グラフ74~76は、悪意のあるクライアントの数が5である場合の正解率を示す。グラフ74~76は、従来技術を用いた場合、損失としてL2損失を用いた場合、損失としてHuber損失を用いた場合の正解率を示す。
【0051】
図5は、データセットとして、MNISTを用いた場合の検証結果を示す。グラフ81~83は、悪意のあるクライアントの数が1である場合の正解率を示す。グラフ81~83は、従来技術を用いた場合、損失としてL2損失を用いた場合、損失としてHuber損失を用いた場合の正解率を示す。
【0052】
グラフ84~86は、悪意のあるクライアントの数が5である場合の正解率を示す。グラフ84~86は、従来技術を用いた場合、損失としてL2損失を用いた場合、損失としてHuber損失を用いた場合の正解率を示す。
【0053】
図4および図5を参照すると、従来技術やL1損失を用いた場合、攻撃によって機械学習モデルの性能が低下している。Huber損失を用いる場合、つまりロバスト回帰を用いる場合、機械学習モデルは、悪意のあるクライアントからの攻撃に対して頑健になっている。
【0054】
次に、図6を参照して、ラッソ回帰を用いる場合について説明する。表90の一列目は、従来技術を用いるか実施形態1を用いるかを表す。二列目は、クライアント側およびサーバ側でラッソ回帰を行うか否かを表す。「Yes」はラッソ回帰を行うことを表し、「No」はラッソ回帰を行わないことを表す。ラッソ回帰では、損失としてL1損失が用いられる。三列目および四列目は、誤答率およびスパーシティを表す。
【0055】
一行目および二行目は、従来技術を用いた場合の誤答率およびスパーシティを示す。一行目はクライアント端末110でラッソ回帰を行わない場合を表し、二行目はクライアント端末110でラッソ回帰を行う場合を表す。なお、従来技術ではサーバ側で重み係数の推定を行わないため、サーバ側は「No」になっている。
【0056】
三行目、四行目、五行目、および六行目は、実施形態1を用いた場合の検証結果を示す。三行目は、クライアント端末110およびサーバ120の両方がラッソ回帰を行わない場合を表す。四行目は、クライアント端末110がラッソ回帰を行い、サーバ120がラッソ回帰を行わない場合を表す。五行目は、クライアント端末110がラッソ回帰を行わず、サーバ120がラッソ回帰を行う場合を表す。六行目は、クライアント端末110がラッソ回帰を行い、サーバ120がラッソ回帰を行う場合を示す。
【0057】
図6を参照すると、サーバ120がラッソ回帰を用いた場合のスパーシティが高く、機械学習モデルをスパースにできることがわかる。
【0058】
上述したプログラムは、コンピュータに読み込まれた場合に、1又はそれ以上の機能をコンピュータに行わせるための命令群(又はソフトウェアコード)を含む。プログラムは、非一時的なコンピュータ可読媒体又は実体のある記憶媒体に格納されてもよい。限定ではなく例として、コンピュータ可読媒体又は実体のある記憶媒体は、random-access memory(RAM)、read-only memory(ROM)、フラッシュメモリ、solid-state drive(SSD)又はその他のメモリ技術、CD-ROM、digital versatile disc(DVD)、Blu-ray(登録商標)ディスク又はその他の光ディスクストレージ、磁気カセット、磁気テープ、磁気ディスクストレージ又はその他の磁気ストレージデバイスを含む。プログラムは、一時的なコンピュータ可読媒体又は通信媒体上で送信されてもよい。限定ではなく例として、一時的なコンピュータ可読媒体又は通信媒体は、電気的、光学的、音響的、またはその他の形式の伝搬信号を含む。
【0059】
なお、本開示は上記実施の形態に限られたものではなく、趣旨を逸脱しない範囲で適宜変更することが可能である。
【符号の説明】
【0060】
100 学習システム
110、110_1、110_2、110_3 クライアント端末
120 サーバ
10、10_1、10_2、10_3 家屋
20、20_1、20_2、20_3 学習データ
31~34、41~44、51~54、61~64、71~76、81~86 グラフ
90 表
図1
図2
図3
図4
図5
図6