8.3 Approximation variationnelle

Dans l’étape E de l’algorithme EM il faut calculer l’espérance conditionnelle \[\begin{equation} \tag{8.3} Q(\theta, \theta') =\mathbb E_{\theta'}[\log \mathbb P_{\theta}(\mathbf A, \mathbf Z)|\mathbf A]. \end{equation}\] Il faut alors connaitre la loi conditionnelle \(\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\) qui n’est pas tractable d’après (8.2).

Approximiation variationnelle. Une solution naturelle consiste à approcher cette loi \(\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\) par une loi plus simple. Concrètement, on se donne une famille \(\mathcal Q\) de lois simples et connues sur \(\mathbf Z\) et on cherche la loi la plus proche de \(\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\) dans cette classe \(\mathcal Q\) en terme de la divergence de Kullback-Leibler: \[\begin{equation} \tag{8.4} \tilde{\mathbb Q} =\arg\min_{\mathbb Q\in\mathcal Q}\mathrm{KL}\left(\mathbb Q~\|~\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\right). \end{equation}\]

On appelle cette démarche de l’approximiation variationnelle. Concernant l’algorithme EM, on remplace alors \(\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\) par cette loi approchée \(\tilde{\mathbb Q}\) et au lieu de \(Q(\theta, \theta')=\mathbb E_{\theta'}[\log \mathbb P_{\theta}(\mathbf A, \mathbf Z)|\mathbf A]\) on utilise \(\mathbb E_{\tilde{\mathbb Q}} [\log \mathbb P_{\theta}(\mathbf A, \mathbf Z)]\). Un algorithme de type EM avec de l’approximation variationnelle est dit algorithme VEM pour variational EM-algorithm.

Si la famille de lois \(\mathcal Q\) contient la vraie loi \(\mathbb P_{\theta^k}(\mathbf Z|\mathbf A)\), on obtient \(\tilde{\mathbb Q}=\mathbb P_{\theta^k}(\mathbf Z|\mathbf A)\), car la divergence de Kullback-Leibler est telle que \(\mathrm{KL}(\mathbb Q~\|~\mathbb P)\geq 0\) pour toutes lois \(\mathbb Q, \mathbb P\) avec égalité si et seulement si \(\mathbb Q= \mathbb P\). Autrement dit, on utilise la loi exacte, et l’algorithme VEM revient à l’algorithme EM classique.


Algorithme VEM.

Entrée: observations \(\mathbf A\), nombre de blocs \(Q\), point initial \(\theta^0\).

Sortie: Paramètre \(\hat\theta^{k+1}\), loi variationnelle \(\tilde{\mathbb Q}\).

Procédure: A l’itération \(k\):

  • VE-step: Calculer \[\tilde{\mathbb Q} =\arg\min_{\mathbb Q\in\mathcal Q}\mathrm{KL}\left(\mathbb Q~\|~\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\right).\]
  • M-step: Calculer \[\begin{align*} \theta^{k+1} =\arg\max_{\theta\in\Theta}\mathbb E_{\tilde{\mathbb Q}}[\log \mathbb P_{\theta}(\mathbf A, \mathbf Z)]. \end{align*}\]

Approximation de champ moyen. Rappelons que dans le contexte du SBM \(\mathbf Z=(Z_1,\dots,Z_n)\) est un vecteur aléatoire discrète à valeurs dans \(\{1,\dots,Q\}^n\).

Lorsqu’on choisit comme famille de lois \(\mathcal Q\) pour \(\mathbf Z\) la famille de lois factorisées à valeurs dans \(\{1,\dots,Q\}^n\), l’approximation est dite approximation de champ moyen ou mean field approximation. Dans ce cas, toute loi \(\mathbb Q\in\mathcal Q\) est entièrement décrite par les probabilités \(\mathbb Q(Z_i=q)\), et nous utilisons ces probabilités pour paramétrer la famille \(\mathcal Q\). Ainsi, pour \(\mathbf z=(z_1,\dots,z_n)\in\{1,\dots,Q\}^n\), on a \[\begin{equation} \mathbb Q(\mathbf z)=\prod_{i=1}^n \mathbb Q(Z_i=z_i)=\prod_{i=1}^n\prod_{q=1}^Q\tau_{i,q}^{z_{i,q}}, \tag{8.5} \end{equation}\]\(z_{i,q}=\mathbb{1}\{z_i=q\}\) et \(\tau_{i,q}=\mathbb Q(Z_i=q)\) et \(\sum_{q=1}^Q\tau_{i,q}=1\) pour tout \(1\le i\le n\). On appelle les \((\tau_{i,q})_{i,q}\) les paramètres variationnels et on considère la valeur de \(\tau_{i,q}\) comme une approximation de la probabilité conditionnelle \(\mathbb P_{\theta'}(Z_i=q|\mathbf A)\) que le \(i\)-ème noeud appartient au bloc \(q\).

Les \((\tau_{i,q})_{i,q}\) définissent un soft clustering des noeuds, car ce sont des probabilités à valeurs dans l’intervalle \([0,1]\). On peut en déduire une hard clustering par le maximum a posterior (MAP) défini par \[\hat Z_i=\arg\max_{1\le q\le Q}\tau_{i,q}, \quad i=1,\dots,n,\] En pratique, les paramètres \((\hat\tau_{i,q})_{i,q}\) que l’on obtient à la fin de l’algorithme VEM sont souvent tout près de 0 et 1, ce qui veut dire qu’ils définissent (presque) un hard clustering. Néanmoins, parfois il y a une observation difficile à classer, et donc les valeurs des paramètres variationnels correspondants ne converge ni vers 0, ni vers 1.

Minimisation de la divergence de Kullback-Leibler. On peut réécrire le problème de minimisation de la divergence de Kullback-Leibler dans (8.4) comme \[\begin{align} \begin{split} \tilde{\mathbb Q} &=\arg\min_{\mathbb Q\in\mathcal Q}\mathrm{KL}\left(\mathbb Q~\|~\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\right)\\ &=\arg\min_{\mathbb Q\in\mathcal Q}\mathbb E_{\mathbb Q}\left[\log\frac{\mathbb Q(\mathbf Z)}{\mathbb P_{\theta'}(\mathbf Z|\mathbf A)}\right]\\ &=\arg\min_{\mathbb Q\in\mathcal Q}\Big\{\underbrace{\mathbb E_{\mathbb Q}\left[\log \mathbb Q(\mathbf Z)\right]}_{=:-\mathcal H(\mathbb Q)} - \mathbb E_{\mathbb Q}\Big[\underbrace{\log \mathbb P_{\theta'}(\mathbf Z|\mathbf A)}_{=\log \mathbb P_{\theta'}(\mathbf A,\mathbf Z)-\log \mathbb P_{\theta'}(\mathbf A)}\Big]\Big\}\\ &=\arg\max_{\mathbb Q\in\mathcal Q}\Big\{\underbrace{\mathcal H(\mathbb Q) + \mathbb E_{\mathbb Q}\left[\log \mathbb P_{\theta'}(\mathbf A,\mathbf Z) \right]}_{=:J(\mathbb Q, \theta')}\Big\}. \end{split} \tag{8.6} \end{align}\] On appelle \(\mathcal H(\mathbb Q)\) l’entropie de la loi \(\mathbb Q\). A la dernière ligne, le terme \(\log \mathbb P_{\theta'}(\mathbf A)\) a disapru, car il ne dépend pas de \(\mathbb Q\).

Le fait de réécrire le problème avec la loi des données complètes \(\mathbb P_{\theta'}(\mathbf A,\mathbf Z)\) au lieu de la loi a posteriori \(\mathbb P_{\theta'}(\mathbf Z|\mathbf A)\)

Calculs dans le SBM. Soit \(\mathbb Q\) une loi factorisée sur \(\{1,\dots,Q\}^n\) de la forme de (8.5). En utilisant (8.1), on réécrit \(J(\mathbb Q, \theta')\) dans le SBM comme \[\begin{align} J&(\mathbb Q, \theta') = -\mathbb E_{\mathbb Q}\left[\log \mathbb Q(\mathbf Z) \right] + \mathbb E_{\mathbb Q}\left[\log \mathbb P_{\theta'}(\mathbf A,\mathbf Z) \right]\notag\\ &=-\sum_{i=1}^n\sum_{q=1}^Q \tau_{i,q}\log \tau_{i,q} +\mathbb E_{\mathbb Q}\left[\sum_{q=1}^Q \sum_{i=1}^nZ_{i,q} \log \pi_q + \sum_{q=1}^Q\sum_{\ell=1}^Q \sum_{i<j} {Z_{i,q} Z_{j,\ell}} \log F(A_{i,j}; \gamma_{q,\ell})\right]\notag\\ &=\sum_{i=1}^n\sum_{q=1}^Q \tau_{i,q}\log\frac{\pi_q}{\tau_{i,q}} +\sum_{q=1}^Q\sum_{\ell=1}^Q \sum_{i<j} \tau_{i,q}\tau_{j,\ell} \log F(A_{i,j}; \gamma_{q,\ell}), \tag{8.7} \end{align}\] car \(\mathbb E_{\mathbb Q}[Z_{i,q}]=\tau_{i,q}\) et par indépendance de \(Z_{i}\) et \(Z_{j}\) sous \(\mathbb Q\).

Proposition 8.1 (VE-step) Dans un SBM, en utilisant une approximation de champ moyen, la solution de (8.6) est donnée par la solution \(\hat \tau =\{\hat \tau_{i,q}\}_{i,q}\) qui vérifie l’équation de point fixe suivante \[ \forall 1\le i \le n , \forall 1\le q \le Q, \quad \hat \tau_{i,q} \propto \pi_q \prod_{j\neq i} \prod_{\ell=1}^Q [F(A_{i,j}; \gamma_{q,\ell})]^{\hat \tau_{j\ell}}, \]\(\propto\) signifie ‘proportionnel à’. La constante est obtenue à partir de la contrainte de \(\forall i, \sum_q \tau_{i,q}= 1\).
Proof. On cherche les points critiques de \(J(\mathbb Q, \theta')\) en dérivant l’expression donnée en (8.7) sans oublier les contraintes \(\forall i, \sum_q \tau_{i,q}= 1\)).