Time-Slice-Cross-Validation and Parameter Stability#

In this notebook we will illustrate how to perform time-slice cross validation for a media mix model. This is an important step to evaluate the stability and quality of the model. We not only look into out of sample predictions but also the stability of the model parameters.

These imports and configurations form the fundamental setup necessary for the entire span of this notebook.

The expectation is that a model has already been trained using the functionalities provided in prior versions of the PyMC-Marketing library. Thus, the data generation and training processes will be replicated in a different notebook. Those unfamiliar with these procedures are advised to refer to the “MMM Example Notebook.”

Prepare Notebook#

import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from pymc_marketing.mmm.time_slice_cross_validation import TimeSliceCrossValidator
from pymc_marketing.paths import data_dir

warnings.simplefilter(action="ignore", category=FutureWarning)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"


%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/pytensor_utils.py:34: FutureWarning: `pytensor.graph.basic.ancestors` was moved to `pytensor.graph.traversal.ancestors`. Calling it from the old location will fail in a future release.
  from pytensor.graph.basic import ancestors
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/multidimensional.py:216: FutureWarning: This functionality is experimental and subject to change. If you encounter any issues or have suggestions, please raise them at: https://github.com/pymc-labs/pymc-marketing/issues/new
  warnings.warn(warning_msg, FutureWarning, stacklevel=1)
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/time_slice_cross_validation.py:32: UserWarning: The pymc_marketing.mmm.builders module is experimental and its API may change without warning.
  from pymc_marketing.mmm.builders.yaml import build_mmm_from_yaml
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

Loading Data#

Here we will load our geo level dataset. This will then be used within our Time-Slice CV steps.

data_path = data_dir / "multidimensional_mock_data.csv"
data_df = pd.read_csv(data_path, parse_dates=["date"], index_col=0)
data_df.head()
date y x1 x2 event_1 event_2 dayofyear t geo
0 2018-04-02 3984.662237 159.290009 0.0 0.0 0.0 92 0 geo_a
1 2018-04-09 3762.871794 56.194238 0.0 0.0 0.0 99 1 geo_a
2 2018-04-16 4466.967388 146.200133 0.0 0.0 0.0 106 2 geo_a
3 2018-04-23 3864.219373 35.699276 0.0 0.0 0.0 113 3 geo_a
4 2018-04-30 4441.625278 193.372577 0.0 0.0 0.0 120 4 geo_a
X = data_df.drop(columns=["y"])
y = data_df["y"]

Specify Time-Slice-Cross-Validation Strategy#

The main idea of the time-slice cross validation process is to fit the model on a time slice of the data and then evaluate it on the next time slice. We repeat this process for each time slice of the data. As we want to simulate a production-like environment where we enlarge our training data over time, we make the time-slice size grow over time.

Data Leakage

It is very important to avoid data leakage when performing time-slice cross validation. This means that the model should not see any training data from the future. This also includes any data pre-processing steps!

For example, as mentioned above, we need to compute the costs share for each training time slice independently if we want to avoid data leakage. Other sources of data leakage include using a global feature for thr trend component. In our case, we simply use an increasing variable t so we are safe as we just increase it by one for each time slice.

Run Time-Slice-Cross-Validation Loop#

Depending on the business requirements, we need to decide the initial number of observations to use for fitting the model (n_init) and the forecast horizon (forecast_horizon). For this example, we use the first 342 observations to fit the model and then predict the next 12 observations (3 months).

# Initialize cross-validator
cv = TimeSliceCrossValidator(
    n_init=163,
    forecast_horizon=12,
    date_column="date",
    step_size=1,
)
# We can check how many splits we will have
# As a reference, the number of splits is computed as:
# n_iterations = y.size - n_init - forecast_horizon + 1
n_splits = cv.get_n_splits(X, y)
print(f"Number of splits: {n_splits}")
Number of splits: 5

Let’s run it!

For more details on the build_mmm_from_yaml, consult the pymc-marketing documentation on Model Deployment.

Alternatively, load a model that has been saved to MLflow via pymc_marketing.mlflow.log_inference_data or has been autologged to MLflow via pymc_marketing.mlflow.autolog(log_mmm=True), from the PyMC-Marketing MLflow module.

results = cv.run(
    X,
    y,
    # You can also pass sampler_config here to speed things up
    sampler_config={
        "tune": 1_000,
        "draws": 1_000,
        "chains": 4,
        "random_seed": seed,
        "target_accept": 0.90,
        "nuts_sampler": "numpyro",
    },
    yaml_path=data_dir / "config_files" / "multi_dimensional_example_model.yml",
)

Sampling: [y]

Sampling: [y]

Sampling: [y]

Sampling: [y]

Sampling: [y]
# We can view the cross-validation results!
# The CV object is an instance of ArviZ InferenceData
results
arviz.InferenceData
    • <xarray.Dataset> Size: 700MB
      Dimensions:                                  (cv: 5, chain: 4, draw: 1000,
                                                    channel: 2, changepoint: 5,
                                                    geo: 2, control: 2,
                                                    fourier_mode: 4, date: 167)
      Coordinates:
        * cv                                       (cv) object 40B 'Iteration 0' .....
        * chain                                    (chain) int64 32B 0 1 2 3
        * draw                                     (draw) int64 8kB 0 1 2 ... 998 999
        * channel                                  (channel) <U2 16B 'x1' 'x2'
        * changepoint                              (changepoint) int64 40B 0 1 2 3 4
        * geo                                      (geo) <U5 40B 'geo_a' 'geo_b'
        * control                                  (control) <U7 56B 'event_1' 'eve...
        * fourier_mode                             (fourier_mode) <U5 80B 'sin_1' ....
        * date                                     (date) datetime64[ns] 1kB 2018-0...
      Data variables: (12/20)
          adstock_alpha                            (cv, chain, draw, channel) float64 320kB ...
          delta                                    (cv, chain, draw, changepoint, geo) float64 2MB ...
          delta_b                                  (cv, chain, draw) float64 160kB ...
          gamma_control                            (cv, chain, draw, control) float64 320kB ...
          gamma_fourier                            (cv, chain, draw, geo, fourier_mode) float64 1MB ...
          gamma_fourier_b                          (cv, chain, draw) float64 160kB ...
          ...                                       ...
          fourier_contribution                     (cv, chain, draw, date, geo, fourier_mode) float64 214MB ...
          intercept_contribution_original_scale    (cv, chain, draw, geo) float64 320kB ...
          total_media_contribution_original_scale  (cv, chain, draw) float64 160kB ...
          trend_effect_contribution                (cv, chain, draw, date, geo) float64 53MB ...
          y_original_scale                         (cv, chain, draw, date, geo) float64 53MB ...
          yearly_seasonality_contribution          (cv, chain, draw, date, geo) float64 53MB ...
      Attributes:
          created_at:                 2025-12-20T13:44:59.248841+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              7.597176
          tuning_steps:               1000
          pymc_marketing_version:     0.17.1

    • <xarray.Dataset> Size: 115MB
      Dimensions:           (cv: 5, chain: 4, draw: 1000, date: 179, geo: 2)
      Coordinates:
        * cv                (cv) object 40B 'Iteration 0' ... 'Iteration 4'
        * chain             (chain) int64 32B 0 1 2 3
        * draw              (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
        * date              (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
        * geo               (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y                 (cv, chain, draw, date, geo) float64 57MB 0.4308 ... 0....
          y_original_scale  (cv, chain, draw, date, geo) float64 57MB 3.581e+03 ......
      Attributes:
          created_at:                 2025-12-20T13:45:01.510622+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1

    • <xarray.Dataset> Size: 988kB
      Dimensions:          (cv: 5, chain: 4, draw: 1000)
      Coordinates:
        * cv               (cv) object 40B 'Iteration 0' ... 'Iteration 4'
        * chain            (chain) int64 32B 0 1 2 3
        * draw             (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999
      Data variables:
          acceptance_rate  (cv, chain, draw) float64 160kB 0.9825 0.989 ... 0.9515 1.0
          diverging        (cv, chain, draw) bool 20kB False False ... False False
          energy           (cv, chain, draw) float64 160kB -506.9 -507.3 ... -528.7
          lp               (cv, chain, draw) float64 160kB -527.1 -523.7 ... -544.4
          n_steps          (cv, chain, draw) int64 160kB 63 63 63 63 ... 63 63 63 63
          step_size        (cv, chain, draw) float64 160kB 0.04739 0.04739 ... 0.05671
          tree_depth       (cv, chain, draw) int64 160kB 6 6 6 6 6 6 6 ... 7 6 6 6 6 6
      Attributes:
          created_at:     2025-12-20T13:44:59.257358+00:00
          arviz_version:  0.22.0

    • <xarray.Dataset> Size: 231MB
      Dimensions:                                         (cv: 5, chain: 1,
                                                           draw: 1000, date: 179,
                                                           geo: 2, control: 2,
                                                           fourier_mode: 4,
                                                           channel: 2, changepoint: 5)
      Coordinates:
        * cv                                              (cv) object 40B 'Iteratio...
        * chain                                           (chain) int64 8B 0
        * draw                                            (draw) int64 8kB 0 1 ... 999
        * date                                            (date) datetime64[ns] 1kB ...
        * geo                                             (geo) <U5 40B 'geo_a' 'ge...
        * control                                         (control) <U7 56B 'event_...
        * fourier_mode                                    (fourier_mode) <U5 80B 's...
        * channel                                         (channel) <U2 16B 'x1' 'x2'
        * changepoint                                     (changepoint) int64 40B 0...
      Data variables: (12/22)
          y_original_scale                                (cv, chain, draw, date, geo) float64 14MB ...
          intercept_contribution                          (cv, chain, draw, geo) float64 80kB ...
          y_sigma                                         (cv, chain, draw) float64 40kB ...
          control_contribution                            (cv, chain, draw, date, geo, control) float64 29MB ...
          yearly_seasonality_contribution_original_scale  (cv, chain, draw, date, geo) float64 14MB ...
          total_media_contribution_original_scale         (cv, chain, draw) float64 40kB ...
          ...                                              ...
          control_contribution_original_scale             (cv, chain, draw, date, geo, control) float64 29MB ...
          channel_contribution                            (cv, chain, draw, date, geo, channel) float64 29MB ...
          yearly_seasonality_contribution                 (cv, chain, draw, date, geo) float64 14MB ...
          intercept_contribution_original_scale           (cv, chain, draw, geo) float64 80kB ...
          delta                                           (cv, chain, draw, changepoint, geo) float64 400kB ...
          delta_b                                         (cv, chain, draw) float64 40kB ...
      Attributes:
          created_at:                 2025-07-26T08:20:31.433730+00:00
          arviz_version:              0.21.0
          inference_library:          pymc
          inference_library_version:  5.25.1
          pymc_marketing_version:     0.15.1

    • <xarray.Dataset> Size: 14MB
      Dimensions:  (cv: 5, chain: 1, draw: 1000, date: 179, geo: 2)
      Coordinates:
        * cv       (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * chain    (chain) int64 8B 0
        * draw     (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-08-30
        * geo      (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y        (cv, chain, draw, date, geo) float64 14MB 2.658 2.098 ... 2.466
      Attributes:
          created_at:                 2025-07-26T08:20:31.438500+00:00
          arviz_version:              0.21.0
          inference_library:          pymc
          inference_library_version:  5.25.1
          pymc_marketing_version:     0.15.1

    • <xarray.Dataset> Size: 15kB
      Dimensions:  (cv: 5, date: 167, geo: 2)
      Coordinates:
        * cv       (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date     (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-06-07
        * geo      (geo) <U5 40B 'geo_a' 'geo_b'
      Data variables:
          y        (cv, date, geo) float64 13kB 0.4794 0.5206 0.4527 ... 0.6063 0.5798
      Attributes:
          created_at:                 2025-12-20T13:44:59.258392+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              7.597176
          tuning_steps:               1000

    • <xarray.Dataset> Size: 82kB
      Dimensions:        (cv: 5, date: 167, geo: 2, channel: 2, control: 2)
      Coordinates:
        * cv             (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date           (date) datetime64[ns] 1kB 2018-04-02 ... 2021-06-07
        * geo            (geo) <U5 40B 'geo_a' 'geo_b'
        * channel        (channel) <U2 16B 'x1' 'x2'
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_data   (cv, date, geo, channel) float64 27kB 159.3 0.0 ... 72.29 0.0
          channel_scale  (cv, geo, channel) float64 160B 498.3 497.2 ... 498.3 497.2
          control_data   (cv, date, geo, control) float64 27kB 0.0 0.0 0.0 ... 0.0 0.0
          dayofyear      (cv, date) float64 7kB 92.0 99.0 106.0 ... 144.0 151.0 158.0
          target_data    (cv, date, geo) float64 13kB 3.985e+03 ... 4.894e+03
          target_scale   (cv, geo) float64 80B 8.312e+03 8.441e+03 ... 8.441e+03
          trend_t        (cv, date) float64 7kB 0.0 7.0 14.0 ... 1.155e+03 1.162e+03
      Attributes:
          created_at:                 2025-12-20T13:44:59.260742+00:00
          arviz_version:              0.22.0
          inference_library:          numpyro
          inference_library_version:  0.19.0
          sampling_time:              7.597176
          tuning_steps:               1000

    • <xarray.Dataset> Size: 95kB
      Dimensions:    (cv: 5, date: 167, geo: 2)
      Coordinates:
        * cv         (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date       (date) datetime64[ns] 1kB 2018-04-02 2018-04-09 ... 2021-06-07
        * geo        (geo) object 16B 'geo_a' 'geo_b'
      Data variables:
          x1         (cv, date, geo) float64 13kB 159.3 159.3 56.19 ... 72.29 72.29
          x2         (cv, date, geo) float64 13kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_1    (cv, date, geo) float64 13kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          event_2    (cv, date, geo) float64 13kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
          dayofyear  (cv, date, geo) float64 13kB 92.0 92.0 99.0 ... 151.0 158.0 158.0
          t          (cv, date, geo) float64 13kB 0.0 0.0 1.0 ... 165.0 166.0 166.0
          y          (cv, date, geo) float64 13kB 3.985e+03 4.395e+03 ... 4.894e+03

    • <xarray.Dataset> Size: 88kB
      Dimensions:        (cv: 5, date: 179, geo: 2, channel: 2, control: 2)
      Coordinates:
        * cv             (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
        * date           (date) datetime64[ns] 1kB 2018-04-02 ... 2021-08-30
        * geo            (geo) <U5 40B 'geo_a' 'geo_b'
        * channel        (channel) <U2 16B 'x1' 'x2'
        * control        (control) <U7 56B 'event_1' 'event_2'
      Data variables:
          channel_data   (cv, date, geo, channel) float64 29kB 159.3 0.0 ... 219.4 0.0
          channel_scale  (cv, geo, channel) float64 160B 498.3 497.2 ... 498.3 497.2
          control_data   (cv, date, geo, control) float64 29kB 0.0 0.0 0.0 ... 0.0 0.0
          dayofyear      (cv, date) float64 7kB 92.0 99.0 106.0 ... 228.0 235.0 242.0
          target_data    (cv, date, geo) float64 14kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0
          target_scale   (cv, geo) float64 80B 8.312e+03 8.441e+03 ... 8.441e+03
          trend_t        (cv, date) float64 7kB 0.0 7.0 14.0 ... 1.239e+03 1.246e+03
      Attributes:
          created_at:                 2025-12-20T13:45:01.515695+00:00
          arviz_version:              0.22.0
          inference_library:          pymc
          inference_library_version:  5.26.1

    • <xarray.Dataset> Size: 80B
      Dimensions:   (cv: 5)
      Coordinates:
        * cv        (cv) object 40B 'Iteration 0' 'Iteration 1' ... 'Iteration 4'
      Data variables:
          metadata  (cv) object 40B {'X_train':           date          x1         ...

Model Diagnostics#

First, we evaluate whether we have any divergences in the model (we can extend the analysis more more model diagnostics).

# Let's check if there are any divergences
diverging_count = int(results.sample_stats["diverging"].values.sum())
print("Diverging transitions:", diverging_count)
Diverging transitions: 0

We have no divergences in the model 😃!

Evaluate Parameter Stability#

Next, we look at the stability of the model parameters. For a good model, these should not change abruptly over time.

  • Adstock Alpha

cv.plot.param_stability(
    results=results,
    parameter=["adstock_alpha"],
    dims={"geo": ["geo_a"]},
);
  • Saturation Beta

cv.plot.param_stability(
    results,
    parameter=["saturation_beta"],
    dims={"geo": ["geo_a", "geo_b"]},
);
  • Saturation Lambda

cv.plot.param_stability(
    results,
    parameter=["saturation_lam"],
    # dims={"geo": ["geo_a", "geo_b"]}
);

The parameters seem to be stable over time. This implies that the estimates ROAS will not change abruptly over time.

Evaluate Out of Sample Predictions#

Finally, we evaluate the out of sample predictions. To begin with, we can simply plot the posterior predictive distributions for each iteration for both the training and test data.

# Plot model predictions across time slices
cv.plot.cv_predictions(
    results,
    # dims={"geo": ["geo_b"]} # to plot specific dimensions only
);
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/plot.py:3259: UserWarning: The figure layout has changed to tight
  plt.tight_layout(rect=[0, 0.07, 1, 1])
../../_images/b4f2289c40830567b358077cabb7f85f2b246f57b758d4904e9504ea7cb7584c.png

Overall, the out of sample predictions look very good 🚀!

We can quantify the model performance using the Continuous Ranked Probability Score (CRPS).

“The CRPS — Continuous Ranked Probability Score — is a score function that compares a single ground truth value to a Cumulative Distribution Function. It can be used as a metric to evaluate a model’s performance when the target variable is continuous and the model predicts the target’s distribution; Examples include Bayesian Regression or Bayesian Time Series models.”

For a nice explanation of the CRPS, check out this blog post.

In PyMC-Marketing, we provide the function crps to compute this metric. We can use it to compute the CRPS score for each iteration.

# Compute the CRPS score for each iteration and plot!
cv.plot.cv_crps(
    results,
    # dims={"geo": ["geo_b"]} # to plot specific dimensions only
);
/Users/juanitorduz/Documents/pymc-marketing/pymc_marketing/mmm/plot.py:3780: UserWarning: The figure layout has changed to tight
  plt.tight_layout()
../../_images/828546b94a677d35803f1dbe7b85050f9f017c941cde8ea07f9b655a7dce6bbf.png

Event though the visual results look great, we see that the CRPS mildly decreases for the training data while it increases for the test data as we increase the size of the training data. This is a sign that we are overfitting the model to the training data. Some strategies to overcome this issue include using regularization techniques and re-evaluate the model specification. This should be an iterative process.

%load_ext watermark
%watermark -n -u -v -iv -w -p pymc_marketing,pytensor,numpyro
Last updated: Sat Dec 20 2025

Python implementation: CPython
Python version       : 3.13.11
IPython version      : 9.8.0

pymc_marketing: 0.17.1
pytensor      : 2.35.1
numpyro       : 0.19.0

arviz         : 0.22.0
matplotlib    : 3.10.8
numpy         : 2.3.5
pandas        : 2.3.3
pymc_marketing: 0.17.1

Watermark: 2.5.1