Causality in machine learning
By OMKAR MURALIDHARAN, NIALL CARDIN, TODD PHILLIPS, AMIR NAJMI
Given recent advances and interest in machine learning, those of us with traditional statistical training have had occasion to ponder the similarities and differences between the fields. Many of the distinctions are due to culture and tooling, but there are also differences in thinking which run deeper. Take, for instance, how each field views the provenance of the training data when building predictive models. For most of ML, the training data is a given, often presumed to be representative of the data against which the prediction model will be deployed, but not much else. With a few notable exceptions, ML abstracts away from the data generating mechanism, and hence sees the data as raw material from which predictions are to be extracted. Indeed, machine learning generally lacks the vocabulary to capture the distinction between observational data and randomized data that statistics finds crucial. To contrast machine learning with statistics is not the object of this post (we can do such a post if there is sufficient interest). Rather, the focus of this post is on combining observational data with randomized data in model training, especially in a machine learning setting. The method we describe is applicable to prediction systems employed to make decisions when choosing between uncertain alternatives.
Why would we care about the prediction accuracy of unrealized scenarios? There are a number of reasons. First is that our decision not to choose a particular scenario might be incorrect, but we might never learn this because we never generate data to contradict the prediction. Second, real-world prediction systems are constantly being updated and improved — knowledge of errors would help us target model development efforts. Finally, a more niche reason is the use in auction mechanisms such as second-pricing where the winner (predicted highest) must pay what value the runner up is predicted to have realized.
Let's start with a simple example to illustrate the problems of predicting and intervening. Suppose a mobile carrier builds a "churn" model to predict which of its customers are likely to discontinue their service in the next three months. The carrier offers a special renewal deal to those who were predicted by the model as most likely to churn. When we analyze the set of customers who have accepted the special deal (and hence not churned), we don't immediately know which customers would have continued their service anyway versus those who renewed because of the special deal. This lack of information has some consequences:
In this simple example, we could of course have run an experiment where we have a randomized control group to whom we would have made a special offer but did not (a "holdback" group). This gives us a way to answer each of the questions above. But what if we are faced with many different situations and many possible interventions, with the objective to select the best intervention in each case?
Let's consider a more complex problem to serve as the running example in this post. Suppose an online newspaper provides a section on their front page called Recommended for you where they highlight news stories they believe a given user will find interesting (the NYTimes does this for its subscribers). We can imagine a sophisticated algorithm to maximize the number of stories a user will find valuable. It may predict which stories to highlight and in which order, based on the user's location, reading history and the topic of news stories.
Those stories which are recommended are likely to be uptaken both because the algorithm works and because of sheer prominence of the recommendation. If the model is more complex (say, multiple levels of prominence, interaction effects), holdback experiments won't scale because they would leave little opportunity to take the optimal action. Randomized experiments are "costly" because we do something different from what we think is best. We want to bring a small amount of randomization into the machine learning process, but do it in manner that uses randomization effectively.
Predicting well on counterfactuals is usually harder than predicting well on the observed data because the decision making process creates confounding associations in the data. Continuing our news story recommendation example, the correct decision rule will tend to make recommendations which are most likely to be uptaken. If we try to estimate the effect of recommendation prominence by comparing how often users read recommended stories against stories not recommended, the association between our prediction and prominence would probably dominate — after all, the algorithm chooses to make prominent only those stories which appear likely to be of interest to the reader.
If we want to use predictions to make good decisions, we have to answer the following questions:
In theory, we could use a holdback experiment to estimate the effect of prominence where we randomly do not recommend stories which we would otherwise have recommended. We can estimate the causal effect of prominence as the difference in log odds of uptake between stories which were recommended and those which were eligible (i.e. would have been recommended) but were randomly not recommended. The value of $\beta_{\pi}$ in following GLM equation is the causal estimate we seek:
$$
\mathrm{logit}(EY) = \beta_{\pi} X_{\pi} + \beta_e X_e \tag{Eq 3}
$$ where $X_e$ is the binary variable denoting the story was eligible for recommendation and $\beta_e$ its associated coefficient. Since we only recommend eligible stories, $X_{\pi}=1$ implies $X_e=1$, and $X_{\pi} \neq X_e$ occurs only in our holdback.
In a complex system, randomization can still be surprisingly subtle because there are often multiple ways to randomize. For example, we can randomize inputs to the decision procedure, or directly randomize decisions. The former approach will tend to produce more realistic outcomes, but can be harder to understand, and may not give us adequate data to assess unlikely decisions. The latter approach is usually easier to understand, but can produce unrealistic outcomes.
To show how subtle things can be, let's go back to our example earlier where we computed the causal effect of prominence by running a holdback experiment. What we did there was to randomly not recommend stories we would have recommended. But this is just one kind of random perturbation. This particular procedure allows us to estimate (and hence check on) what statisticians call treatment on the treated. In other words, we estimate the average effect of prominence on the uptake of stories we recommend. This is different than the average effect of prominence across the population of stories. What we miss is the effect of prominence on the kinds of stories we never recommend. Suppose the effect of prominence is significantly lower for a news topic that no one finds interesting, say, news on the proceedings of the local chamber of commerce (PLCC). If we never recommend stories of the PLCC, they won't contribute to our holdback experiment and hence we will never learn that our estimates for such stories were too high. We could fix this problem by recommending random stories (as well as randomly suppressing recommendations) and hence directly measure the average effect of recommendation on all stories. But this might not be quite what we want either — it is unclear what we are learning from the unnatural scenario of recommending news of the intolerable PLCC. These recommendations might themselves cause users to react in an atypical manner, perhaps by not bothering to look at recommendations further down the list.
Random expression and suppression of recommendation are examples of what we called randomizing the decision. The alternative we mentioned was to randomize inputs to the decision. We could achieve this by adding random noise to each news story's quality score and feeding it into the decision procedure. If the amount of random noise is commensurate with the variability in the quality score then we will truly generate a realistic set of perturbations. The data we collect from these randomized perturbations will tend to be near the quality threshold for recommendation, which is usually where data is most valuable. On the other hand, this data might be less interpretable — all sorts of decisions might change, not just the stories whose scores we randomized. The impact of each individual perturbation is not easily separated and this can make the data harder to use for modeling and analysis of prediction errors.
There is no easy way to make the choice between realism and interpretability. We’ve often chosen artificial randomization that is easy to understand, since it is more likely to produce data useful for multiple applications, and subsequently checked the results against more realistic randomization. Happily, in our case, we found that answers between these two approaches were in good agreement.
Now why would anyone ever place demands on a prediction system which run counter this seemingly reasonable principle? In our work, violations of this principle have arisen when we've wanted to impose invariants on the decisions we make, usually with good intentions. The problem isn't the invariants themselves but rather the hypothetical nature of their premises.
For example, suppose the online newspaper wishes to enforce a policy of "platform neutrality" whereby the quality of recommendations should be the same on iPhone and Android mobile phones. Perhaps the newspaper wishes to ensure that users would see the same recommendations regardless of the type of phone they use. However, this is a slippery notion. iPhone users might actually have different aggregate usage patterns and preferences from Android users, making a naive comparison inappropriate.
One way to address this is to use techniques to derive causal inference from observational analysis to predict what an iPhone user would do if she were using Android. There is valuable literature on this (e.g. Rubin Causal Model) but, fundamentally, you really need to understand what you are doing with a causal model. That means the model needs careful, manual attention from an experienced modeler, and a nuanced understanding of its strengths and weaknesses. This is why observational analysis models are used only when there is no alternative, where the inference made by the model is carefully traded off against its assumptions. Such fine inspection is not feasible for production models, which are updated by many people (and automatically over time), require ongoing automated monitoring, and whose predictions are used for many applications.
Why is NFN necessary? First, it is generally a better way to design decision systems. Actions violating NFN can never be realized. This usually means they are not directly important for decisions, and the system can be improved by thinking about its goals more carefully.
Second, and more importantly, mission-critical systems have much stronger requirements than one-off analyses — we need to be able to monitor them for correctness, to define clear notions of improvement, to check that their behavior is stable, and to debug problems. For example, suppose our system uses predictions for what iPhone users would do on Android. If those predictions drift over time, we have no way to tell if the system is working well or not. One-off analyses might be able to get away with using observational analysis techniques, but critical systems need ongoing, robust, direct validation with randomized data. Crucially, we can never check how accurate this prediction is, since we cannot randomly force the iPhone user to use Android.
In summary, the NFN principle cautions us against imposing requirements whose solutions may have unintended consequences we cannot easily detect. As with any principle, we would override it with great caution.
Large scale prediction systems are often bad at counterfactual predictions out of the box. This is because a large scale prediction system is almost certainly misspecified. Thus even very sophisticated ones suffer from a kind of overfitting. These systems don’t classically overfit — they use cross-validation or progressive validation to avoid that — but they tend to overfit to the observed distribution of the data. As we described earlier, this factual distribution doesn’t match the counterfactual distribution of data and hence the model can fail to generalize. When deciding which stories to recommend, we need predictions with and without the prominence of recommendation — but that means we'll need good predictions for interesting stories which don't get recommended and uninteresting stories which are recommended, rare and strange parts of the factual data.
An obvious attempt to fix this is to upweight randomized data in training, or even train the model solely on the randomized data. Unfortunately, such direct attempts perform poorly. Let's say 1% of the data are randomized. If we train on both randomized and observational data, the observational data will push estimates off course due to model misspecification. However, training solely on randomized data will suffer from data sparseness because you are training on 1% of the data. Nor is upweighting the randomized data of much of a solution — we reduce the influence of observational data only to the extent we reduce its role in modeling. Thus, upweighting is tantamount to throwing away non-randomized data.
The problem is that the model doesn’t know the random data is random, so it uses it in exactly the same way as it uses any data. As we observed at the start of this post, standard machine learning techniques don’t distinguish between randomized and observational data the way statistical models do. To make better estimates, we need the randomized data to play a different role than the observational data in model training.
What is the right role for randomized data? There is probably more than one good answer to that question. For instance, one could imagine shrinking the unbiased, high-variance estimates from randomized data towards the potentially-biased, low-variance observational estimates. This is not the approach we chose for our application but we nonetheless do use the observational estimates to reduce the variance of the estimates made from randomized data.
Previously we used separate models to learn the effects of prominence and quality. The quality model (Eq 2) took the estimates of the prominence model (Eq 3) as an input. While this does achieve some of our goals it has its own problems. Firstly, this set-up is clunky, we have to maintain two models and changes in performance are the result of complex interactions between the two. Also, updating either is made harder by its relationship to the other. Secondly, the prominence model fails to take advantage of the information in the quality model.
The approach we have found most effective is best motivated as a refinement of the simple model we used to estimate the causal effect of prominence from our holdback in Eq 3. Let the quality score for each story be the log odds prediction of uptake without prominence. For the model in Eq 2 it would simply be $\beta_1 X_1$ where $\beta_1$ is the coefficient of $X_1$ estimated by ML. Assume for a moment that the quality score component is given. An improved estimate of the causal effect of prominence is possible from estimating $\beta_{\pi}$ in the model
Given recent advances and interest in machine learning, those of us with traditional statistical training have had occasion to ponder the similarities and differences between the fields. Many of the distinctions are due to culture and tooling, but there are also differences in thinking which run deeper. Take, for instance, how each field views the provenance of the training data when building predictive models. For most of ML, the training data is a given, often presumed to be representative of the data against which the prediction model will be deployed, but not much else. With a few notable exceptions, ML abstracts away from the data generating mechanism, and hence sees the data as raw material from which predictions are to be extracted. Indeed, machine learning generally lacks the vocabulary to capture the distinction between observational data and randomized data that statistics finds crucial. To contrast machine learning with statistics is not the object of this post (we can do such a post if there is sufficient interest). Rather, the focus of this post is on combining observational data with randomized data in model training, especially in a machine learning setting. The method we describe is applicable to prediction systems employed to make decisions when choosing between uncertain alternatives.
Predicting and intervening
Most of the prediction literature assumes that predictions are made by a passive observer who has no influence on the phenomenon. On the other hand, most prediction systems are used to make decisions about how to intervene in a phenomenon. Often, the assumption of non-influence is quite reasonable — say if we predict whether or not it will rain in order to determine if we should carry an umbrella. In this case, whether or not we decide to carry an umbrella clearly doesn't affect the weather. But at other times, matters are less clear. For instance, if the predictions are used to decide between uncertain alternative scenarios then we observe only the outcomes which were realized. In this framing, the decisions we make influence our future training data. Depending on how the model is structured, we typically use the information we gain from realized factual scenarios to assess probabilities associated with unrealized counterfactual scenarios. But this involves extrapolation and hence the counterfactual prediction might be less accurate. Some branches of machine learning (e.g. multi-arm bandits and reinforcement learning) adopt this framing of choice between alternative scenarios in order to study optimal tradeoffs between exploration and exploitation. Our goal here is specifically to evaluate and improve counterfactual predictions.Why would we care about the prediction accuracy of unrealized scenarios? There are a number of reasons. First is that our decision not to choose a particular scenario might be incorrect, but we might never learn this because we never generate data to contradict the prediction. Second, real-world prediction systems are constantly being updated and improved — knowledge of errors would help us target model development efforts. Finally, a more niche reason is the use in auction mechanisms such as second-pricing where the winner (predicted highest) must pay what value the runner up is predicted to have realized.
Let's start with a simple example to illustrate the problems of predicting and intervening. Suppose a mobile carrier builds a "churn" model to predict which of its customers are likely to discontinue their service in the next three months. The carrier offers a special renewal deal to those who were predicted by the model as most likely to churn. When we analyze the set of customers who have accepted the special deal (and hence not churned), we don't immediately know which customers would have continued their service anyway versus those who renewed because of the special deal. This lack of information has some consequences:
- we cannot directly measure churn prediction accuracy on customers to whom we made offers
- we cannot directly measure if the offer was effective (did its benefit exceeded its cost)
- we must (somehow) account for the intervention when training future churn models
In this simple example, we could of course have run an experiment where we have a randomized control group to whom we would have made a special offer but did not (a "holdback" group). This gives us a way to answer each of the questions above. But what if we are faced with many different situations and many possible interventions, with the objective to select the best intervention in each case?
Let's consider a more complex problem to serve as the running example in this post. Suppose an online newspaper provides a section on their front page called Recommended for you where they highlight news stories they believe a given user will find interesting (the NYTimes does this for its subscribers). We can imagine a sophisticated algorithm to maximize the number of stories a user will find valuable. It may predict which stories to highlight and in which order, based on the user's location, reading history and the topic of news stories.
Figure 1: A hypothetical example of personalized news recommendations |
Those stories which are recommended are likely to be uptaken both because the algorithm works and because of sheer prominence of the recommendation. If the model is more complex (say, multiple levels of prominence, interaction effects), holdback experiments won't scale because they would leave little opportunity to take the optimal action. Randomized experiments are "costly" because we do something different from what we think is best. We want to bring a small amount of randomization into the machine learning process, but do it in manner that uses randomization effectively.
Predicting well on counterfactuals is usually harder than predicting well on the observed data because the decision making process creates confounding associations in the data. Continuing our news story recommendation example, the correct decision rule will tend to make recommendations which are most likely to be uptaken. If we try to estimate the effect of recommendation prominence by comparing how often users read recommended stories against stories not recommended, the association between our prediction and prominence would probably dominate — after all, the algorithm chooses to make prominent only those stories which appear likely to be of interest to the reader.
If we want to use predictions to make good decisions, we have to answer the following questions:
- How do we measure accuracy on counterfactuals?
- Are there counterfactual predictions we should avoid in decision making?
- How can we construct our prediction system to do well on counterfactuals?
A problem of counterfactual prediction
Before we describe solutions, it is important to be precise about the problem we are trying to address. First off, let's be clear that if the model we are training is the true model (i.e. it correctly specifies the generating mechanism), there is no problem. Likelihood theory guarantees that we will estimate the true model in an unbiased way and asymptotically converge upon it. The problem is that every real-world model is misspecified in some way, and this is what leads to poor counterfactual estimates.
For illustration purposes, assume there is a single level of prominence and that the true model is binomial for the binary event of uptake $Y$. This true model is described by the GLM equation
For illustration purposes, assume there is a single level of prominence and that the true model is binomial for the binary event of uptake $Y$. This true model is described by the GLM equation
$$\mathrm{logit}(EY) = \beta_1 X_1 + \beta_2 X_2 + \beta_{\pi} X_{\pi} \tag{Eq 1} $$
where $X_1$ and $X_2$ are continuous features we use to estimate the relevance of the news story to the user, $X_{\pi}$ is the binary variable indicating whether the story was made prominent. Let's define $\beta_1 X_1 + \beta_2 X_2$ to be the quality score of the model. We wish to estimate $\beta_{\pi}$, the true log odds effect of prominence on uptake $Y$.
Now suppose we fit the following misspecified model using maximum likelihood
$$ \mathrm{logit}(EY) = \beta_1 X_1 + \beta_{\pi} X_{\pi} \tag{Eq 2} $$
This model is misspecified because our quality score is missing $X_2$. Our estimate of $\beta_{\pi}$ will pick up the projection of $X_2$ onto $X_{\pi}$ (and onto $X_1$). In general, we will have misattribution to the extent that errors in our model are correlated with $X_{\pi}$. This isn't anything new. If all we care about is prediction on observed $Y$, we do fine, at least to the extent $X_2$ can be projected on the space spanned by $X_1$ and $X_{\pi}$. The fact that our estimate of $\beta_{\pi}$ is not unbiased isn't a concern because our predictions are unbiased (i.e. correct on average on the logit scale). The problem only arises when we use the model to predict on observations where the distribution of predictors is different from the training distribution — this of course happens when we are deciding on which stories to administer prominence. Depending on the situation, this could be a big deal, and so it has been at Google.
In theory, we could use a holdback experiment to estimate the effect of prominence where we randomly do not recommend stories which we would otherwise have recommended. We can estimate the causal effect of prominence as the difference in log odds of uptake between stories which were recommended and those which were eligible (i.e. would have been recommended) but were randomly not recommended. The value of $\beta_{\pi}$ in following GLM equation is the causal estimate we seek:
$$
\mathrm{logit}(EY) = \beta_{\pi} X_{\pi} + \beta_e X_e \tag{Eq 3}
$$ where $X_e$ is the binary variable denoting the story was eligible for recommendation and $\beta_e$ its associated coefficient. Since we only recommend eligible stories, $X_{\pi}=1$ implies $X_e=1$, and $X_{\pi} \neq X_e$ occurs only in our holdback.
Observe that $\beta_{\pi}$ is estimated as the difference in $\mathrm{logit}(EY)$ when $X_{\pi} = 1$, $X_e = 1$ and when $X_{\pi} = 0$, $X_e = 1$. Why we use this roundabout GLM model to express a simple odds ratio calculation will become clearer further on. The point is that this method works to estimate the causal effect of prominence because randomization breaks the correlation between $X_2$ and $X_{\pi}$ (see an earlier post for a more detailed discussion on this point). As per Eq 2, we can apply this estimate of $\beta_{\pi}$ from our randomized holdback in estimating $\beta_1$ on observational data.
Checking accuracy with randomization — realism vs. interpretability
The best and most obvious way to tell how well we are predicting the effects of counterfactual actions is to randomly take those counterfactual actions a fraction of the time and see what happens. In our news recommendation example, we can randomly decide to recommend or not some stories, and see if our decision-time prediction of the change in uptake rates is correct. As we saw, this works because randomization breaks the correlations between our chosen action (whether to recommend) and other decision inputs (the quality of the recommendation).In a complex system, randomization can still be surprisingly subtle because there are often multiple ways to randomize. For example, we can randomize inputs to the decision procedure, or directly randomize decisions. The former approach will tend to produce more realistic outcomes, but can be harder to understand, and may not give us adequate data to assess unlikely decisions. The latter approach is usually easier to understand, but can produce unrealistic outcomes.
To show how subtle things can be, let's go back to our example earlier where we computed the causal effect of prominence by running a holdback experiment. What we did there was to randomly not recommend stories we would have recommended. But this is just one kind of random perturbation. This particular procedure allows us to estimate (and hence check on) what statisticians call treatment on the treated. In other words, we estimate the average effect of prominence on the uptake of stories we recommend. This is different than the average effect of prominence across the population of stories. What we miss is the effect of prominence on the kinds of stories we never recommend. Suppose the effect of prominence is significantly lower for a news topic that no one finds interesting, say, news on the proceedings of the local chamber of commerce (PLCC). If we never recommend stories of the PLCC, they won't contribute to our holdback experiment and hence we will never learn that our estimates for such stories were too high. We could fix this problem by recommending random stories (as well as randomly suppressing recommendations) and hence directly measure the average effect of recommendation on all stories. But this might not be quite what we want either — it is unclear what we are learning from the unnatural scenario of recommending news of the intolerable PLCC. These recommendations might themselves cause users to react in an atypical manner, perhaps by not bothering to look at recommendations further down the list.
Random expression and suppression of recommendation are examples of what we called randomizing the decision. The alternative we mentioned was to randomize inputs to the decision. We could achieve this by adding random noise to each news story's quality score and feeding it into the decision procedure. If the amount of random noise is commensurate with the variability in the quality score then we will truly generate a realistic set of perturbations. The data we collect from these randomized perturbations will tend to be near the quality threshold for recommendation, which is usually where data is most valuable. On the other hand, this data might be less interpretable — all sorts of decisions might change, not just the stories whose scores we randomized. The impact of each individual perturbation is not easily separated and this can make the data harder to use for modeling and analysis of prediction errors.
There is no easy way to make the choice between realism and interpretability. We’ve often chosen artificial randomization that is easy to understand, since it is more likely to produce data useful for multiple applications, and subsequently checked the results against more realistic randomization. Happily, in our case, we found that answers between these two approaches were in good agreement.
The No Fake Numbers Principle
Automated decision systems are often mission-critical, so it is important that everything which goes into them is accurate and checkable. We can’t reliably check counterfactual predictions for actions we can’t randomize. These facts lead to the No Fake Numbers (NFN) principle:
Avoid decisions based on predictions for counterfactual actions you cannot take.
Avoid decisions based on predictions for counterfactual actions you cannot take.
In other words, NFN says not to use predictions of unobservable quantities to make decisions.
Now why would anyone ever place demands on a prediction system which run counter this seemingly reasonable principle? In our work, violations of this principle have arisen when we've wanted to impose invariants on the decisions we make, usually with good intentions. The problem isn't the invariants themselves but rather the hypothetical nature of their premises.
For example, suppose the online newspaper wishes to enforce a policy of "platform neutrality" whereby the quality of recommendations should be the same on iPhone and Android mobile phones. Perhaps the newspaper wishes to ensure that users would see the same recommendations regardless of the type of phone they use. However, this is a slippery notion. iPhone users might actually have different aggregate usage patterns and preferences from Android users, making a naive comparison inappropriate.
One way to address this is to use techniques to derive causal inference from observational analysis to predict what an iPhone user would do if she were using Android. There is valuable literature on this (e.g. Rubin Causal Model) but, fundamentally, you really need to understand what you are doing with a causal model. That means the model needs careful, manual attention from an experienced modeler, and a nuanced understanding of its strengths and weaknesses. This is why observational analysis models are used only when there is no alternative, where the inference made by the model is carefully traded off against its assumptions. Such fine inspection is not feasible for production models, which are updated by many people (and automatically over time), require ongoing automated monitoring, and whose predictions are used for many applications.
Why is NFN necessary? First, it is generally a better way to design decision systems. Actions violating NFN can never be realized. This usually means they are not directly important for decisions, and the system can be improved by thinking about its goals more carefully.
Second, and more importantly, mission-critical systems have much stronger requirements than one-off analyses — we need to be able to monitor them for correctness, to define clear notions of improvement, to check that their behavior is stable, and to debug problems. For example, suppose our system uses predictions for what iPhone users would do on Android. If those predictions drift over time, we have no way to tell if the system is working well or not. One-off analyses might be able to get away with using observational analysis techniques, but critical systems need ongoing, robust, direct validation with randomized data. Crucially, we can never check how accurate this prediction is, since we cannot randomly force the iPhone user to use Android.
In summary, the NFN principle cautions us against imposing requirements whose solutions may have unintended consequences we cannot easily detect. As with any principle, we would override it with great caution.
Using randomization in training
The previous sections described how to use randomization to check our prediction models and guide the design of our decision systems. This section describes how to go further, and directly incorporate randomized data into the systems.Large scale prediction systems are often bad at counterfactual predictions out of the box. This is because a large scale prediction system is almost certainly misspecified. Thus even very sophisticated ones suffer from a kind of overfitting. These systems don’t classically overfit — they use cross-validation or progressive validation to avoid that — but they tend to overfit to the observed distribution of the data. As we described earlier, this factual distribution doesn’t match the counterfactual distribution of data and hence the model can fail to generalize. When deciding which stories to recommend, we need predictions with and without the prominence of recommendation — but that means we'll need good predictions for interesting stories which don't get recommended and uninteresting stories which are recommended, rare and strange parts of the factual data.
An obvious attempt to fix this is to upweight randomized data in training, or even train the model solely on the randomized data. Unfortunately, such direct attempts perform poorly. Let's say 1% of the data are randomized. If we train on both randomized and observational data, the observational data will push estimates off course due to model misspecification. However, training solely on randomized data will suffer from data sparseness because you are training on 1% of the data. Nor is upweighting the randomized data of much of a solution — we reduce the influence of observational data only to the extent we reduce its role in modeling. Thus, upweighting is tantamount to throwing away non-randomized data.
The problem is that the model doesn’t know the random data is random, so it uses it in exactly the same way as it uses any data. As we observed at the start of this post, standard machine learning techniques don’t distinguish between randomized and observational data the way statistical models do. To make better estimates, we need the randomized data to play a different role than the observational data in model training.
What is the right role for randomized data? There is probably more than one good answer to that question. For instance, one could imagine shrinking the unbiased, high-variance estimates from randomized data towards the potentially-biased, low-variance observational estimates. This is not the approach we chose for our application but we nonetheless do use the observational estimates to reduce the variance of the estimates made from randomized data.
Previously we used separate models to learn the effects of prominence and quality. The quality model (Eq 2) took the estimates of the prominence model (Eq 3) as an input. While this does achieve some of our goals it has its own problems. Firstly, this set-up is clunky, we have to maintain two models and changes in performance are the result of complex interactions between the two. Also, updating either is made harder by its relationship to the other. Secondly, the prominence model fails to take advantage of the information in the quality model.
$$ \mathrm{logit}(EY) = \mathrm{offset}(\hat{\beta_1} X_1) + \beta_{\pi} X_{\pi} + \beta_e X_e \tag{Eq 4} $$
where $\mathrm{offset}$ is an abuse of R syntax to indicate that this component is given and not estimated. The estimate of $\beta_{\pi}$ in this model is still unbiased but by accounting for the (presumed) known quality effect, we reduce the variability of our estimate. In reality, the quality score is not known, and is estimated from the observational data. But regardless, randomization ensures that the estimate from this procedure will be unbiased. As long as we employ an estimate of quality score that is better than nothing, we account for some of the variability and hence reduce estimator variance. We have every reason to be optimistic of a decent-but-misspecified model.
The procedure above involves first training a model entirely on observational data and then using the quality score thus derived to estimate a second model for prominence, trained on randomized data. The astute reader will note that the observational model itself estimates a prominence effect which we discard. It turns out we can do even better by co-training the quality score together with the prominence. Consider the following iterative updating procedure:
- On non-randomized data, use the model $$ \mathrm{logit}(EY) = \beta_1 X_1 + \mathrm{offset}(\hat{\beta_{\pi}} X_{\pi}) $$ and only update the quality score coefficients (here just $\beta_1$).
- On randomized data, use the model $$ \mathrm{logit}(EY) = \mathrm{offset}(\hat{\beta_1} X_1) + \beta_{\pi} X_{\pi} + \beta_e X_e$$ and only update the prominence coefficients (here $\beta_{\pi}$ and $\beta_e$).
Conclusion
In this post we described how some randomized data may be applied both to check and improve the accuracy of a machine learning system trained largely on observational data. We also shared some of the subtleties of randomization applied to causal modeling. While we've spent years trying to understand and overcome issues arising from counterfactual (and "counter-usual", atypical) predictions, there is much we have still to learn. And yet the ideas we describe here have already been deployed to solve some long-standing prediction problems at Google. We hope they will be useful to you as well.
Insightful. Thanks for sharing!
ReplyDeleteFirst, what guarantees that the co-training procedure described before the conclusion is stable? That is to say, why shouldn't the estimated values in 1. and 2. oscillate?
ReplyDeleteSecond, how are X_{\pi} and X_{e} set at prediction time? Are they both always set to 1?
If you train the models with likelihood maximization and the likelihood function for each model has a single maximum then it must converge.
DeleteI am interested in the answer to the second question as well: "Second, how are X_{\pi} and X_{e} set at prediction time? Are they both always set to 1?"
DeleteVery interesting, looking forward to the machine learning vs statistics...
ReplyDeleteI am interested in reading your comparison of machine learning and statistics.
ReplyDelete