(19)【発行国】日本国特許庁(JP)
(12)【公報種別】公開特許公報(A)
(11)【公開番号】P2023069960
(43)【公開日】2023-05-18
(54)【発明の名称】仲介処理に基づく連合学習方法及び連合学習システム
(51)【国際特許分類】
G06N 3/08 20230101AFI20230511BHJP
G06N 20/00 20190101ALI20230511BHJP
【FI】
G06N3/08
G06N20/00
【審査請求】有
【請求項の数】18
【出願形態】OL
(21)【出願番号】P 2021192090
(22)【出願日】2021-11-26
(31)【優先権主張番号】110141246
(32)【優先日】2021-11-05
(33)【優先権主張国・地域又は機関】TW
(71)【出願人】
【識別番号】599060434
【氏名又は名称】財團法人資訊工業策進會
(74)【代理人】
【識別番号】100108453
【弁理士】
【氏名又は名称】村山 靖彦
(74)【代理人】
【識別番号】100110364
【弁理士】
【氏名又は名称】実広 信哉
(74)【代理人】
【識別番号】100133400
【弁理士】
【氏名又は名称】阿部 達彦
(72)【発明者】
【氏名】王 秉豐
(72)【発明者】
【氏名】許 群昇
(72)【発明者】
【氏名】周 志遠
(57)【要約】
【課題】仲介処理に基づく連合学習方法及び連合学習システムを提供する。
【解決手段】連合学習方法は、複数の仲介者モジュールを生成するように、複数のクライアント装置を複数の仲介者グループに分けて;サーバー装置によって、初期モデル重みデータを前記複数の仲介者モジュールにブロードキャストし;前記複数の仲介者モジュールによって、前記複数の仲介者グループのそれぞれについて逐次訓練プロセスを行って、ターゲットモデルを訓練し、訓練済モデル重みデータを生成し;前記サーバー装置によって、加重連合平均アルゴリズムを実行することで、グローバルモデル重みデータを生成し;前記サーバー装置によって、前記グローバルモデル重みデータに基づいてターゲットモデルを構築することによりグローバルターゲットモデルを生成する。
【選択図】
図3
【特許請求の範囲】
【請求項1】
仲介処理に基づく連合学習方法であって、
サーバー装置は、複数のクライアント装置における複数のデータ分布情報に基づいて前記複数のクライアント装置を複数の仲介者グループに分けて、前記複数の仲介者グループを管理する複数の仲介者モジュールを生成し、
前記サーバー装置は、初期モデル重みデータを前記複数の仲介者モジュールにブロードキャストし、
前記複数の仲介者モジュールは、前記複数の仲介者グループのそれぞれに対し逐次訓練プロセスを行い、前記逐次訓練プロセスは、
前記複数のクライアント装置に対応するトレーニングシーケンスを決定し、
前記初期モデル重みデータを前記複数のクライアント装置に対応的に送信し、対応的な前記複数のクライアント装置は、複数のローカル情報を複数の訓練データとして、前記初期モデル重みデータ及び前記トレーニングシーケンスに基づいてターゲットモデルを順序に訓練することで、訓練済モデル重みデータを生成し、
前記訓練済モデル重みデータを前記サーバー装置に送信し、
前記サーバー装置は、前記複数の仲介者グループにおける複数の前記訓練済モデル重みデータを取得し、前記複数の訓練済モデル重みデータに基づいて前記複数の仲介者グループのそれぞれに対応する複数の重みを算出し、
前記サーバー装置は、前記複数の重みに基づいて前記複数の訓練済モデル重みデータに対し加重連合平均アルゴリズムを行うことで、グローバルモデル重みデータを生成し、
前記サーバー装置は、前記グローバルモデル重みデータに基づいて前記ターゲットモデルを構築することでグローバルターゲットモデルを生成する、
ことを特徴とする、仲介処理に基づく連合学習方法。
【請求項2】
前記逐次訓練プロセスはさらに、前記複数の仲介者モジュールが前記複数の仲介者グループのそれぞれに対し耐障害性プロセスを行うことが含まれ、前記耐障害性プロセスは、
訓練が行われている前記クライアント装置の接続状態を監視し、
訓練が行われている前記クライアント装置がオフライン状態になったことを検出すると、対応的な前記複数のクライアント装置のそれぞれの装置状態を確認し、
前記複数の装置状態に応じて前記複数のクライアント装置から新規のクライアントエンド装置を選択し、前記オフライン状態となった前記クライアント装置が訓練予定のモデル重みを前記新規のクライアント装置に送信して、訓練を行う、請求項1に記載の仲介処理に基づく連合学習方法。
【請求項3】
訓練が行われている前記クライアント装置の接続状態を監視するプロセスでは、
訓練が行われている前記クライアント装置に周期的な信号を送信し、
訓練が行われている前記クライアント装置が所定時間内に前記周期的な信号に応答しなかったかどうかを判断し、
訓練が行われている前記クライアント装置が前記所定時間内に前記周期的な信号に応答しなかったと判断すると、訓練が行われている前記クライアント装置は、前記オフライン状態になったと決定する、請求項2に記載の仲介処理に基づく連合学習方法。
【請求項4】
前記オフライン状態となった前記クライアント装置が訓練予定の前記モデル重みを前記新規のクライアント装置に送信するプロセスは、
順番が前記オフライン状態になった前記クライアント装置の1つの前の訓練済の前記クライアント装置が訓練予定の前記モデル重みを前記新規のクライアント装置に送信する、請求項2に記載の仲介処理に基づく連合学習方法。
【請求項5】
前記耐障害性プロセスは、前記オフライン状態になった前記クライアント装置に関する耐障害性情報を記録し、前記サーバー装置に送信することを含み、
前記サーバー装置が前記グローバルモデル重みデータを生成するプロセスは、前記サーバー装置が前記複数の仲介者グループの複数の前記耐障害性情報に応じて前記グローバルモデル重みデータを生成する、請求項1に記載の仲介処理に基づく連合学習方法。
【請求項6】
前記複数のクライアント装置はそれぞれ前記複数のローカル情報を統計することによって前記複数のデータ分布情報を生成し、前記サーバー装置に送信する、請求項1に記載の仲介処理に基づく連合学習方法。
【請求項7】
前記複数のクライアント装置の前記トレーニングシーケンスを決定するプロセスはさらに、
対応的な前記複数のクライアント装置のそれぞれの装置状態を1つずつ確認し、
前記複数の装置状態に応じて前記トレーニングシーケンスを決定することを含む、請求項1に記載の仲介処理に基づく連合学習方法。
【請求項8】
対応的な前記複数のクライアント装置が、前記初期モデル重みデータ及び前記トレーニングシーケンスに基づいてターゲットモデルを順序に訓練するプロセスは、
前記初期モデル重みデータを前記トレーニングシーケンスにおける順番第1の前記クライアント装置に送信し、
前記第1の前記クライアント装置は、前記初期モデル重みデータによって前記ターゲットモデルを訓練し、訓練が終了すると、第1の訓練済モデル重みデータを生成し、
前記第1の訓練済モデル重みデータを前記トレーニングシーケンスにおける順番第2の前記クライアント装置に送信し、
前記第2の前記クライアント装置は、前記第1の訓練済モデル重みデータによって、前記ターゲットモデルを訓練し、訓練が終了すると、第2の訓練済モデル重みデータを生成する、ことを含む、請求項1に記載の仲介処理に基づく連合学習方法。
【請求項9】
前記逐次訓練プロセスは、
前記訓練済モデル重みデータに基づいて前記逐次訓練プロセスを再実行する必要があるかどうかを判断し、
前記逐次訓練プロセスを再実行する必要がある判断に応じて、前記トレーニングシーケンスを再決定し、前記訓練済モデル重みデータを前記初期モデル重みデータとして前記複数のクライアント装置に対応的に送信し訓練を行い、
前記逐次訓練プロセスを再実行する必要がない判断に応じて、前記訓練済モデル重みデータを前記サーバー装置に送信する、請求項1に記載の仲介処理に基づく連合学習方法。
【請求項10】
仲介処理に基づく連合学習システムであって、
複数のクライアント装置と、
前記複数のクライアント装置に通信接続され、前記複数のクライアント装置の複数のデータ分布情報に基づいて前記複数のクライアント装置を複数の仲介者グループに分けるサーバー装置と、
前記サーバー装置が生成して、前記複数の仲介者グループを対応的に管理する複数の仲介者モジュールと、を含み、
前記サーバー装置は、初期モデル重みデータを前記複数の仲介者モジュールにブロードキャストし、
前記複数の仲介者モジュールは、前記複数の仲介者グループのそれぞれに対し逐次訓練プロセスを行い、前記逐次訓練プロセスは、
前記複数のクライアント装置に対応するトレーニングシーケンスを決定し、
前記初期モデル重みデータを前記複数のクライアント装置に対応的に送信し、対応的な前記複数のクライアント装置は、複数のローカル情報を複数の訓練データとして、前記初期モデル重みデータ及び前記トレーニングシーケンスに基づいてターゲットモデルを順序に訓練することで、訓練済モデル重みデータを生成し、
前記訓練済モデル重みデータを前記サーバー装置に送信し、
前記サーバー装置は、前記複数の仲介者グループにおける複数の前記訓練済モデル重みデータを取得し、前記複数の訓練済モデル重みデータに基づいて前記複数の仲介者グループのそれぞれに対応する複数の重みを算出し、
前記サーバー装置は、前記複数の重みに基づいて前記複数の訓練済モデル重みデータに対し加重連合平均アルゴリズムを行うことで、グローバルモデル重みデータを生成し、
前記サーバー装置は、前記グローバルモデル重みデータに基づいて前記ターゲットモデルを構築することでグローバルターゲットモデルを生成する、
ことを特徴とする、仲介処理に基づく連合学習システム。
【請求項11】
前記逐次訓練プロセスはさらに、前記複数の仲介者モジュールが前記複数の仲介者グループのそれぞれに対し耐障害性プロセスを行うことが含まれ、前記耐障害性プロセスは、
訓練が行われている前記クライアント装置の接続状態を監視し、
訓練が行われている前記クライアント装置がオフライン状態になったことを検出すると、対応的な前記複数のクライアント装置のそれぞれの装置状態を確認し、
前記複数の装置状態に応じて前記複数のクライアント装置から新規のクライアント装置を選択し、前記オフライン状態となった前記クライアント装置が訓練予定のモデル重みを前記新規のクライアント装置に送信して、訓練を行う、請求項10に記載の仲介処理に基づく連合学習システム。
【請求項12】
訓練が行われている前記クライアント装置の接続状態を監視するプロセスでは、
訓練が行われている前記クライアント装置に周期的な信号を送信し、
訓練が行われている前記クライアント装置が所定時間内に前記周期的な信号に応答しなかったかどうかを判断し、
訓練が行われている前記クライアント装置が前記所定時間内に前記周期的な信号に応答しなかったと判断すると、訓練が行われている前記クライアント装置は、前記オフライン状態になったと決定する、請求項11に記載の仲介処理に基づく連合学習システム。
【請求項13】
前記オフライン状態となった前記クライアント装置が訓練予定の前記モデル重みを前記新規のクライアント装置に送信するプロセスは、
順番が前記オフライン状態になった前記クライアント装置の1つの前の訓練済の前記クライアント装置が訓練予定の前記モデル重みを前記新規のクライアント装置に送信する、請求項11に記載の仲介処理に基づく連合学習システム。
【請求項14】
前記耐障害性プロセスは、前記オフライン状態になった前記クライアント装置に関する耐障害性情報を記録し、前記サーバー装置に送信することを含み、
前記サーバー装置が前記グローバルモデル重みデータを生成するプロセスは、前記サーバー装置が前記複数の仲介者グループの複数の前記耐障害性情報に応じて前記グローバルモデル重みデータを生成する、請求項10に記載の仲介処理に基づく連合学習システム。
【請求項15】
前記複数のクライアント装置はそれぞれ前記複数のローカル情報を統計することによって前記複数のデータ分布情報を生成し、前記サーバー装置に送信する、請求項10に記載の仲介処理に基づく連合学習システム。
【請求項16】
前記複数のクライアント装置の前記トレーニングシーケンスを決定するプロセスはさらに、
対応的な前記複数のクライアント装置のそれぞれの装置状態を1つずつ確認し、
前記複数の装置状態に応じて前記トレーニングシーケンスを決定することを含む、請求項10に記載の仲介処理に基づく連合学習システム。
【請求項17】
対応的な前記複数のクライアント装置が、前記初期モデル重みデータ及び前記トレーニングシーケンスに基づいてターゲットモデルを順序に訓練するプロセスは、
前記初期モデル重みデータを前記トレーニングシーケンスにおける順番第1の前記クライアント装置に送信し、
前記第1の前記クライアント装置は、前記初期モデル重みデータによって前記ターゲットモデルを訓練し、訓練が終了すると、第1の訓練済モデル重みデータを生成し、
前記第1の訓練済モデル重みデータを前記トレーニングシーケンスにおける順番第2の前記クライアント装置に送信し、
前記第2の前記クライアント装置は、前記第1の訓練済モデル重みデータによって、前記ターゲットモデルを訓練し、訓練が終了すると、第2の訓練済モデル重みデータを生成する、ことを含む、請求項10に記載の仲介処理に基づく連合学習システム。
【請求項18】
前記逐次訓練プロセスは、
前記訓練済モデル重みデータに基づいて前記逐次訓練プロセスを再実行する必要があるかどうかを判断し、
前記逐次訓練プロセスを再実行する必要がある判断に応じて、前記トレーニングシーケンスを再決定し、前記訓練済モデル重みデータを前記初期モデル重みデータとして前記複数のクライアント装置に対応的に送信し訓練を行い、
前記逐次訓練プロセスを再実行する必要がない判断に応じて、前記訓練済モデル重みデータを前記サーバー装置に送信する、請求項10に記載の仲介処理に基づく連合学習システム。
【発明の詳細な説明】
【技術分野】
【0001】
本発明は、連合学習方法及び連合学習システムに関し、特に仲介処理に基づく連合学習方法及び連合学習システムに関する。
【背景技術】
【0002】
既存の連合学習(Federated Learning)方法では、ローカルデータはクライアントから出る必要はなく、ローカルデバイス上で学習して共有モデルを構築し、更新するようになる。この方法は、高いレベルのプライバシーを確保するだけでなく、大量のデータの転送を集中するためのコストが不要となる。しかし、異なるクライアントが収集したローカルデータには、環境や場所の要因によるデータの偏りがあるため、このデータの偏りが学習したモデルの精度を低下させることがある。
【0003】
さらに、既存の連合学習方法は、サーバーを介して複数のクライアント装置を連携させて学習を行うが、選定されたクライアント装置のネットワークトラフィックの乱れにより、サーバーで実行されている連合学習方法の集計処理が中断されることがある。
【発明の概要】
【発明が解決しようとする課題】
【0004】
本発明が解決しようとする課題は、既存の技術の不足に対し、仲介処理に基づく連合学習方法及び連合学習システムを提供することである。
【課題を解決するための手段】
【0005】
本発明に係る特定の実施形態は、次のような品質検査方法を提供する。上記技術的課題を解決するために、本発明で採用した技術的解決策の一つは、次のような仲介処理に基づく連合学習方法を提供することである。仲介処理に基づく連合学習方法は、サーバー装置によって、複数のクライアント装置における複数のデータ分布情報に基づいて前記複数のクライアント装置を複数の仲介者グループに分けて、前記複数の仲介者グループを管理するための複数の仲介者モジュールを生成し;前記サーバー装置によって、初期モデル重みデータを前記複数の仲介者モジュールにブロードキャストし;前記複数の仲介者モジュールによって、前記複数の仲介者グループのそれぞれについて逐次訓練プロセスを行い、逐次訓練プロセスにおいては、前記対応する前記複数のクライアント装置のトレーニングシーケンスを決定するプロセスも含み;前記初期モデル重みデータを前記複数のクライアント装置に対応的に送信し、前記複数のクライアント装置は複数のローカル情報を複数の訓練データとして取り扱い、前記初期モデル重みデータ及び前記トレーニングシーケンスによって、ターゲットモデルを順序に訓練することによって、訓練済モデル重みデータを生成し;前記訓練済モデル重みデータを前記サーバー装置に送信する。連合学習方法はさらに次のようなステップを含み、前記サーバー装置によって、前記複数の仲介者グループにおける複数の前記訓練済モデル重みデータを取得し、前記複数の訓練済モデル重みデータに基づいて前記複数の仲介者グループのそれぞれに対応する複数の重みを算出し;前記サーバー装置によって、前記複数の重みに基づいて前記複数の訓練済モデル重みデータに対して、加重連合平均アルゴリズムを実行することによって、グローバルモデル重みデータを生成し;前記サーバー装置によって、前記グローバルモデル重みデータに基づいて前記ターゲットモデルを構築することによってグローバルターゲットモデルを生成する。
【0006】
上記技術的課題を解決するために、本発明で採用した技術的解決策のもう1つは、次のような仲介処理に基づく連合学習システムを提供することである。仲介処理に基づく連合学習システム複数のクライアント装置、サーバー装置及び複数の仲介者モジュールを含む。サーバー装置は、前記複数のクライアント装置に通信接続され、前記複数のクライアント装置における複数のデータ分布情報に基づいて前記複数のクライアント装置を複数の仲介者グループに分ける。前記複数の仲介者モジュールは、前記サーバー装置で生成され、それぞれは、前記複数の仲介者グループを管理するために用いられる。なかでも、サーバー装置は、初期モデル重みデータを前記複数の仲介者モジュールにブロードキャストする。前記複数の仲介者モジュールは、前記複数の仲介者グループのそれぞれに対し逐次訓練プロセスを実行する。逐次訓練プロセスは、対応する前記複数のクライアント装置のトレーニングシーケンスを決定するプロセスが含まれる。前記初期モデル重みデータを前記複数のクライアント装置のそれぞれに対応的に送信し、対応的な前記複数のクライアント装置が、複数のローカル情報を複数の訓練データとして取り扱い、前記初期モデル重みデータ及び前記トレーニングシーケンスに基づいてターゲットモデルを順序に訓練することによって、訓練済モデル重みデータを生成する。前記訓練済モデル重みデータを前記サーバー装置に送信する。なかでも、前記サーバー装置は、前記複数の仲介者グループにおける複数の前記訓練済モデル重みデータを取得し、前記複数の訓練済モデル重みデータに基づいて対応的な前記複数の仲介者グループの複数の重みを算出する。なかでも、前記サーバー装置は、前記複数の重みによって、前記複数の訓練済モデル重みデータに対し加重連合平均アルゴリズムを実行することで、グローバルモデル重みデータを生成する。
【0007】
本発明による有益な効果の1つとしては、本発明が提供する仲介処理に基づく連合学習方法及び連合学習システムは、連合学習において仲介者グループの訓練任務をあっせんするための仲介者を提供することによって、クライアントエンドとサーバーとの間にモデルの重みを疎通することによって、連合学習のデータ分布不均一な課題を解決しながら、優れたプライバシーと低コスト性を同時に持たせることができる。
【0008】
また、本発明は、連合学習の仲介者モジュールの下で耐障害性メカニズムを提供する仲介処理に基づく連合学習方法および連合学習システムを提供することにより、訓練処理中にクライアント装置との接続が切断された場合でも、訓練性能およびモデルの安定性を維持することができる。
【発明の効果】
【0009】
また、本発明は、仲介処理に基づく連合学習方法及び連合学習システムを提供するものであり、クライアント装置が複数の仲介モジュールの並列動作によりグローバルモデルを逐次更新することで、重みの偏りを回避するだけでなく、通信コストを削減し、連合学習の学習全体を高速化することができる。
【0010】
この発明の特徴および技術的内容をよりよく理解するために、以下の詳細な説明および添付の図面を参照していただきたいが、これらの説明および添付の図面は、この発明を説明することのみを目的としており、この発明の保護範囲をいかなる形でも限定しない。
【図面の簡単な説明】
【0011】
【
図1】本発明に係る実施形態の連合学習システムを示す模式図である。
【
図2】本発明に係る実施形態のサーバー装置及びクライアント装置を示す機能ブロック図である。
【
図3】本発明に係る実施形態の連合学習方法を示すフローチャートである。
【
図4】本発明に係る実施形態の逐次訓練プロセスを示すフローチャートである。
【
図5】本発明に係る実施形態の耐障害性プロセスを示すフローチャートである。
【発明を実施するための形態】
【0012】
下記より、具体的な実施例で本発明が開示する「仲介処理に基づく連合学習方法及び連合学習システム」に係る実施形態を説明する。当業者は本明細書の公開内容により本発明のメリット及び効果を理解し得る。本発明は他の異なる実施形態により実行又は応用できる。本明細書における各細節も様々な観点又は応用に基づいて、本発明の精神を逸脱しない限りに、均等の変形と変更を行うことができる。また、本発明の図面は簡単で模式的に説明するためのものであり、実際的な寸法を示すものではない。以下の実施形態において、さらに本発明に係る技術事項を説明するが、公開された内容は本発明を限定するものではない。
【0013】
図1は、本発明に係る実施形態の連合学習システムを示す模式図である。
図1に示すように、本発明に係る実施形態は、仲介処理に基づく連合学習システム1を提供する。仲介処理に基づく連合学習システム1は、クライアント装置100-1、100-2、…、100-K、サーバー装置12及び仲介者(mediator)モデル14-1、14-2、…、14-Nを含む。
【0014】
サーバー装置12は、前記複数のクライアント装置100-1、100-2、…、100-Kに通信接続され、クライアント装置100-1、100-2、…、100-Kにおける複数のデータ分布情報に基づいて、クライアント装置100-1、100-2、…、100-Kを仲介者グループ10-1、10-2、…、10-Nに分ける。仲介者モジュール14-1、14-2、…、14-Nは、サーバー装置12で生成され、前記複数の仲介者グループ10-1、10-2、…、10-Nを管理するために用いられる。かつ、仲介者モジュール14-1、14-2、…、14-Nと仲介者グループ10-1、10-2、…、10-Nとの個数は同様である。なお、本発明は上記の例に制限されない。
【0015】
連合学習システム1において、サーバー装置12の主な役割は初期化であり、データの分布状況に応じて、クライアント装置100-1、100-2、…、100-Kをそれぞれ仲介者グループ10-1、10-2、…、10-Nに分ける。例えば、仲介者グループ10-1がクライアント装置100-1、100-2、100-3を含み、仲介者グループ10-2がクライアント装置100-4、100-5、100-6を含んでもよい。サーバー装置12はさらに仲介者グループ10-1、10-2、…、10-Nについて仲介者モジュール14-1、14-2、…、14-Nを構築することによって、仲介者グループ10-1、10-2、…、10-Nについて訓練が済まされたモデル重み付きに対して加重連合平均アルゴリズムを行い、最終的に訓練成果を整合したモデルを生成する。
【0016】
一方、クライアント装置100-1、100-2、…、100-Kは、データプレーン(data plane)を担当し、データリクエスト、訓練の実行、及び訓練後のモデル重みの送信を行っている。また、連合学習システム1において、クライアント装置100-1、100-2、…、100-Kのそれぞれの訓練任務を調整するために、本発明では、仲介者モジュール14-1、14-2、…、14-Nが制御プレーン(control plane)を担当し、データプレーンのソフトウェアプログラムを制御し、さらに、どちらのクライアント装置が訓練を行うかを決定する。
【0017】
さらに
図2を参照する。
図2は、本発明に係る実施形態のサーバー装置及びクライアント装置を示す機能ブロック図である。
図2に示すように、特定の実施形態において、サーバー装置20は、クライアント装置22、24と通信接続する。例えば、サーバー装置20は、クライアント装置22と直接に接続してもよいが、クライアント装置24とネットワーク26を介して接続してもよい。なお、本発明は上記の例に制限されない。
【0018】
サーバー装置20は、プロセッサー200、通信インターフェース202及び記録媒体204を含む。プロセッサー200は通信インターフェース202及び記録媒体204に接続される。記録媒体204例えば、ハードディスク、ソリッドステートドライブまたはデータの保存に使用できる他の記憶装置であってもよく、それに少なくとも複数のコンピュータ読み取り可能命令D1、グローバルデータ分布情報D2、クラスタリングアルゴリズムD3、仲介者モジュール生成プロセスD4、加重連合平均アルゴリズムD5、初期モデル重みデータD6及びターゲットモデルデータD7を記憶するように構成されていてもよいが、これに限定されない。通信インターフェース202は、プロセッサー200の制御にて、ネットワークにアクセスしたり、クライアント装置22、24と通信したりする。
【0019】
クライアント装置22は、プロセッサー220、通信インターフェース222及び記録媒体224を含む。プロセッサー220は、通信インターフェース222及び記録媒体224に接続される。記録媒体224例えば、ハードディスク、SSDなど、データを記憶することができる記憶装置であって、少なくとも複数のコンピュータ読み取り可能命令D1’、ローカル情報D2’、データ分布情報D3’、訓練プロセスD4’、ターゲットモデルデータD5’及びモデル重みデータD6’を記憶するように構成されていてもよいが、これに限定されない。通信インターフェース222は、プロセッサー220の制御下でネットワークにアクセスするように構成されており、例えば、クライアント装置22と通信してもよい。
【0020】
同様に、クライアント装置24は、プロセッサー240、通信インターフェース242及び記録媒体244を含んでもよい。プロセッサー240は、通信インターフェース242及び記録媒体244に接続される。かつ、記録媒体244及び通信インターフェース242は、記録媒体224及び通信インターフェース222と同様するため、ここでは説明を繰り返さない。特定の実施形態において、クライアント装置22、24は、例えば、モバイルデバイス、IoT(Internet of Things)デバイス、Fog Computingデバイスなどであってもよい。
【0021】
また、
図1に示された仲介者モジュール14-1、14-2、…、14-Nは、ハードウェアまたはソフトウェアの形で実現してもよい。ハードウェアで実現する場合、仲介者モジュール14-1、14-2、…、14-Nをサーバー装置20の構成で実現し、サーバー装置20及びクライアント装置22、24の間に接続されてもよい。ソフトウェアで実現する場合、コンピュータ読み取り可能命令またはアプリの形でサーバー装置20の記録媒体204に格納され、プロセッサー200で実行される。なお、本発明は上記の例に制限されない。
【0022】
図3は、本発明に係る実施形態の連合学習方法を示すフローチャートである。
図3に示すように、本発明に係る実施形態は、前記連合学習システムに適用する、仲介処理に基づく連合学習方法を提供する。仲介処理に基づく連合学習方法は、少なくとも次のようなステップを含む。
【0023】
ステップS30:クライアント装置は、ローカル情報を統計することによってデータ分布情報を生成してからサーバー装置に送信する。詳しくは、ステップS10は初期化プロセスを含む。初期化プロセスにおいて、サーバー装置は、連合学習方法を実行しようとする複数のクライアント装置と通信を行って、登録プロセスを実行する。登録プロセスは、例えば、
図2に示すように、クライアント装置22、24に認識コードを配布し、クライアント装置22、24からデータ分布情報を受信する。例えば、クライアント装置22は、ローカル情報D2’を統計することで、データ分布情報D3’を生成し、それを登録プロセスにおいてサーバー装置20に送信して、最終的にサーバー装置20においてグローバルデータ分布情報D2として集約するように構成されていてもよい。なかでも、データ分布情報D3’としては、例えば、ローカル情報D2’から得られた統計量の平均値、標準誤差、中央値、標準偏差、サンプル分散、尖り度、歪度、範囲、最小値、最大値、総和などが含まれてもよい。
【0024】
ステップS31:サーバー装置は、複数のクライアント装置における複数のデータ分布情報に基づいて、前記複数のクライアント装置を複数の仲介者グループに分けて、前記複数の仲介者グループのそれぞれを管理するための複数の仲介者モジュールを生成する。ステップS31において、サーバー装置20のプロセッサー200は、クラスタリングアルゴリズムD3を行って、前記複数のクライアント装置22、24のデータ分布状況、例えば、グローバルデータ分布情報D2の平均値、標準誤差、中央値、標準偏差、サンプル分散、尖り度、歪度、範囲、最小値、最大値、合計値などの統計情報に基づいて、グループ分けを行うように構成されてもよい。グループ分けの結果によれば、
図1に示すように、クライアント装置100-1、100-2、…、100-Kを仲介者グループ10-1、10-2、…、10-Nに分ける。
【0025】
続いて、サーバー装置20は、仲介者モジュール生成プロセスD4を実行することによって、仲介者モジュールを設定する。ソフトウェアの形で実現する場合、仲介者モジュール生成プロセスD4を実行することによって、仲介者モジュールは全プロセスのどの部分で実行するかを決定してもよい。例えば、サーバー装置での実行に加えて、クライアント装置の地理的または距離的な特性を登録プロセスで収集し、対応する仲介者グループに応じて、共有サーバーがあれば、共有サーバーを選択して仲介モジュールを実行してもよい。ハードウェアの形で実現する場合、サーバーのような装置を設置して、仲介者のグループに含まれるクライアント装置の地理的または距離的な特性に応じて、対応する仲介者のグループを管理してもよい。上記はあくまでも例示であり、本発明はこれらに限定されるものではない。
【0026】
ステップS32:サーバー装置は、初期モデル重みデータを前記複数の仲介者モジュールにブロードキャストするように構成されてもよい。ステップS32において、事前に選択されたターゲットモデルを、
図2に示すようにターゲットモデルデータD7に入れ、ターゲットモデルを構成するための初期モデル重みデータD6をともにすべての仲介者モジュールにブロードキャストするように構成されてもよい。例えば、
図1に示すサーバー装置12は、初期モデル重みW0を、すべての仲介者モジュール14-1、14-2、…、14-Nにブロードキャストする。
【0027】
ステップS33:前記複数の仲介者モジュールは、前記複数の仲介者グループのそれぞれに対して逐次訓練プロセスを行う。
【0028】
図4を参照されたい。
図4は、本発明に係る実施形態の逐次訓練プロセスを示すフローチャートである。
【0029】
図4に示すように、逐次訓練プロセスは次のステップを含んでもよい。ステップS40:前記複数のクライアント装置の複数の装置状態を1つずつ確認する。ステップS41:前記複数の装置状態でトレーニングシーケンスを決定する。さらに、初期モデル重みデータを対応的な前記複数のクライアント装置に送信する。例えば、ステップS42を実行してもよい。ステップS42は、初期モデル重みデータをトレーニングシーケンスにおける順番第一のクライアント装置に送信する。
【0030】
さらに、対応的な前記複数のクライアント装置は、複数のローカル情報を複数の訓練データとして、初期モデル重みデータ及びトレーニングシーケンスに基づいて、ターゲットモデルを順序に訓練することによって、訓練済モデル重みデータを生成する。例えば、ステップS43~ステップS45を実行する。
【0031】
ステップS43:第1のクライアント装置は、初期モデル重みデータでターゲットモデルを訓練し、訓練が終了すると、第1の訓練済モデル重みデータを生成する。例えば、
図2に示すように、第1のクライアント装置は、例えば、クライアント装置22であってもよい。プロセッサー220は、訓練プロセスD4’を行うことでターゲットモデルを訓練する。例えば、初期モデル重みデータをターゲットモデルに入力した後、ローカル情報を訓練データとして訓練を行う。ターゲットモデルは、例えば、畳み込み層(Convolution layer)、プーリング層(Pooling layer)及び完全接続層(Fully Connected layer)を含む、畳み込みニューラルネットワーク(Convolutional neural network,CNN)であってもよい。畳み込みカーネルに与えられた重みのセットは、訓練完了条件(収束や事前に設定された精度など)が満たされるまで、訓練プロセスの進行に応じて更新され、重みのセットは訓練シーケンスの次のクライアント装置に渡される。
【0032】
ステップS44:第1の訓練済モデル重みデータをトレーニングシーケンスにおける順番第2のクライアント装置に送信する。
【0033】
ステップS45:第2のクライアント装置は、第1の訓練済モデル重みデータによってターゲットモデルを訓練し、訓練終了と第2の訓練済モデル重みデータを生成する。同様に、第1の訓練済モデル重みデータをターゲットモデルに入力した後、第2のクライアント装置におけるローカル情報を訓練データとして訓練を行ってもよい。
【0034】
そのように、すべてのクライアント装置の訓練が終了すると、訓練済モデル重みデータを生成して、ステップS46に移行する。ステップS46:訓練済モデル重みデータに基づいて、さらに逐次訓練プロセスを再実行する必要があるかを判断する。即ち、訓練結果によって、再訓練の必要性を評価する。例えば、訓練後のターゲットモデルの精度をさらに向上させる要望があれば、訓練結果によって再訓練を行ってもよい。
【0035】
ステップS46において前記逐次訓練プロセスを再実行する必要があると判断すると、逐次訓練プロセスはステップS40に戻り、対応的な前記複数のクライアント装置の装置状態を一つずつ確認することで、前記トレーニングシーケンスを再確定し、前記訓練済モデル重みデータは、前記初期モデル重み情報として、前記複数のクライアント装置のそれぞれに対応的に送信され、訓練を行う。
【0036】
ステップS47において前記逐次訓練プロセスを再実行する必要がないと判断すると、訓練済モデル重みデータをサーバー装置に送信する。例えば、クライアント装置22は、訓練済のモデル重みデータD6’を格納した後仲介者モジュールに送信してから、さらに、仲介者モジュールが訓練済モデル重みデータをサーバー装置に送信してもよい。
【0037】
なお、上記の処理は、複数のグループの複数の仲介者モジュールが並行して行うことができ、それぞれを順次学習させることで、クライアント装置が一定の順序で重みデータを更新してサーバー装置に送り返すことができ、偏った重みを避けることができるだけでなく、通信コストを削減し、全体の連合学習の学習を高速化することができる。
【0038】
また、
図4に示すように、特定の実施形態において、逐次訓練プロセスはさらに、前記複数の仲介者グループの訓練過程、例えば、
図4に示したステップS43及びS45の実行過程において、前記複数の仲介者モジュールは、耐障害性プロセス(ステップS48)を行うことを含んでもよい。
【0039】
図5を参照されたい。
図5は、本発明に係る実施形態の耐障害性プロセスを示すフローチャートである。
図5に示すように、耐障害性プロセスは次のステップを含む。
【0040】
ステップS50:訓練が実行されているククライアント装置の接続状態を監視する。
【0041】
例えば、ステップS51に移行し、訓練が行われているクライアント装置に周期的な信号を送信する。ステップS52:訓練が行われているクライアント装置が予定時間内に周期的な信号に応答しなかったかどうかを判断する。
【0042】
訓練が行われているクライアント装置が予定時間内に定期的な信号に応答していないと判断したことに応答して、ステップS53に移行する。ステップS53:訓練が行われているクライアント装置がオフライン状態になったと判断する。
【0043】
訓練が行われているクライアント装置がオフラインであることを検出したことに応答して、ステップS54に移行する。ステップS54:対応するクライアント装置の装置状態をそれぞれ確認する。
【0044】
ステップS55:前記複数のクライアント装置のうち、新規のクライアント装置を、前記装置状態のそれぞれに基づいて選択し、オフラインとなったクライアント装置から訓練用に予め設定されたモデル重みを新規のクライアント装置に渡して訓練を行う。
【0045】
例えば、ステップS56に移行してもよい。ステップS56:オフライン状態になったクライアント装置の直前に訓練が行ったクライアント装置は、新たなクライアント装置に、訓練するために予め設定されたモデルの重みを渡すように設定する。
【0046】
その後、ステップS50に戻り、オフライン状態が検出されたときに耐障害性プロセスが発動できるように、訓練が行われているクライアント装置の接続状態を引き続き監視する。
【0047】
一方、耐障害性プロセスは、オフライン状態が検出されるたびに、ステップS57に入り、クライアント装置がオフライン状態になったことに関連する耐障害性情報を記録し、サーバー装置に送信することも含む。耐障害性情報は、例えば、登録プロセスでオフライン状態になったクライアント装置に割り当てられたクライアント装置の識別コードであってもよく、それは、後続のステップで関連する重みを計算する際に使用される。
【0048】
そのため、連合学習の仲介者構成のフレームワークの下で耐障害性メカニズムを提供することで、訓練中に幾つかのクライアント装置がオフラインになった場合があっても、訓練のパフォーマンスとモデルの安定性を維持することができる。
【0049】
図3に戻って、ステップS34に移行する。ステップS34:全ての仲介者モジュールが、訓練済モデル重みデータをサーバー装置に送信するように制御する。例えば、仲介者モジュール14-1、14-2、…、14-Nを、
図1に示すように、訓練済モデル重みW0-1、W0-2、…、W0-Nをサーバー装置12に送信する。また、ステップS34において、仲介者モジュールは同時に、記録した耐障害性情報をサーバー装置に送信してもよい。
【0050】
ステップS35:サーバー装置は、前記複数の仲介者グループにおける複数の訓練済モデル重みデータを取得し、前記複数の訓練済モデル重みデータに基づいて、前記複数の仲介者グループに対応する複数の重みを計算するように構成される。例えば、サーバー装置は、現在のサイクルにおける各仲介者モジュールの訓練によって生成されたデータ量に基づいて、仲介者グループの重みを決定してもよい。また、このステップでは、サーバー装置は、記録された誤差許容情報に基づいて、仲介者のグループの重みを決定してもよい。
【0051】
ステップS36:前記複数の重みに基づいて前記複数の訓練済モデル重みデータに対して加重連合平均アルゴリズムを実行し、グローバルモデル重みデータを生成するようにサーバー装置を構成する。
【0052】
本発明において、加重連合平均アルゴリズムは、FedAVGアルゴリズムとも呼ばれ、大まかには、アーキテクチャ(トポロジー)の決定、勾配計算、情報交換、モデルの集約で構成される。本発明の枠組みでは、決定ステップ(トポロジー)は、初期モデル重みを確立し、このラウンドの連合学習に参加する仲介者モジュールを決定することである。勾配計算と情報交換のステップでは、まず、サーバー装置からダウンロードした初期モデル重みと対応するパラメータを確認し、クライアント装置でローカルに学習させた後、サーバー装置にアップロードする。なお、ステップS35およびS36は、実際にはモデル集計ステップに対応しており、学習済みモデルの重みがサーバーに送信され、サーバーは、選択されたメンバー(すなわち、すべての仲介者モジュールおよび仲介者グループ)に含まれるサンプル数に基づいて重みを与え、仲介者モジュールの学習済みモデルの重みを乗算し、それらを合計して平均化し、最終的なモデルの重みを得ることができる。最終的なモデルウェイトは、ステップS36で参照したグローバルモデル重みデータである。
【0053】
ステップS37:グローバルモデル重みデータに基づいて、連合学習プロセスを再度実行するかどうかを判断する。例えば、グローバルモデル重みデータを使用したターゲットモデルの精度がさらに向上するかどうかを検証したい場合、学習結果に基づいて、連合学習プロセスを再度実行するかどうかを決定することができる。
【0054】
ステップS37において、連合学習プロセスが行われなくなったと判断された場合には、ステップS38に移行し、グローバルモデル重みデータを用いてターゲットモデルを構成し、グローバルターゲットモデルを生成するようにサーバー装置を構成する。
【0055】
ステップS37で、連合学習プロセスを再度実行する必要があると判断した場合は、ステップS39:データ分布情報を再編成するためにサーバー装置を設定することに移行する。例えば、既にオフラインになっているクライアント機器を削除し、サーバー機器に接続されている新しいクライアント装置を追加して、全てのクライアント装置のデータ配信情報を再収集することができる。この時点で、連合学習方法は、ステップS31に戻って再編成することができる。
【0056】
本発明の連合学習方法および連合学習システムには、仲介者処理メカニズムが組み込まれており、仲介者グループの学習タスクを適切に調整することができるため、クライアントエンドとサーバーの間でモデルの重みを転送し、連合学習におけるデータの偏在を克服することができ、しかも、プライベート性が高く、コスト的にも有利であることは注目に値する。
【0057】
本発明の有益な効果の一つとして、本発明が提供する仲介処理に基づく連合学習方法及び連合学習システムは、連合学習に仲介者を追加して仲介者グループの学習タスクを調整する仲介処理を採用しているため、クライアントとサーバー間でモデルの重みを伝達して連合学習におけるデータの偏在を克服するのに役立ち、しかもプライベート性が高く、低コストであることが挙げられる。
【0058】
また、本発明は、仲介処理に基づく連合学習方法及び連合学習システムの下で耐障害性メカニズムを提供することにより、訓練処理中にクライアント装置がオフラインとなった場合でも、訓練性能およびモデルの安定性を維持することができる。
【0059】
また、本発明は、仲介処理に基づく連合学習方法及び連合学習システムを提供するものであり、クライアント装置が複数の仲介者モジュールの並列動作によりグローバルモデルを逐次更新することで、重みの偏りを回避するだけでなく、通信コストを削減し、連合学習の学習全体を高速化することができる。
【符号の説明】
【0060】
1 連合学習システム
100-1、100-2、100-3、100-4、100-5、100-6、100-K、22、24 クライアント装置
10-1、10-2、10-N 仲介者グループ
12 サーバー装置
14-1、14-2、14-N 仲介者モジュール
20 サーバー装置
26 ネットワーク
200、220、240 プロセッサー
202、222、242 通信インターフェース
204、224、244 記録媒体
D1、D1’ コンピュータ読み取り可能命令
D2 グローバルデータ分布情報
D2’ ローカル情報
D3 クラスタリングアルゴリズム
D3’ データ分布情報
D4 仲介者モジュール生成プロセス
D4’ 訓練プロセス
D5 加重連合平均アルゴリズム
D5’、D7 ターゲットモデルデータ
D6 初期モデル重みデータ
D6’ モデル重みデータ
W0 初期モデル重み
W0-1、W0-2、W0-N 訓練済モデル重み