Model explanations should be pragmatic, with RSA and RL

The interpretability community is currently heavily focused on post-hoc explanations, such as Integradient Gradients, TCAV, or SHAP, and whether they're faithful. The latter thread has the wonderful sequence Attention is not Explanation and its rejoinder Attention is not not Explanation.

One dimension typically omitted in this field is whether explanations are actually useful.

One view of utility is as communciation, or a reasoning tool, that improves predictive accuracy.

The Explanatory Multiplier

Let's formally relate the predictive probability with an explanation, \(P(y|x,e)\), to the base model's prediction, \(P(y|x)\):

\[ P(y|x,e) = \frac{P(e|x,y)}{P(e|x)} \cdot P(y|x) \]

That fractional term is essentially the pointwise mutual information (PMI) between the explanation and the label, conditioned on the input. If the explanation \(e\) is equally likely regardless of the true label \(y\), the fraction becomes 1, and the explanation provides zero marginal utility.

In other words, a good explanation must be discriminative.

The Rational Speech Acts (RSA) perspective and pragmatics

We can cast learning these explanations using the Rational Speech Acts (RSA) framework -- a Bayesian approach to communication where speakers and listeners recursively reason about one another. This comes from the "pragmatics" field, and one paper I really love in this direction is Learning Language Games Through Interaction.

In the context of model interpretability, the game looks like this:

  • Literal Explainer (E0): Generates a raw attribution map, e.g., standard Integrated Gradients: \(P(e|x,y)\).
  • Reasoning Guesser (G1): Updates its prediction based on the utility of that explanation: \(P(y|x,e)\).
  • Reasoning Explainer (E1): Modifies the explanation to maximize the guesser's chance of getting the right answer.

The pragmatic explainer E1 computes:

\[ E1: P(e|x,y) = G1 \cdot P(e|x) \]

This formulation naturally suppresses features the base model already perfectly understands. It operationalizes the command: "Tell me something I don't know."

Learning the explainer: RSA and RL

The question is how to learn a high-quality marginal \(P(e|x)\) — the base rate over explanations, independent of the label \(y\) — to plug into E1. It helps to treat the explanation as a bottleneck, where \(y\) depends on \(x\) only through \(e\):

\[ P(y|x) = \sum_{e \in E} P(y|e)\,P(e|x) \]

This lets us learn the guesser \(P(y|e)\) once and reuse it. There's two ways to then learn the explainer:

  • Bootstrap from humans, then refine. Imitate human data to get a guesser \(P_{human}(y|e)\) and explainer \(P_{human}(e|x)\); then learn a new explainer \(P'(e|x)\) that maximizes the accuracy of the frozen human guesser — regularized to stay close to human explanations via \(\min_e \text{KL}(P_{human}(e|x) \,\|\, P'(e|x))\), or simpler overlap penalties.
  • Learn end-to-end via RL. Skip the separate networks and train a single model with RL to emit its own explanation \(e\) and reward it for the resulting accuracy on \(y\).