選定理由
Alibabaグループが大好きな(Computer Visionの)Causal Machine Learning、Nanyang大学との共同研究、NIPS2020。この分野のサーベイ論文はこちら: [Liu2022]
Paper: https://proceedings.neurips.cc/paper/2020/hash/1091660f3dff84fd648efe31391c5524-Abstract.html
Code: https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch
概要
[ビジネス課題]現実世界のデータはロングテール(クラス間不均衡)であることが一般的である。特に対象となるクラス数が多い場合は、1サンプルに多くのクラスが入り込むため、どんなに収集コストを上げても必ずロングテールとならざるを得ない(Zipf’s lawと呼ばれる)。
[技術課題]ロングテールなデータセットでの学習時の確率的勾配降下法のモメンタム(M)は交絡因子となりうる。少数派クラス(tail class)を学習する際も多数派クラス(head class)に対して学習した特徴表現(D)への射影はモメンタムの影響を受けたバックドアパスによる学習になってしまう。(以下図)
[従来]weighting/re-samplingを用いた学習時バイアス(背景の学習)の低減があるが、これらは根本的解決でなく多数派クラスにアンダーフィットし、少数派クラスにオーバーフィットする。結果、tail classはhead classに誤分類されやすい。又、クラス頻度が既知であるという前提はオンライン学習を困難にしてしまう。
[提案]学習時はバックドア基準による因果効果推定を行うことで(確率的勾配降下法の)モメンタムによる交絡因子(M)を排除し、一方で中間変数(D)によってクラス間の共起性を捉える効果は残しながら学習を行う。さらに推論時は直接効果を反事実の枠組みで計算する。
[効果] Long-tailed CIFAR-10/-100, ImageNet-LT などのロングテールな画像分類データセットでSOTA。又、セグメンテーションと物体検出タスク向けのLVIS にてSOTAの性能となった。
モメンタムの導入はHMC(ハミルトニアンモンテカルロ法)に近い概念に感じる。
De-confounded学習とTDE推論
学習時: De-confounded Training
モメンタムの交絡を排除するためにMからXへの依存を介入操作により除外する(上図左)。
上記式のようにdoオペレータはバックドア基準による調整化公式により展開する。ここで確率的勾配降下法のモメンタムの数はほぼ無限であり、上記式も計算困難である。
そこで調整化公式適用後は逆確率重みづけの式により近似可能であり、モメンタムを排除できる。
逆確率重みづけの式はエネルギーモデルに定式化することができ、最終的に因果効果はソフトマックスの出力から計算可能である。
推論時: TDE Inference
[Pearl2001]にて定義されている総合直接効果(TDE: Total Direct Effect)を最大にするクラスを推論結果とする。ここで x0 はヌル入力である。
実際には上記式に展開される。αは直接効果と間接効果のバランスを調整するハイパーパラメータである。
実験
データセットは画像分類タスク向けに Long-tailed CIFAR-10/-100, ImageNet-LT、物体検出とセグメンテーション向けに LVIS を使用した。
様々な線形、非線形の分類モデルと比較した。De-confounded学習とTDE推論両方を組み合わせると多くの場合で性能が改善した。
ImageNet-LTデータセットでの実験結果。few-shotでの改善効果が大きい。
GradCAMにて可視化を行うと、提案手法はbaselineよりもクラス特有の領域が可視化されている。例えばdressはheadクラス、kimonoはtailクラスであり、baselineのkimonoの識別時にはheadクラスのdressの表現特徴に交絡してしまっているが、提案手法では着物の箇所のみ励起している。
実装
以下クラスにて実装されている。
BBoxHeadがmmdetが提供しているクラスで、ConvFCBBoxHeadに因果推論のロジックが実装されている。実際に使用されているのはパラメータ限定版がShared2FCBBoxHead、Shared4Conv1FCBBoxHeadでconfigで使用されている例のように後者の方が使われている。
ConvFCBBoxHead内のhead部のforward計算に教師ラベルが前景か背景かが必要になる。よって以下のようにbbox_targetsをforwardに渡してあげる必要がある。
https://github.com/KaihuaTang/Long-Tailed-Recognition.pytorch/blob/master/lvis1.0/mmdet/models/roi_heads/htc_roi_head.py#L106
Top comments (0)