IP Force 特許公報掲載プロジェクト 2022.1.31 β版

知財求人 - 知財ポータルサイト「IP Force」

▶ 日本電気株式会社の特許一覧

特開2024-73781学習装置、学習方法、及び、プログラム
<>
  • 特開-学習装置、学習方法、及び、プログラム 図1
  • 特開-学習装置、学習方法、及び、プログラム 図2
  • 特開-学習装置、学習方法、及び、プログラム 図3
  • 特開-学習装置、学習方法、及び、プログラム 図4
  • 特開-学習装置、学習方法、及び、プログラム 図5
  • 特開-学習装置、学習方法、及び、プログラム 図6
< >
(19)【発行国】日本国特許庁(JP)
(12)【公報種別】公開特許公報(A)
(11)【公開番号】P2024073781
(43)【公開日】2024-05-30
(54)【発明の名称】学習装置、学習方法、及び、プログラム
(51)【国際特許分類】
   G06N 20/00 20190101AFI20240523BHJP
【FI】
G06N20/00
【審査請求】未請求
【請求項の数】8
【出願形態】OL
(21)【出願番号】P 2022184674
(22)【出願日】2022-11-18
(71)【出願人】
【識別番号】000004237
【氏名又は名称】日本電気株式会社
(74)【代理人】
【識別番号】100107331
【弁理士】
【氏名又は名称】中村 聡延
(74)【代理人】
【識別番号】100104765
【弁理士】
【氏名又は名称】江上 達夫
(74)【代理人】
【識別番号】100131015
【弁理士】
【氏名又は名称】三輪 浩誉
(72)【発明者】
【氏名】谷本 啓
(57)【要約】
【課題】適切な損失関数を用いて、因果推論に用いるモデルを学習する手法を提案する。
【解決手段】学習装置において、取得手段は、説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得する。学習手段は、最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する。ここで、損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【選択図】図3
【特許請求の範囲】
【請求項1】
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得する取得手段と、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する学習手段と、
を備え、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている学習装置。
【請求項2】
前記学習手段は、前記迷惑モデルと前記損失関数とを同時かつ敵対的に最適化する請求項1に記載の学習装置。
【請求項3】
前記学習手段は、前記迷惑モデルに関する損失関数と、前記因果推論を行うモデルに関する損失関数とを用いて学習を行う請求項1に記載の学習装置。
【請求項4】
前記損失関数は、前記迷惑モデルを重みとして含む請求項1に記載の学習装置。
【請求項5】
前記損失関数は、前記迷惑モデルを損失に対する重みとして用いた重み付き損失を算出する請求項1に記載の学習装置。
【請求項6】
前記損失関数は、前記因果推論を行うモデルによる条件付き因果効果の推定を含む請求項1に記載の学習装置。
【請求項7】
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得し、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習し、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている学習方法。
【請求項8】
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得し、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する処理をコンピュータに実行させ、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されているプログラム。
【発明の詳細な説明】
【技術分野】
【0001】
本開示は、因果推論に関する。
【背景技術】
【0002】
入力データと出力データに基づいてデータ間の因果関係を推定する因果推論が知られている。特許文献1は、機械学習システムにおける因果関係を推定する手法を記載している。
【先行技術文献】
【特許文献】
【0003】
【特許文献1】特開2019-194849号公報
【発明の概要】
【発明が解決しようとする課題】
【0004】
本開示の1つの目的は、適切な損失関数を用いて、因果推論に用いるモデルを学習する手法を提案することにある。
【課題を解決するための手段】
【0005】
本開示の一つの観点では、学習装置は、
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得する取得手段と、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する学習手段と、
を備え、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【0006】
本開示の他の観点では、学習方法は、
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得し、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習し、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【0007】
本開示のさらに他の観点では、プログラムは、
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得し、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する処理をコンピュータに実行させ、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【発明の効果】
【0008】
本開示によれば、適切な損失関数を用いて、因果推論に用いるモデルを学習することが可能となる。
【図面の簡単な説明】
【0009】
図1】因果推論の適用例を示す。
図2】第1実施形態に係る学習装置のハードウェア構成を示すブロック図である。
図3】第1実施形態に係る学習装置の機能構成を示すブロック図である。
図4】学習装置による学習処理のフローチャートである。
図5】第2実施形態の学習装置の機能構成を示すブロック図である。
図6】第2実施形態の学習装置による処理のフローチャートである。
【発明を実施するための形態】
【0010】
以下、図面を参照して、本開示の好適な実施形態について説明する。
<基本説明>
[因果推論]
近年、データ間の因果関係を推論する手法である因果推論が提案されている。教師あり学習による推論は、基本的に全ての事実に対する正解が用意されていることが前提となっている。教師あり学習において典型的に使用される損失としてクロスエントロピーが知られている。クロスエントロピーは、予測の対象となる全ての選択肢(クラス)についての予測値と正解との間のエントロピーの和として与えられる。よって、教師あり学習においては、「もしこの予測をしたら」という反事実に対する正解も用意されている。
【0011】
これに対し、意思決定問題に因果推論を用いる場合、一般に全ての選択肢の結果はわからない。即ち、実際に取られなかった行動(これを「反事実」と呼ぶ。)についての結果を知ることはできない。これは、部分観測、バンディットフィードバックとも呼ばれる。よって、意思決定問題に因果推論を用いる場合の問題は、反事実に対する結果が欠損していること、及び、その欠損が完全にランダムではなく、背景因子(「交絡因子」とも呼ぶ。)によって偏っていることである。
【0012】
いま、図1に示すように、ある患者に対して、治療のために何らかの処置を行うことを考える。この場合、説明変数xについて、何らかの行動aをとることにより結果yが得られるものとする。なお、説明変数xは、例えば患者の年齢、性別などの属性であり、上述の背景因子に相当する。いま、仮に患者に対して薬5を投与した場合、それによる結果yを観測することができる。しかし、その場合、この患者に対して薬6を投与することや注射を行うことは反事実に相当し、それらに対する結果yを観測することはできない。因果推論においては、これら反事実に対する結果を潜在結果として想定するが、それらを実際に観測することはできない。これが反事実の欠損という問題である。
【0013】
また、反事実の欠損が完全にランダムに発生するのではなく、背景因子によって偏るという問題がある。例えば、低年齢者には薬が処方されにくく、高齢者には薬が処方されやすい、というような背景因子があると、個々の反事実について欠損が発生する確率が異なってくる。
【0014】
例えば、高齢者は基礎疾患を有することが多いので、強い薬を処方するという背景因子があるとする。この場合、実際に強い薬を投与した結果、患者の予後が良くなかったとき、実際にはもともと基礎疾患があるから予後が良くないにも拘わらず、統計上は強い薬が原因で予後が良くなかったと判断されることもある。これは疑似相関とも呼ばれ、背景因子によって起きる問題である。
【0015】
しかし、背景因子に関する説明変数x、具体的には何に基づいて意思決定をしたのかという情報が得られていれば、上記の問題に対処することが可能である。
【0016】
[因果推論の精度指標]
因果推論の精度指標は、平均二乗誤差(MSE:Mean Square Error)を用いた以下の損失関数で表すことができる。
【数1】
【0017】
ここで、「x」は背景因子に相当する説明変数を表し、「a」は行動を示し、「f^(x,a)」は説明変数xにおいて行動aが選択された場合の結果yを予測するモデルによる予測結果を表す。なお、本明細書では、標記の便宜上、ある記号、例えば「f」の上に「^」が付いたものを「f^」と表し、予測値や予測結果を表す。他の記号についても同様である。「A」は行動の集合を表す。MSE(f^)は、予測モデルf(x,a)による予測結果の精度を表し、MSEの「」は、行動の集合Aから行動aを一様(uniform)に選択することを表す。「y」は行動aが選択されたときの結果を表す。「E[y|x]」は背景因子xにおいて結果yが生じる期待値を表す。「μ(a|x)」は背景因子xのもとで行動aが選択される条件付き確率を表す。μ(a|x)は、過去の意思決定者の意思決定ポリシーを示し、「傾向スコア」とも呼ぶ。
【0018】
上記の式(1)に示すように、因果推論の損失関数MSE(f^)は、背景因子xが起きる期待値Eと、その状態で行動aが選択される期待値Ea~μ(a|x)の積を含み、この積Ea~μ(a|x)は過去の観測データにおいて背景因子xと行動aの組が現れる確率を示す分布として得ることができる。式(1)の期待値に過去のデータの確率分布を入力し、大括弧[]内の値を最小化すれば推論の精度を上げることができる。損失関数としては、式(1)に示すように傾向スコアμ(a|x)の逆数でサンプルに重み付けする手法がとられる。以下、傾向スコアμ(a|x)を「重みμ(a|x)」とも呼ぶ。
【0019】
なお、学習により得られた重みμ(a|x)のモデルの精度が低い場合、又は、重みが極端な値を取る場合には、因果推論の結果が不安定化することがある。
【0020】
式(1)にプラグイン(代入)する重みμ(a|x)を求めるためには、教師あり学習により重みμ(a|x)のモデルを学習する。そして、学習したモデルを用いて重みの予測値μ^(a|x)を求めて式(1)にプラグインする。この場合の損失関数MSE(f^)は以下のようになる。
【0021】
【数2】
【0022】
なお、現実には期待値E[y|x]は教示情報として得られないので、ノイズの乗った実際の観測データyを用いる。それでも、式(2)は期待値に対して二乗損失なので、
【0023】
【数3】
と分解できる。式(3)の右辺第2項はノイズを示し、そのノイズ分散は予測モデルf^によらない定数であるので、精度の評価において無視することができる。
【0024】
上記のように、重みμ(a|x)のモデルを学習して重みの予測値μ^(a|x)を求め、これを損失関数の式(2)にプラグインして予測モデルf(x,a)を学習する推定方法(以下、「プラグイン推定」又は「2段階推定」とも呼ぶ。)では、サンプル数が無限大であると仮定すると、行動が一様に分布する場合の損失(「De-biased損失」と呼ぶ。)を精度良く推定することができる。
【0025】
しかし、上記のプラグイン推定では、予測モデルfの仮説によっては、楽観的な部分があり得る。上記の損失関数を用いた学習は、観測データに基づいて、最も良く見えるモデルを最適化により選ぶ手法である。しかし、データの量が少ないためにたまたま訓練誤差が小さい仮説、即ち、楽観的な損失を示す仮説が存在すると、その仮説を採用しやすくなるため、推定が不安定化してしまう。
【0026】
そこで、本実施形態では、因果推論を行うモデル(以下、「因果推論モデル」とも呼ぶ。)の学習において、仮説(即ち、予測モデルf)の評価値が楽観的にならないように、つまり悲観的になるようにする。具体的には、予測モデルfの損失を大きくすることにより、予測モデルfが基づく仮説に対する評価が楽観的になりすぎないようにする。言い換えると、たまたま良いパラメータが存在しやすくなるような極端な重み付けを避けることで、予測モデルfの評価が楽観的にならないようにする。また、本当に良い予測モデルの評価値はあまり下げないようにすることで、最適性が高いパラメータでは悲観度が小さくなるようにする。これにより、モデルによる推定が不安定になることを防止する。
【0027】
<第1実施形態>
次に、本開示の第1実施形態に係る学習装置について説明する。
[ハードウェア構成]
図1は、第1実施形態に係る学習装置100のハードウェア構成を示すブロック図である。図示のように、学習装置100は、インタフェース(I/F)11と、プロセッサ12と、メモリ13と、記録媒体14と、データベース(DB)15と、を備える。
【0028】
I/F11は、外部装置との間でデータの入出力を行う。具体的に、学習装置100は、学習の対象となる因果推論モデルに関する説明変数の情報をI/F11を通じて取得する。また、学習装置100は、I/F11を通じて、所定の行動に対する結果を観測データとして取得する。
【0029】
プロセッサ12は、CPU(Central Processing Unit)などのコンピュータであり、予め用意されたプログラムを実行することにより学習装置100の全体を制御する。なお、プロセッサ12は、GPU(Graphics Processing Unit)またはFPGA(Field-Programmable Gate Array)であってもよい。プロセッサ12は、後述する生徒モデル学習処理を実行する。
【0030】
メモリ13は、ROM(Read Only Memory)、RAM(Random Access Memory)などにより構成される。メモリ13は、プロセッサ12による各種の処理の実行中に作業メモリとしても使用される。
【0031】
記録媒体14は、ディスク状記録媒体、半導体メモリなどの不揮発性で非一時的な記録媒体であり、学習装置100に対して着脱可能に構成される。記録媒体14は、プロセッサ12が実行する各種のプログラムを記録している。学習装置100が各種の処理を実行する際には、記録媒体14に記録されているプログラムがメモリ13にロードされ、プロセッサ12により実行される。
【0032】
DB15は、学習装置100が学習に使用するデータを記憶する。具体的に、DB15には、学習の対象となる因果推論モデルの説明変数に関する情報が記憶される。例えば、図1に示すように、患者に対する医療処置の効果を予測する因果推論モデルでは、説明変数に関する情報として、患者の年齢、性別などの属性が記憶される。また、DB15は、実際に行われた行動に対して得られた結果の観測データを記憶する。さらに、DB15には、因果推論モデルの学習において精度の評価に使用される精度指標、具体的には損失関数に関する情報が記憶される。
【0033】
[機能構成]
図3は、第1実施形態に係る学習装置100の機能構成を示すブロック図である。学習装置100は、機能的には、学習データ記憶部21と、学習データ取得部22と、損失関数記憶部23と、損失関数取得部24と、学習部25と、を備える。
【0034】
学習データ記憶部21は、因果推論モデルの学習に用いる学習データを記憶する。学習データ記憶部21は、例えばDB15により実現される。学習データは、説明変数、行動、及び、その行動の結果を含む。行動の結果は、観測データとして得られ、学習データ記憶部21に記憶される。学習データ取得部22は、学習データ記憶部21から学習データを取得し、学習部25に出力する。
【0035】
損失関数記憶部23は、学習の対象となる因果推論モデルの評価指標を与える損失関数を記憶する。損失関数記憶部23は、例えばメモリ13又はDB15により実現される。損失関数の具体例については後述するが、本実施形態では、損失関数は、迷惑モデルを部分的に含むものを用いる。「迷惑モデル」とは、最終的な出力としては必要がないが、損失の計算において必要となる予測値を計算するためのモデルを言う。損失関数取得部24は、取得した損失関数を学習部25へ出力する。
【0036】
学習部25は、学習データと損失関数とを用いて、因果推論モデルの評価値である損失を計算し、損失を最小化するように因果推論モデルの学習を行う。ここで、損失関数は、前述のように、因果推論モデルの評価値である損失が楽観的にならないように、つまり悲観的になるように定義されている。具体的には、損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失関数を悲観的に推定するように定義されている。学習部25は、このような損失関数を用いて、因果推論モデルを学習し、学習により得られた因果推論モデルを出力する。
【0037】
[学習処理]
次に、学習装置100による学習処理について説明する。図4は、学習装置100による学習処理のフローチャートである。この処理は、図2に示すプロセッサ12が予め用意されたプログラムを実行し、図3に示す各要素として動作することにより実現される。
【0038】
まず、損失関数取得部24は、損失関数記憶部23から、学習に用いる損失関数を取得する(ステップS11)。次に、学習データ取得部22は、学習データ記憶部21から、学習データを取得する(ステップS12)。次に、学習部25は、取得した損失関数と学習データを用いて、因果推論モデルの学習を行う(ステップS13)。次に、学習部25は、所定の学習終了条件が具備されたか否かを判定する(ステップS14)。学習終了条件は、例えば、学習データを全て用いて学習したこと、学習中のモデルの精度が所定レベルに達したことなどが用いられる。学習終了条件が具備されていない場合(ステップS14:No)、学習部25は学習を継続する。一方、学習終了条件が具備された場合(ステップS14:Yes)、学習処理は終了する。
【0039】
[実施例]
以下、第1実施形態の実施例について説明する。なお、以下の説明において登場する「目的関数」はいずれも「損失関数」の例である。
【0040】
(第1実施例)
一般に、損失関数に代入する未知の量を推定するためのモデルを「迷惑モデル」と呼ぶ。迷惑モデルは、損失の計算上必要なパラメータであるので推定するが、そのパラメータ自体を知りたいわけではないという意味で迷惑モデルと呼ばれる。先の[基本説明]の欄における傾向スコアの予測モデルμ(a|x)は迷惑モデルの一例である。
【0041】
いま、迷惑モデルνに関する目的関数をLν(ν)とする。目的関数Lν(ν)は、例えばクロスエントロピー損失などであり、推定したい因果推論モデルのパラメータθに依存しないものとする。また、推定したい因果推論モデルのパラメータθに関する目的関数をL(θ;ν)とする。目的関数L(θ;ν)は、例えば平均二乗誤差(MSE)などである。
【0042】
損失関数が迷惑モデルを含む場合、一般的には迷惑モデルを学習し、迷惑モデルによる予測値を損失関数に代入して損失を計算する。この手法を前述のようにプラグイン推定とも呼ぶ。プラグイン推定では、まず目的関数Lν(ν)を学習により最適化して迷惑モデルνの予測値ν^を求め、この予測値ν^を目的関数L(θ;ν)に代入して、目的関数L(θ;ν^)を最小化するパラメータθ^を求める。
【0043】
これに対し、第1実施例に係る学習装置は、通常のプラグイン推定の代わりに敵対的同時最適化を行い、因果推論モデルのパラメータθ^を以下の式で求める。
【数4】
【0044】
学習時において、式(4)に示すように、基本的に迷惑モデルνは最大化され、パラメータθは最小化される。よって、迷惑モデルνは、αLν(ν)を最小化しつつ、L(θ;ν)を最大化するように学習される。一方で、パラメータθは、迷惑モデルνが最大化しようとするL(θ;ν)を最小化するように学習される。これにより、迷惑モデルνを最大化しようとする動きはパラメータθを最小化する動きにより制約を受け、かつ、パラメータθを最小化しようとする動きは迷惑モデルνを最大化しようとする動きにより制約を受ける。このように、迷惑モデルνとパラメータθが敵対的に動作し、両者を同時に最適化するため、この手法を敵対的同時最適化と呼ぶ。
【0045】
こうして、迷惑モデルνは、、データから計算される自身の確からしさを表すLν(ν)が適切な範囲、即ち、所定以上に確からしい範囲に維持される。また、迷惑モデルνは、ハイパーパラメータαにより制御される所定値以上に確からしい範囲内に維持された状態で、自身を最大化することにより損失L(θ;ν)を最大化しようとする。このように、損失関数
L(θ;ν)-αLν(ν)
は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【0046】
なお、LνとLの関数形に関して適当な仮定のもとで、制約付き最適化と正則化は同一視できる。即ち、確からしさの制約度合いと、正則化の強さαに一対一対応があり、対応する制約付き最適化の解と正則化付き最適化の解とは一致する。よって、後でパラメータαを交差検証などで選択する前提では、パラメータαを用いた正則化により、迷惑モデルνを所定以上に確からしい範囲内に維持することができる。
【0047】
(第2実施例)
第2実施例は、第1実施例を具体化した実施例であり、因果推論モデルの損失関数において、迷惑モデルの目的関数を重みとして用いるものである。
【0048】
迷惑モデルνに関する目的関数をLν(ν)とする。なお、目的関数Lν(ν)は、推定したい因果推論モデルのパラメータθに依存しないものとする。また、推定したい因果推論モデルのパラメータθに関する重み付き目的関数を以下のようにL(θ;ν)とする。
【0049】
【数5】
この目的関数は、損失関数l(θ)に迷惑モデルνの出力を重みω(ν)として掛けたものとなっている。なお、「i」はサンプル番号を示す。
【0050】
第1実施例と同様に本実施形態による敵対的同時最適化を適用すると、推定したい因果推論モデルのパラメータθは以下の式で与えられる。
【0051】
【数6】
例えば、迷惑モデルνを傾向スコアμ(a|x)のモデルとし、重みは
ω=1/μ(a|x
などとしてもよい。また、迷惑モデルνに関する目的関数Lν(ν)は、傾向スコアのモデルが行動を精度良く予測する場合に損失が小さくなるクロスエントロピーなどの判別損失を用いてもよい。
【0052】
式(6)においても、第1実施例における式(4)と同様に、迷惑モデルνは、データから計算される自身の確からしさを表すLν(ν)が適切な範囲、即ち、所定以上に確からしい範囲に維持される。また、迷惑モデルνは、ハイパーパラメータαにより制御される所定値以上に確からしい範囲内に維持された状態で、自身を最大化することにより重み付き損失ω(ν)l(θ)を最大化しようとする。こうして、損失関数
ω(ν)l(θ)-αLν(ν)
は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【0053】
なお、式(6)において、重みω(ν)を最大化するように迷惑モデルνを学習すると、学習が進むにつれて重みω(ν)が大きくなっていく。重みが極端に大きくなると、重みに対する実質的なデータサイズが小さくなるため、推定分散が増加する。そこで、重みを正規化する項を導入すると、以下の式が得られる。
【0054】
【数7】
式(7)では、重みω(ν)に正規化項1/Σω(ν)を掛けることにより、重みを1に正規化している。式(7)の手法は、式(6)の自己正規化版ということができる。
【0055】
(第3実施例)
第3実施例は、本実施形態の手法を目的変数変換法に応用したものである。因果推論では、ある背景因子xのもとで行動aが選択された場合と選択されなかった場合の結果の差を効果として推定することが多い。これを条件付き因果効果(以下、「CATE:Conditional Average Treatment Effect」)とも呼ぶ。ある背景因子xのもとで行動aを取った場合の因果効果は以下の式で与えられる。
【0056】
【数8】
しかし、式(8)では行動aを選択した場合と選択しなかった場合の観測データを必要とするため、現実にはCATEτ(x)に対する正解を与えることはできない。
【0057】
これに対し、目的変数変換法は、CATEτ(x)にノイズが乗った値を得ることはできるとの考えに基づく。目的変数変換法により結果yを目的変数zに置き換えると、変換後の目的変数zは以下の式で与えられる。
【数9】
【0058】
式(9)において、行動aが選択された場合には第2項が0となり、行動aが選択されなかった場合は第1項が0となるため、いずれにしても実際に観測されたデータと傾向スコアμ(x)を用いて目的変数zを計算することができる。ここで、傾向スコアμ(x)=μ(a=1|x)の予測値μ^が正しいとき、目的変数zの期待値はCATEτ(x)に一致する。即ち、目的変数zを背景因子xに回帰したCATE推定モデルτ^は、式(8)の期待値E[ya=1-ya=0|x]にノイズが乗ったものと考えることができ、サンプル数無限大で真のCATEに一致する。よって、目的変数zを背景因子xに回帰することでCATE推定モデルτ^を学習することができる。
【0059】
具体的には、以下のように変換後の目的変数zを傾向スコアμの関数に置き換える。
【数10】
【0060】
そして、前述の敵対的同時最適化によりCATE推定モデルτ^を以下のように定義する。
【数11】
ここで、NLL(Negative Log Likelihood)は、傾向スコアμに関する元々の損失関数であり、例えばクロスエントロピーなどである。
【0061】
式(11)において、傾向スコアμは、中括弧{}内の第2項「αNLL(μ,(x,a))」を最小化するとともに、第1項である損失関数l(z μ,τ)を最大化するように学習される。一方で、パラメータτは、迷惑モデルμが最大化しようとする損失関数l(z μ,τ)を最小化するように学習される。その結果、損失関数
{l(z μ,τ)-αNLL(μ,(x,a))}
は、迷惑モデルμが所定以上に確からしい範囲において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失関数を悲観的に推定するものとなる。
【0062】
(第4実施例)
第4実施例は第3実施例と同様に条件付き因果効果CATEを推定する手法であるが、第3実施例における目的変数変換法の代わりにダブルロバストラーナー(DRL:Doubly Robust Learner、以下「DRL」とも呼ぶ。)を用いる。
【0063】
条件付き因果効果CATEは前述の式(8)により表される。ここで、DRLでは、まず、行動a∈{0,1}ごとのデータに対し、以下のように潜在アウトカム予測モデルf^,f^を学習する。
【0064】
【数12】
即ち、行動a=1のときの結果yを予測する予測モデルf^(x)と、行動a=0のときの結果yを予測する予測モデルf^(x)とが個別に学習される。
【0065】
次に、行動ごとのデータと傾向スコアμに対して、目的変数変換法を用い、以下のように変換後の目的変数zμを定義する。
【数13】
【0066】
学習により得られた予測モデルf^(x)の予測値f^(x)と、予測モデルf^(x)の予測値f^(x)は式(13)にプラグインされる。
【0067】
式(13)においては、まず、行動a=1のときの予測値f^(x)と、行動a=0のときの予測値f^(x)との差を取る。これに加えて、行動a=1のときの結果y と予測値f^(x)との残差を傾向スコアμ(x)の逆数で重み付けして加算する。さらに、行動a=0のときの結果y と予測値f^(x)との残差を1-μ(x)の逆数で重み付けして減算する。即ち、個別に学習済みの予測モデルf^(x)の予測値f^(x)と、予測モデルf^(x)の予測値f^(x)とを変換後の目的変数zμにプラグインしている点が第3実施例と異なる。
【0068】
変換後の目的変数zμは、基本的に予測モデルの予測値が正しければ正しい値となるし、予測モデルの予測値が正しくなくても、傾向スコアμのモデルが正しければ残差が調整されて正しい値となる。この意味で、2重にロバスト(Doubly Robust)と呼ばれる。
【0069】
そして、この変換後の目的変数zμを用いて、以下に示すCATEのモデルτを学習する。
【数14】
【0070】
式(14)は式(11)と同様であり、傾向スコアμは、中括弧{}内の第2項「αNLL(μ,(x,a))」を最小化するとともに、第1項である損失関数l(z μ,τ)を最大化するように学習される。一方で、パラメータτは、迷惑モデルμが最大化しようとする損失関数l(z μ,τ)を最小化するように学習される。その結果、損失関数
{l(z μ,τ)-αNLL(μ,(x,a))}
は、迷惑モデルμが所定以上に確からしい範囲において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失関数を悲観的に推定するものとなる。
【0071】
<第2実施形態>
図5は、第2実施形態の学習装置の機能構成を示すブロック図である。図示のように、学習装置70は、取得手段71と、学習手段72とを備える。
【0072】
図6は、第2実施形態の学習装置による処理のフローチャートである。取得手段71は、説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得する(ステップS71)。学習手段72は、最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する(ステップS72)。ここで、損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている。
【0073】
<適用分野>
上記の学習により得られた因果推論モデルは、各種の分野に適用することができる。例えば、医療の分野において、因果推論モデルは、薬効・医療処置の効果の予測に用いることができる。具体的には、図1に示すように、説明変数として患者の属性を用い、行動として患者に対する医療処置を用い、医療処置後の患者の状態などを結果として用いることができる。また、医療の分野において、因果推論モデルは、化学特性の予測、実験の最適化などにも適用することができる。
【0074】
また、マーケティングの分野では、因果推論モデルは、価格弾力性や交差弾力性の推定、価格最適化やダイナミックプライシング、他商品の在庫を考慮した需要予測や在庫最適化、個別的商品推薦などに適用することができる。また、政策・教育の分野においては、因果推論モデルは、政策効果の予測や評価、問題の推薦などに適用することができる。
【0075】
上記の実施形態の一部又は全部は、以下の付記のようにも記載されうるが、以下には限られない。
【0076】
(付記1)
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得する取得手段と、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する学習手段と、
を備え、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている学習装置。
【0077】
(付記2)
前記学習手段は、前記迷惑モデルと前記損失関数とを同時かつ敵対的に最適化する付記1に記載の学習装置。
【0078】
(付記3)
前記学習手段は、前記迷惑モデルに関する損失関数と、前記因果推論を行うモデルに関する損失関数とを用いて学習を行う付記1に記載の学習装置。
【0079】
(付記4)
前記損失関数は、前記迷惑モデルを重みとして含む付記1に記載の学習装置。
【0080】
(付記5)
前記損失関数は、前記迷惑モデルを損失に対する重みとして用いた重み付き損失を算出する付記1に記載の学習装置。
【0081】
(付記6)
前記損失関数は、前記因果推論を行うモデルによる条件付き因果効果の推定を含む付記1に記載の学習装置。
【0082】
(付記7)
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得し、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習し、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されている学習方法。
【0083】
(付記8)
説明変数、行動、及び、前記行動の結果の情報を含む学習データを取得し、
最終的な出力としては必要のない推定対象である迷惑モデルを部分的に含む損失関数に基づいて、前記学習データを用いて、因果推論を行うモデルを学習する処理をコンピュータに実行させ、
前記損失関数は、迷惑モデルが所定以上に確からしい範囲内において、最も悪い値を用いることで、迷惑モデルの不確実性に関して損失を悲観的に推定するように定義されているプログラム。
【0084】
以上、実施形態及び実施例を参照して本開示を説明したが、本開示は上記実施形態及び実施例に限定されるものではない。本開示の構成や詳細には、本開示のスコープ内で当業者が理解し得る様々な変更をすることができる。
【符号の説明】
【0085】
12 プロセッサ
21 学習データ記憶部
22 学習データ取得部
23 損失関数記憶部
24 損失関数取得部
25 学習部
図1
図2
図3
図4
図5
図6