Let Bayes tune Bayes: hyperparameter optimization for causal MMMs with Optuna
You’ve spent hours fitting your Bayesian Media Mix Model. You’ve tried five different Fourier orders for seasonality. You’ve tested adstock lags from 4 to 12 weeks. Yet your model still insists that your best-performing channel deserves zero budget, and your posterior predictive checks look terrible. Stakeholders start questioning whether this “sophisticated” Bayesian approach is worth the effort.
The truth is simple: even the most principled causal model can fail if the hyperparameters are wrong. And building Bayesian models is hard. Really hard.
But there’s a systematic way to fix this. In this post, I show how to turn weeks of manual tuning into a repeatable optimization process by combining Optuna’s Bayesian optimization with PyMC-Marketing, using CRPS (Continuous Ranked Probability Score) to properly evaluate probabilistic forecasts.
Note: This tutorial uses synthetic data where I know the ground truth, so I can validate the approach. Real-world results will vary, but the general method still applies.
The hyperparameter dilemma in causal MMM
MMMs encode our understanding of marketing causality: channels have diminishing returns (saturation), effects persist over time (adstock), and seasonal patterns influence baseline sales (seasonality). These aren’t just statistical constructs. They’re hypotheses about real causal mechanisms in your marketing system.
But even with the right causal structure, poor hyperparameter choices lead to misleading conclusions and wasted computation. Should you use 3 or 7 Fourier components for seasonality? This choice sounds technical, but it can directly change whether your model recommends cutting or increasing spend on a channel.
Why traditional tuning fails
The core idea borrows from traditional machine learning: use a validation set and score models based on a defined error metric. But common approaches to hyperparameter tuning often collapse under the computational weight of Bayesian inference:
- Grid search becomes computationally prohibitive when each model fit takes 10–15 minutes with proper MCMC.
- Manual tuning based on domain expertise often devolves into time-consuming trial and error.
- Default values from packages may not suit your specific business context or lead to poor convergence.
On top of that, Bayesian models aren’t just about point predictions. They produce full posterior distributions that capture uncertainty. Traditional error metrics like RMSE or MAE don’t tell the full story.
What I needed was an efficient way to explore the hyperparameter space, one that respects the probabilistic nature of Bayesian models.
This is where Bayesian optimization meets proper scoring rules. Bayesian optimization (implemented via Optuna) efficiently explores the hyperparameter space using information from previous trials. Proper scoring rules like CRPS give us a principled way to evaluate probabilistic forecasts on a test set. Together, they form a systematic loop:
- Propose new hyperparameters (adstock decay, Fourier order, etc.)
- Fit the model with PyMC-Marketing
- Evaluate the predictive distribution on held-out data using CRPS
- Update the search strategy to focus on promising regions
Within a few dozen iterations, this approach can outperform weeks of manual tuning while remaining fully Bayesian and grounded in uncertainty-aware evaluation.
CRPS: the right metric for probabilistic forecasts
Understanding CRPS through analogy
Imagine you’re a weather forecaster. RMSE would only evaluate whether you correctly predicted “72 degrees.” CRPS evaluates your entire probabilistic forecast: “70% chance of 70–75 degrees, 20% chance of 65–70 degrees, 10% chance of 75–80 degrees.” It rewards you for being confident when you’re right and penalizes overconfidence when you’re wrong. CRPS combines both accuracy (how close is your prediction?) and calibration (is your uncertainty appropriate?).
The mathematics behind CRPS
For those who appreciate the mathematical foundations, CRPS is defined as:
CRPS(F, y) = integral from -inf to inf of (F(x) - indicator(x >= y))^2 dx
Where F is your predictive cumulative distribution function and y is the observed value.
Intuitively, CRPS measures the squared difference between your predicted CDF and a step function at the observed value. It decomposes into two components:
- Reliability: How well-calibrated are your probability statements?
- Resolution: How sharp (narrow) are your predictions when you’re confident?
In practice, with samples from your posterior predictive distribution, you compute it efficiently:
from pymc_marketing.metrics import crps
# y_true: actual observed values (shape: n_observations)
# y_pred: posterior predictive samples (shape: n_samples, n_observations)
crps_score = crps(y_true, y_pred)Why CRPS beats traditional metrics
Here’s a comparison of the metrics available for evaluating MMM performance:
| Metric | What it measures | Pros | Cons | Best for |
|---|---|---|---|---|
| RMSE | Point estimate error | Simple to interpret, widely understood | Ignores uncertainty, sensitive to outliers | Deterministic models |
| MAPE | Percentage error | Scale-independent, business-friendly | Undefined at zero, ignores uncertainty | Point forecasts |
| WAIC | In-sample fit (penalized) | Accounts for complexity, Bayesian native | Not true holdout, hard to interpret | Model comparison |
| CRPS | Full predictive distribution | Proper scoring rule, true generalization | Less intuitive, computationally heavier | Probabilistic forecasts |
The key point: CRPS is a proper scoring rule, meaning it’s minimized when your predictive distribution matches the true data-generating process. You can’t game it by predicting overly wide or narrow intervals. It naturally balances accuracy and appropriate uncertainty.
CRPS in practice: a simple example
Here’s how CRPS evaluation looks in code:
def compute_test_crps(
mmm: MMM,
X_test: pd.DataFrame,
y_test: pd.Series
) -> float:
"""Compute CRPS on test set using posterior predictive."""
# Sample posterior predictive for test set
mmm.sample_posterior_predictive(X_test, extend_idata=True)
# Extract predictions (shape: n_chains, n_draws, n_observations)
y_pred_samples = mmm.idata.posterior_predictive["y"].values
# Rescale to original scale
target_scale = float(mmm.idata.constant_data["target_scale"].values)
y_pred_rescaled = y_pred_samples * target_scale
# Reshape for CRPS calculation
n_chains, n_draws, n_obs = y_pred_rescaled.shape
y_pred_reshaped = y_pred_rescaled.reshape(n_chains * n_draws, n_obs)
# Compute CRPS
return float(crps(y_test.values, y_pred_reshaped))Bayesian optimization: smart search for expensive models
The computational challenge
A properly converged Bayesian MMM with 2000 draws, 2000 tuning steps, and 4 chains typically takes 10–15 minutes on modern hardware. Testing 100 hyperparameter combinations naively would require 17–25 hours of computation. And that assumes you know which 100 combinations are worth trying.
Bayesian optimization, specifically Tree-structured Parzen Estimators (TPE) as implemented in Optuna, addresses this by learning from each trial to focus on promising regions of the hyperparameter space. Instead of exhaustive search, it builds a probabilistic model of the objective function.
Why this works well for MMM optimization
TPE balances exploring new regions with exploiting promising areas. It can prune trials early when convergence fails, saving computation. Each trial learns from all previous trials, and it handles discrete parameters like Fourier orders without trouble.
One caveat worth noting: Bayesian optimization can still overfit to your specific test set. In production, consider using time series cross-validation or multiple hold-out periods.
Here’s how I configure Optuna for MMM optimization:
# Configuration for optimization vs final model
OPTUNA_DRAWS = 500 # Fewer draws during search (risk: unreliable estimates)
OPTUNA_TUNE = 500 # Faster convergence checking (risk: poor adaptation)
OPTUNA_CHAINS = 2 # Minimum chains (risk: can't detect convergence issues)
FINAL_DRAWS = 2000 # Production-quality for final model
FINAL_TUNE = 2000 # Thorough convergence
FINAL_CHAINS = 4 # Robust diagnostics
# Convergence thresholds - lenient during search, strict for final
OPTUNA_DIVERGENCE_THRESHOLD = 0.10 # 10% divergences OK during search
OPTUNA_RHAT_THRESHOLD = 1.10 # R-hat < 1.10 (concerning but acceptable for search)
FINAL_DIVERGENCE_THRESHOLD = 0.01 # <1% divergences for production
FINAL_RHAT_THRESHOLD = 1.01 # Strict R-hat requirement
FINAL_ESS_THRESHOLD = 400 # Minimum effective sample sizeThese reduced settings during optimization are a calculated risk. I accept potentially unreliable individual estimates in exchange for exploring more hyperparameter combinations. The final model always gets validated with full MCMC settings.
What you can (and shouldn’t) tune
Unlike traditional ML, not all parameters are fair game for optimization in Bayesian modeling. In an MMM, some hyperparameters define the structure of your causal assumptions, while others can safely be treated as knobs to optimize for predictive performance.
You can tune the Fourier order controlling seasonal complexity, the adstock lag length, and even the functional form of effects like adstock or saturation (e.g., Hill vs. logistic). These choices shape how the model captures marketing dynamics without altering its causal foundations.
However, you should not use Bayesian optimization to select priors. Priors encode external beliefs and domain knowledge, such as expected ROI ranges or plausible decay rates, and optimizing them purely for predictive accuracy defeats the Bayesian purpose. In other words: use Bayesian optimization to make your model smarter, not to make it forget what it’s supposed to believe.
Implementation: the complete playbook
Step 1: Data preparation and train/test split
The foundation of this approach is a chronological train/test split. Unlike random splits common in traditional ML, time series data requires respecting temporal ordering:
def split_train_test(
df: pl.DataFrame,
test_size_weeks: int = 24
) -> tuple[pl.DataFrame, pl.DataFrame]:
"""Split data chronologically for time series validation."""
n_total = df.shape[0]
n_train = n_total - test_size_weeks
# Sort by date to ensure chronological order
df_sorted = df.sort("date")
train_df = df_sorted[:n_train]
test_df = df_sorted[n_train:]
return train_df, test_df
# Create 80/20 split (80 weeks train, 24 weeks test)
df_train, df_test = split_train_test(df, test_size_weeks=24)
A few things to consider about test set size:
- Larger test sets (20–30%) give more reliable performance estimates but leave less training data.
- Smaller test sets (10–15%) give more training data but potentially unstable metrics.
- It helps to align test periods with important business periods (e.g., holiday seasons).
Step 2: Convergence monitoring
Before I can trust a model’s predictions, I need to make sure the MCMC chains have converged. This function checks three diagnostic criteria:
def check_convergence(
mmm: MMM,
divergence_threshold: float,
rhat_threshold: float,
ess_threshold: float,
trial_number: int = None
) -> tuple[bool, dict]:
"""Check MCMC convergence diagnostics."""
# 1. Check divergences (numerical instabilities)
n_divergences = int(mmm.idata.sample_stats.diverging.sum().item())
total_samples = n_draws * n_chains
divergence_rate = n_divergences / total_samples
# 2. Check R-hat (chain mixing)
rhat = az.rhat(mmm.idata)
max_rhat = float(rhat.to_array().max())
# 3. Check ESS (effective sample size)
ess = az.ess(mmm.idata)
min_ess = float(ess.to_array().min())
# Determine if converged
converged = (
divergence_rate <= divergence_threshold and
max_rhat <= rhat_threshold and
min_ess >= ess_threshold
)
return converged, {
"divergence_rate": divergence_rate,
"max_rhat": max_rhat,
"min_ess": min_ess
}
Step 3: The objective function
This is where everything comes together. The objective function orchestrates model fitting, convergence checking, and CRPS evaluation:
def objective(trial: optuna.Trial) -> float:
"""Optuna objective using CRPS on test set."""
# Suggest hyperparameters from search space
yearly_seasonality = trial.suggest_int("yearly_seasonality", 1, 10)
adstock_max_lag = trial.suggest_int("adstock_max_lag", 4, 12)
# Create MMM with suggested parameters
mmm = MMM(
date_column="date",
channel_columns=channel_columns,
control_columns=control_columns,
adstock=GeometricAdstock(l_max=adstock_max_lag),
saturation=LogisticSaturation(),
yearly_seasonality=yearly_seasonality
)
# Fit on TRAINING data only
mmm.fit(
X=X_train,
y=y_train,
draws=OPTUNA_DRAWS,
tune=OPTUNA_TUNE,
chains=OPTUNA_CHAINS,
nuts_sampler="numpyro", # 2-10x faster than PyMC
random_seed=RANDOM_SEED + trial.number,
progressbar=False
)
# Check convergence - prune if failed
converged, diagnostics = check_convergence(
mmm=mmm,
divergence_threshold=OPTUNA_DIVERGENCE_THRESHOLD,
rhat_threshold=OPTUNA_RHAT_THRESHOLD,
ess_threshold=OPTUNA_ESS_THRESHOLD
)
if not converged:
raise optuna.TrialPruned()
# Compute CRPS on TEST set (the key metric)
test_crps = compute_test_crps(mmm, X_test, y_test)
return test_crpsStep 4: Running the optimization
With the objective function defined, running the optimization is straightforward:
# Create Optuna study
study = optuna.create_study(
study_name="mmm_crps_optimization",
direction="minimize", # Minimize test CRPS
sampler=optuna.samplers.TPESampler(seed=RANDOM_SEED),
pruner=optuna.pruners.MedianPruner(n_startup_trials=5)
)
# Run optimization
study.optimize(
objective,
n_trials=20, # Number of hyperparameter combinations to try
show_progress_bar=True
)
print(f"Best parameters: {study.best_params}")
print(f"Best test CRPS: {study.best_value:.2f}")Step 5: Interpreting the results
Optuna provides visualization tools to understand the optimization process:
From the synthetic data optimization:
- Optimal yearly_seasonality: 3 (different from common defaults)
- Optimal adstock_max_lag: 10 weeks
- Best test CRPS: 294.71
A few things to keep in mind when interpreting these results. These “optimal” parameters are specific to this synthetic dataset and single test period. The difference between the best (294.71) and median (~310) CRPS is modest. In real applications, check whether different hyperparameters lead to meaningfully different business decisions. Differences of around 5 can usually be ignored.
Step 6: Final model with optimal parameters
With optimal hyperparameters identified, I refit on the full dataset with production-quality MCMC settings:
# Create final model with optimal parameters
final_mmm = MMM(
date_column="date",
channel_columns=channel_columns,
control_columns=control_columns,
adstock=GeometricAdstock(l_max=study.best_params["adstock_max_lag"]),
saturation=LogisticSaturation(),
yearly_seasonality=study.best_params["yearly_seasonality"]
)
# Fit with production settings on FULL dataset
final_mmm.fit(
X=X_full,
y=y_full,
draws=2000,
tune=2000,
chains=4,
nuts_sampler="numpyro",
random_seed=RANDOM_SEED
)
# Verify convergence with strict thresholds
converged, diagnostics = check_convergence(
final_mmm,
divergence_threshold=0.01, # <1% divergences
rhat_threshold=1.01, # R-hat very close to 1
ess_threshold=400 # Substantial effective samples
)Validating the optimized model
Prediction quality
The real test of the optimization is how well the model predicts both in-sample and out-of-sample:
A few things stand out:
- The 95% credible intervals capture most actual values.
- Test set predictions maintain appropriate uncertainty.
- No obvious overfitting despite parameter optimization.
Performance metrics
| Dataset | CRPS | Note |
|---|---|---|
| Train | 265.93 | In-sample performance |
| Test | 260.06 | Out-of-sample performance |
The test and training CRPS values are close (a difference of ~5 points on a scale of 260+), which indicates good generalization without overfitting. In practice, CRPS differences less than 5% of the absolute value are often not statistically significant. What matters is that the model performs consistently across both sets, suggesting the hyperparameters aren’t overly tuned to the data.
Practical insights and recommendations
When to use this approach
This CRPS-based optimization works best when:
- You have sufficient data (at least 100+ observations, ideally 2+ years for seasonal patterns).
- Multiple hyperparameters interact. Fourier orders, adstock lags, and saturation parameters have complex relationships with each other.
- You have computational time available. Even with optimization, expect up to several hours for complex models.
- The business context is stable. If your market changes rapidly, optimized parameters may quickly become outdated.
When not to use this approach
Be cautious about hyperparameter optimization when:
- Data is limited (less than 52 weeks). You risk overfitting to noise.
- Market dynamics are changing. Recent disruptions make historical patterns unreliable.
- Quick insights are needed. In Bayesian statistics, “good enough” with defaults often beats perfect after extensive tuning.
- Causal structure is uncertain. Fix your model specification and priors before optimizing hyperparameters.
What I learned from this synthetic data example
Optimal parameters may surprise you. My synthetic data preferred 3 Fourier components over the common default of 7. Don’t assume defaults are optimal.
Longer memory isn’t always better. The optimal 10-week adstock (vs. maximum 12) shows there’s a balance between capturing effects and model complexity.
NumPyro speeds up experimentation significantly. In my tests, NumPyro provided 2–10x speedup compared to PyMC’s default sampler, though actual speedup varies by model complexity.
Convergence monitoring matters. The two-tier approach (lenient for search, strict for final) prevents wasting computation on poorly converged models.
Advanced considerations and limitations
Methodological caveats
This simplified example should be extended for production systems. Here are the main issues to address:
Single test set overfitting. Optimizing on one test set risks selecting hyperparameters that work well for that specific period but not others. Production systems should use rolling window cross-validation, multiple hold-out periods, or business-cycle-aware splits (e.g., always test on Q4 if that’s your critical period).
Stationarity assumptions. This approach assumes the optimal hyperparameters are stable over time. In rapidly evolving markets, you may need time-varying hyperparameters, regular retraining schedules, or monitoring systems to detect when parameters become stale.
Computational cost-benefit. Sometimes the marginal improvement from optimization doesn’t justify the computational cost. Is a 5% CRPS improvement worth 10 hours of computation? Would that time be better spent improving data quality or rethinking the causal structure? Are stakeholders even sensitive enough to notice the improvement?
Scaling and extensions
For production deployments, a few ideas worth exploring:
# Multi-objective optimization
def multi_objective(trial):
# ... fit model ...
# Optimize multiple metrics
test_crps = compute_test_crps(mmm, X_test, y_test)
convergence_quality = 1 / (diagnostics["max_rhat"] - 1 + 1e-6)
return test_crps, convergence_quality
# Distributed optimization
study.optimize(
objective,
n_trials=100,
n_jobs=4 # Parallel trials on multiple cores
)
# Cross-validation for robustness
def objective_with_cv(trial):
crps_scores = []
for fold in time_series_split(df, n_splits=3):
# ... fit and evaluate on each fold ...
crps_scores.append(fold_crps)
return np.mean(crps_scores)Dealing with convergence failures
When trials frequently fail convergence checks, the problem often isn’t the hyperparameters but the model itself:
- Model misspecification. Poor convergence often signals that your model structure is too complex for the data.
- Identification issues. Some parameter combinations may be fundamentally unidentifiable without strong priors.
- Data quality. Outliers or data errors can cause convergence failures regardless of hyperparameters.
- Prior-data conflict. When priors strongly disagree with data, no amount of tuning will help.
Wrapping up
Hyperparameter optimization using CRPS and Optuna gives you a systematic approach to what has traditionally been an ad-hoc process in Bayesian MMM. It’s not a silver bullet, but it helps when applied to the right problems.
The practical takeaways: CRPS is the right metric for evaluating probabilistic forecasts (though small differences may not be meaningful). Bayesian optimization can efficiently explore parameter spaces, but watch out for overfitting to your test set. Proper train/test splits matter, though single splits have limitations compared to cross-validation. And convergence monitoring prevents waste, but persistent failures usually point to model problems, not hyperparameter issues.
This approach works best when you have enough data, stable business conditions, and the computational resources to invest. It’s one tool among many for building robust MMMs, not a replacement for domain expertise and careful model specification.
Next steps
If you decide hyperparameter optimization is right for your use case, here’s how I’d suggest getting started. Begin with just one or two hyperparameters before attempting complex multi-parameter optimization. Use domain knowledge to constrain your search spaces. Implement cross-validation rather than relying on a single train/test split. Monitor whether optimal parameters change over time. And always validate that optimized models lead to better business decisions, not just better metrics.
Hyperparameter optimization is a tool, not a goal. Focus first on getting your causal structure right, ensuring data quality, and understanding your business context. Only then does fine-tuning become worthwhile.
This tutorial demonstrated hyperparameter optimization for Bayesian Media Mix Models using synthetic data. The methodology combines Optuna’s Bayesian optimization with PyMC-Marketing’s MMM implementation, evaluated using CRPS as a proper scoring rule for probabilistic forecasts. All code examples are available in the accompanying notebook.