# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MMM related plotting class.
Examples
--------
Quickstart with MMM:
.. code-block:: python
from pymc_marketing.mmm import GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
import pandas as pd
# Minimal dataset
X = pd.DataFrame(
{
"date": pd.date_range("2025-01-01", periods=12, freq="W-MON"),
"C1": [100, 120, 90, 110, 105, 115, 98, 102, 108, 111, 97, 109],
"C2": [80, 70, 95, 85, 90, 88, 92, 94, 91, 89, 93, 87],
}
)
y = pd.Series(
[230, 260, 220, 240, 245, 255, 235, 238, 242, 246, 233, 249], name="y"
)
mmm = MMM(
date_column="date",
channel_columns=["C1", "C2"],
target_column="y",
adstock=GeometricAdstock(l_max=10),
saturation=LogisticSaturation(),
)
mmm.fit(X, y)
mmm.sample_posterior_predictive(X)
# Posterior predictive time series
_ = mmm.plot.posterior_predictive(var="y", hdi_prob=0.9)
# Posterior contributions over time (e.g., channel_contribution)
_ = mmm.plot.contributions_over_time(var=["channel_contribution"], hdi_prob=0.9)
# Channel saturation scatter plot (scaled space by default)
_ = mmm.plot.saturation_scatterplot(original_scale=False)
Wrap a custom PyMC model
--------
Requirements
- posterior_predictive plots: an `az.InferenceData` with a `posterior_predictive` group
containing the variable(s) you want to plot with a `date` coordinate.
- contributions_over_time plots: a `posterior` group with time‑series variables (with `date`).
- saturation plots: a `constant_data` dataset with variables:
- `channel_data`: dims include `("date", "channel", ...)`
- `channel_scale`: dims include `("channel", ...)`
- `target_scale`: scalar or broadcastable to the curve dims
and a `posterior` variable named `channel_contribution` (or
`channel_contribution_original_scale` if plotting `original_scale=True`).
.. code-block:: python
import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=30, freq="D")
y_obs = np.random.normal(size=len(dates))
with pm.Model(coords={"date": dates}):
sigma = pm.HalfNormal("sigma", 1.0)
pm.Normal("y", 0.0, sigma, observed=y_obs, dims="date")
idata = pm.sample_prior_predictive(random_seed=1)
idata.extend(pm.sample(draws=200, chains=2, tune=200, random_seed=1))
idata.extend(pm.sample_posterior_predictive(idata, random_seed=1))
plot = MMMPlotSuite(idata)
_ = plot.posterior_predictive(var="y", hdi_prob=0.9)
Custom contributions_over_time
--------
.. code-block:: python
import numpy as np
import pandas as pd
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=30, freq="D")
x = np.linspace(0, 2 * np.pi, len(dates))
series = np.sin(x)
with pm.Model(coords={"date": dates}):
pm.Deterministic("component", series, dims="date")
idata = pm.sample_prior_predictive(random_seed=2)
idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=2))
plot = MMMPlotSuite(idata)
_ = plot.contributions_over_time(var=["component"], hdi_prob=0.9)
Saturation plots with a custom model
--------
.. code-block:: python
import numpy as np
import pandas as pd
import xarray as xr
import pymc as pm
from pymc_marketing.mmm.plot import MMMPlotSuite
dates = pd.date_range("2025-01-01", periods=20, freq="W-MON")
channels = ["C1", "C2"]
# Create constant_data required for saturation plots
channel_data = xr.DataArray(
np.random.rand(len(dates), len(channels)),
dims=("date", "channel"),
coords={"date": dates, "channel": channels},
name="channel_data",
)
channel_scale = xr.DataArray(
np.ones(len(channels)),
dims=("channel",),
coords={"channel": channels},
name="channel_scale",
)
target_scale = xr.DataArray(1.0, name="target_scale")
# Build a toy model that yields a matching posterior var
with pm.Model(coords={"date": dates, "channel": channels}):
# A fake contribution over time per channel (dims must include date & channel)
contrib = pm.Normal("channel_contribution", 0.0, 1.0, dims=("date", "channel"))
idata = pm.sample_prior_predictive(random_seed=3)
idata.extend(pm.sample(draws=50, chains=1, tune=0, random_seed=3))
# Attach constant_data to idata
idata.constant_data = xr.Dataset(
{
"channel_data": channel_data,
"channel_scale": channel_scale,
"target_scale": target_scale,
}
)
plot = MMMPlotSuite(idata)
_ = plot.saturation_scatterplot(original_scale=False)
Notes
-----
- `MMM` exposes this suite via the `mmm.plot` property, which internally passes the model's
`idata` into `MMMPlotSuite`.
- Any PyMC model can use `MMMPlotSuite` directly if its `InferenceData` contains the needed
groups/variables described above.
"""
import itertools
import arviz as az
import arviz_plots as azp
import numpy as np
import xarray as xr
from arviz_base.labels import DimCoordLabeller, NoVarLabeller, mix_labellers
from arviz_plots import PlotCollection
from pymc_marketing.mmm.config import mmm_plot_config
__all__ = ["MMMPlotSuite"]
WIDTH_PER_COL: float = 10.0
HEIGHT_PER_ROW: float = 4.0
[docs]
class MMMPlotSuite:
"""Media Mix Model Plot Suite.
Provides methods for visualizing the posterior predictive distribution,
contributions over time, and saturation curves for a Media Mix Model.
"""
[docs]
def __init__(
self,
idata: xr.Dataset | az.InferenceData,
):
self.idata = idata
def _get_additional_dim_combinations(
self,
data: xr.Dataset,
variable: str,
ignored_dims: set[str],
) -> tuple[list[str], list[tuple]]:
"""Identify dimensions to plot over and get their coordinate combinations."""
if variable not in data:
raise ValueError(f"Variable '{variable}' not found in the dataset.")
all_dims = list(data[variable].dims)
additional_dims = [d for d in all_dims if d not in ignored_dims]
if additional_dims:
additional_coords = [data.coords[d].values for d in additional_dims]
dim_combinations = list(itertools.product(*additional_coords))
else:
# If no extra dims, just treat as a single combination
dim_combinations = [()]
return additional_dims, dim_combinations
def _get_posterior_predictive_data(
self,
idata: xr.Dataset | None,
) -> xr.Dataset:
"""Retrieve the posterior_predictive group from either provided or self.idata."""
if idata is not None:
return idata
# Otherwise, check if self.idata has posterior_predictive
if (
not hasattr(self.idata, "posterior_predictive") # type: ignore
or self.idata.posterior_predictive is None # type: ignore
):
raise ValueError(
"No posterior_predictive data found in 'self.idata'. "
"Please run 'MMM.sample_posterior_predictive()' or provide "
"an external 'idata' argument."
)
return self.idata.posterior_predictive # type: ignore
def _validate_dims(
self,
dims: dict[str, str | int | list],
all_dims: list[str],
) -> None:
"""Validate that provided dims exist in the model's dimensions and values."""
if dims:
for key, val in dims.items():
if key not in all_dims:
raise ValueError(
f"Dimension '{key}' not found in idata dimensions."
)
valid_values = self.idata.posterior.coords[key].values
if isinstance(val, (list, tuple, np.ndarray)):
for v in val:
if v not in valid_values:
raise ValueError(
f"Value '{v}' not found in dimension '{key}'."
)
else:
if val not in valid_values:
raise ValueError(
f"Value '{val}' not found in dimension '{key}'."
)
def _dim_list_handler(
self, dims: dict[str, str | int | list] | None
) -> tuple[list[str], list[tuple]]:
"""Extract keys, values, and all combinations for list-valued dims."""
dims_lists = {
k: v
for k, v in (dims or {}).items()
if isinstance(v, (list, tuple, np.ndarray))
}
if dims_lists:
dims_keys = list(dims_lists.keys())
dims_values = [
v if isinstance(v, (list, tuple, np.ndarray)) else [v]
for v in dims_lists.values()
]
dims_combos = list(itertools.product(*dims_values))
else:
dims_keys = []
dims_combos = [()]
return dims_keys, dims_combos
def _resolve_backend(self, backend: str | None) -> str:
"""Resolve backend parameter to actual backend string."""
return backend or mmm_plot_config["plot.backend"]
def _get_data_or_fallback(
self,
data: xr.Dataset | None,
idata_attr: str,
data_name: str,
) -> xr.Dataset:
"""Get data from parameter or fall back to self.idata attribute.
Parameters
----------
data : xr.Dataset or None
Data provided by user.
idata_attr : str
Attribute name on self.idata to use as fallback (e.g., "posterior").
data_name : str
Human-readable name for error messages (e.g., "posterior data").
Returns
-------
xr.Dataset
The data to use.
Raises
------
ValueError
If data is None and self.idata doesn't have the required attribute.
"""
if data is None:
if not hasattr(self.idata, idata_attr):
raise ValueError(
f"No {data_name} found in 'self.idata' and no 'data' argument provided. "
f"Please ensure 'self.idata' contains a '{idata_attr}' group or provide 'data' explicitly."
)
data = getattr(self.idata, idata_attr)
return data
# ------------------------------------------------------------------------
# Main Plotting Methods
# ------------------------------------------------------------------------
[docs]
def posterior_predictive(
self,
var: str | None = None,
idata: xr.Dataset | None = None,
hdi_prob: float = 0.85,
backend: str | None = None,
) -> PlotCollection:
"""Plot posterior predictive distributions over time.
Visualizes posterior predictive samples as time series, showing the median
line and highest density interval (HDI) bands. Useful for checking model fit
and understanding prediction uncertainty.
Parameters
----------
var : str, optional
Variable name to plot from posterior_predictive group. If None, uses "y".
idata : xr.Dataset, optional
Dataset containing posterior predictive samples with a "date" coordinate.
If None, uses self.idata.posterior_predictive.
This parameter allows:
- Testing with mock data without modifying self.idata
- Plotting external posterior predictive samples
- Comparing different model fits side-by-side
hdi_prob : float, default 0.85
Probability mass for HDI interval (between 0 and 1).
backend : str, optional
Plotting backend to use. Options: "matplotlib", "plotly", "bokeh".
If None, uses global config via mmm_plot_config["plot.backend"].
Default is "matplotlib".
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)``,
this provides a unified interface across all backends.
Raises
------
ValueError
If no posterior_predictive data found in self.idata and no idata provided.
ValueError
If hdi_prob is not between 0 and 1.
See Also
--------
LegacyMMMPlotSuite.posterior_predictive : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes)
- Different interface for saving and displaying plots
Examples
--------
Basic usage:
.. code-block:: python
mmm.sample_posterior_predictive(X)
pc = mmm.plot.posterior_predictive()
pc.show()
Plot with different HDI probability:
.. code-block:: python
pc = mmm.plot.posterior_predictive(hdi_prob=0.94)
pc.show()
Save to file:
.. code-block:: python
pc = mmm.plot.posterior_predictive()
pc.save("posterior_predictive.png")
Use different backend:
.. code-block:: python
pc = mmm.plot.posterior_predictive(backend="plotly")
pc.show()
Provide explicit data:
.. code-block:: python
external_pp = xr.Dataset(...) # Custom posterior predictive
pc = mmm.plot.posterior_predictive(idata=external_pp)
pc.show()
Direct instantiation pattern:
.. code-block:: python
from pymc_marketing.mmm.plot import MMMPlotSuite
mps = MMMPlotSuite(custom_idata)
pc = mps.posterior_predictive()
pc.show()
"""
if not 0 < hdi_prob < 1:
raise ValueError("HDI probability must be between 0 and 1.")
# Resolve backend
backend = self._resolve_backend(backend)
# 1. Retrieve or validate posterior_predictive data
pp_data = self._get_posterior_predictive_data(idata)
# 2. Determine variable to plot
if var is None:
var = "y"
main_var = var
# 3. Identify additional dims & get all combos
ignored_dims = {"chain", "draw", "date", "sample"}
additional_dims, _ = self._get_additional_dim_combinations(
data=pp_data, variable=main_var, ignored_dims=ignored_dims
)
# 4. Prepare subplots
pc = azp.PlotCollection.wrap(
pp_data[main_var].to_dataset(),
cols=additional_dims,
col_wrap=1,
figure_kwargs={
"sharex": True,
},
backend=backend,
)
# plot hdi
hdi = pp_data.azstats.hdi(hdi_prob)
pc.map(
azp.visuals.fill_between_y,
x=pp_data["date"],
y_bottom=hdi.sel(ci_bound="lower"),
y_top=hdi.sel(ci_bound="upper"),
alpha=0.2,
color="C0",
)
# plot median line
pc.map(
azp.visuals.line_xy,
x=pp_data["date"],
y=pp_data.median(dim=["chain", "draw"]),
color="C0",
)
# add labels
pc.map(azp.visuals.labelled_x, text="Date")
pc.map(azp.visuals.labelled_y, text="Posterior Predictive")
pc.map(
azp.visuals.labelled_title,
subset_info=True,
labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(),
)
return pc
[docs]
def contributions_over_time(
self,
var: list[str],
data: xr.Dataset | None = None,
hdi_prob: float = 0.85,
dims: dict[str, str | int | list] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Plot time-series contributions for specified variables.
Visualizes how variables contribute over time, showing the median line and
HDI bands. Useful for understanding channel contributions, intercepts, or
other time-varying effects in your model.
Parameters
----------
var : list of str
Variable names to plot from the posterior group. Must have a "date" dimension.
Examples: ["channel_contribution"], ["intercept"], ["channel_contribution", "intercept"].
data : xr.Dataset, optional
Dataset containing posterior data with variables in `var`.
If None, uses self.idata.posterior.
This parameter allows:
- Testing with mock data without modifying self.idata
- Plotting external results not stored in self.idata
- Comparing different posterior samples side-by-side
- Avoiding unintended side effects on self.idata
hdi_prob : float, default 0.85
Probability mass for HDI interval (between 0 and 1).
dims : dict[str, str | int | list], optional
Dimension filters to apply. Keys are dimension names, values are either:
- Single value: {"country": "US", "user_type": "new"}
- List of values: {"country": ["US", "UK"]}
If provided, only the selected slice(s) will be plotted.
backend : str, optional
Plotting backend to use. Options: "matplotlib", "plotly", "bokeh".
If None, uses global config via mmm_plot_config["plot.backend"].
Default is "matplotlib".
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)``,
this provides a unified interface across all backends.
Raises
------
ValueError
If hdi_prob is not between 0 and 1.
ValueError
If no posterior data found in self.idata and no data argument provided.
ValueError
If any variable in `var` not found in data.
See Also
--------
LegacyMMMPlotSuite.contributions_over_time : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes)
- Variable names must be passed in a list (was already list in legacy)
Examples
--------
Basic usage - plot channel contributions:
.. code-block:: python
mmm.fit(X, y)
pc = mmm.plot.contributions_over_time(var=["channel_contribution"])
pc.show()
Plot multiple variables together:
.. code-block:: python
pc = mmm.plot.contributions_over_time(
var=["channel_contribution", "intercept"]
)
pc.show()
Filter by dimension:
.. code-block:: python
pc = mmm.plot.contributions_over_time(
var=["channel_contribution"], dims={"geo": "US"}
)
pc.show()
Filter with multiple dimension values:
.. code-block:: python
pc = mmm.plot.contributions_over_time(
var=["channel_contribution"], dims={"geo": ["US", "UK"]}
)
pc.show()
Use different backend:
.. code-block:: python
pc = mmm.plot.contributions_over_time(
var=["channel_contribution"], backend="plotly"
)
pc.show()
Provide explicit data (option 1 - via data parameter):
.. code-block:: python
custom_posterior = xr.Dataset(...)
pc = mmm.plot.contributions_over_time(
var=["my_contribution"], data=custom_posterior
)
pc.show()
Provide explicit data (option 2 - direct instantiation):
.. code-block:: python
from pymc_marketing.mmm.plot import MMMPlotSuite
mps = MMMPlotSuite(custom_idata)
pc = mps.contributions_over_time(var=["my_contribution"])
pc.show()
"""
if not 0 < hdi_prob < 1:
raise ValueError("HDI probability must be between 0 and 1.")
# Get data with fallback to self.idata.posterior
data = self._get_data_or_fallback(data, "posterior", "posterior data")
# Validate data has the required variables
missing_vars = [v for v in var if v not in data]
if missing_vars:
raise ValueError(
f"Variables {missing_vars} not found in data. "
f"Available variables: {list(data.data_vars)}"
)
# Resolve backend
backend = self._resolve_backend(backend)
main_var = var[0]
ignored_dims = {"chain", "draw", "date"}
da = data[var]
# Apply dims filtering if provided
if dims:
self._validate_dims(dims, list(da[main_var].dims))
for dim_name, dim_value in dims.items():
if isinstance(dim_value, (list, tuple, np.ndarray)):
da = da.sel({dim_name: dim_value})
else:
da = da.sel({dim_name: dim_value})
additional_dims, _ = self._get_additional_dim_combinations(
data=da, variable=main_var, ignored_dims=ignored_dims
)
# 4. Prepare subplots
pc = azp.PlotCollection.wrap(
da,
cols=additional_dims,
col_wrap=1,
figure_kwargs={
"sharex": True,
},
backend=backend,
)
# plot hdi
hdi = da.azstats.hdi(hdi_prob)
pc.map(
azp.visuals.fill_between_y,
x=da["date"],
y_bottom=hdi.sel(ci_bound="lower"),
y_top=hdi.sel(ci_bound="upper"),
alpha=0.2,
color="C0",
)
# plot median line
pc.map(
azp.visuals.line_xy,
x=da["date"],
y=da.median(dim=["chain", "draw"]),
color="C0",
)
# add labels
pc.map(azp.visuals.labelled_x, text="Date")
pc.map(azp.visuals.labelled_y, text="Posterior Value")
pc.map(
azp.visuals.labelled_title,
subset_info=True,
labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(),
)
return pc
[docs]
def saturation_scatterplot(
self,
original_scale: bool = False,
constant_data: xr.Dataset | None = None,
posterior_data: xr.Dataset | None = None,
dims: dict[str, str | int | list] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Plot saturation scatter plot showing channel spend vs contributions.
Creates scatter plots of actual channel spend (X-axis) against channel
contributions (Y-axis), one subplot per channel. Useful for understanding
the saturation behavior and diminishing returns of each marketing channel.
Parameters
----------
original_scale : bool, default False
Whether to plot in original scale (True) or scaled space (False).
If True, requires channel_contribution_original_scale in posterior.
constant_data : xr.Dataset, optional
Dataset containing constant_data group with required variables:
- 'channel_data': Channel spend data (dims include "date", "channel")
- 'channel_scale': Scaling factor per channel (if original_scale=True)
- 'target_scale': Target scaling factor (if original_scale=True)
If None, uses self.idata.constant_data.
This parameter allows:
- Testing with mock constant data
- Plotting with alternative scaling factors
- Comparing different data scenarios
posterior_data : xr.Dataset, optional
Dataset containing posterior group with channel contribution variables.
Must contain 'channel_contribution' or 'channel_contribution_original_scale'.
If None, uses self.idata.posterior.
This parameter allows:
- Testing with mock posterior samples
- Plotting external inference results
- Comparing different model fits
dims : dict[str, str | int | list], optional
Dimension filters to apply. Examples:
- {"geo": "US"} - Single value
- {"geo": ["US", "UK"]} - Multiple values
If provided, only the selected slice(s) will be plotted.
backend : str, optional
Plotting backend to use. Options: "matplotlib", "plotly", "bokeh".
If None, uses global config via mmm_plot_config["plot.backend"].
Default is "matplotlib".
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)``,
this provides a unified interface across all backends.
Raises
------
ValueError
If required data not found in self.idata and not provided explicitly.
ValueError
If 'channel_data' not found in constant_data.
ValueError
If original_scale=True but channel_contribution_original_scale not in posterior.
See Also
--------
saturation_curves : Add posterior predictive curves to this scatter plot
LegacyMMMPlotSuite.saturation_scatterplot : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes)
- Lost **kwargs for matplotlib customization (use backend-specific methods)
- Different grid layout algorithm
Examples
--------
Basic usage (scaled space):
.. code-block:: python
mmm.fit(X, y)
pc = mmm.plot.saturation_scatterplot()
pc.show()
Plot in original scale:
.. code-block:: python
mmm.add_original_scale_contribution_variable(var=["channel_contribution"])
pc = mmm.plot.saturation_scatterplot(original_scale=True)
pc.show()
Filter by dimension:
.. code-block:: python
pc = mmm.plot.saturation_scatterplot(dims={"geo": "US"})
pc.show()
Use different backend:
.. code-block:: python
pc = mmm.plot.saturation_scatterplot(backend="plotly")
pc.show()
Provide explicit data:
.. code-block:: python
custom_constant = xr.Dataset(...)
custom_posterior = xr.Dataset(...)
pc = mmm.plot.saturation_scatterplot(
constant_data=custom_constant, posterior_data=custom_posterior
)
pc.show()
"""
# Resolve backend
backend = self._resolve_backend(backend)
# Get constant_data and posterior_data with fallback
constant_data = self._get_data_or_fallback(
constant_data, "constant_data", "constant data"
)
posterior_data = self._get_data_or_fallback(
posterior_data, "posterior", "posterior data"
)
# Validate required variables exist
if "channel_data" not in constant_data:
raise ValueError(
"'channel_data' variable not found in constant_data. "
f"Available variables: {list(constant_data.data_vars)}"
)
# Identify additional dimensions beyond 'date' and 'channel'
cdims = constant_data.channel_data.dims
additional_dims = [dim for dim in cdims if dim not in ("date", "channel")]
# Validate dims and remove filtered dims from additional_dims
if dims:
self._validate_dims(dims, list(constant_data.channel_data.dims))
additional_dims = [d for d in additional_dims if d not in dims]
else:
self._validate_dims({}, list(constant_data.channel_data.dims))
channel_contribution = (
"channel_contribution_original_scale"
if original_scale
else "channel_contribution"
)
if channel_contribution not in posterior_data:
raise ValueError(
f"""No posterior.{channel_contribution} data found in posterior_data. \n
Add a original scale deterministic:\n
mmm.add_original_scale_contribution_variable(\n
var=[\n
\"channel_contribution\",\n
...\n
]\n
)\n
"""
)
# Apply dims filtering to channel_data and channel_contribution
channel_data = constant_data.channel_data
channel_contrib = posterior_data[channel_contribution]
if dims:
for dim_name, dim_value in dims.items():
if isinstance(dim_value, (list, tuple, np.ndarray)):
channel_data = channel_data.sel({dim_name: dim_value})
channel_contrib = channel_contrib.sel({dim_name: dim_value})
else:
channel_data = channel_data.sel({dim_name: dim_value})
channel_contrib = channel_contrib.sel({dim_name: dim_value})
pc = azp.PlotCollection.grid(
channel_contrib.mean(dim=["chain", "draw"]).to_dataset(),
cols=additional_dims,
rows=["channel"],
aes={"color": ["channel"]},
backend=backend,
)
pc.map(
azp.visuals.scatter_xy,
x=channel_data,
)
pc.map(azp.visuals.labelled_x, text="Channel Data", ignore_aes={"color"})
pc.map(
azp.visuals.labelled_y, text="Channel Contributions", ignore_aes={"color"}
)
pc.map(
azp.visuals.labelled_title,
subset_info=True,
labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(),
ignore_aes={"color"},
)
return pc
[docs]
def saturation_curves(
self,
curve: xr.DataArray,
original_scale: bool = False,
constant_data: xr.Dataset | None = None,
posterior_data: xr.Dataset | None = None,
n_samples: int = 10,
hdi_probs: float | list[float] | None = None,
random_seed: np.random.Generator | None = None,
dims: dict[str, str | int | list] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Overlay saturation scatter plots with posterior predictive curves and HDI bands.
Builds on saturation_scatterplot() by adding:
- Sample curves from the posterior distribution
- HDI bands showing uncertainty
- Smooth saturation curves over the scatter plot
Parameters
----------
curve : xr.DataArray
Posterior predictive saturation curves with required dimensions:
- "chain", "draw": MCMC samples
- "x": Input values for curve evaluation
- "channel": Channel names
Generate using: ``mmm.saturation.sample_curve(...)``
original_scale : bool, default False
Plot in original scale (True) or scaled space (False).
If True, requires channel_contribution_original_scale in posterior.
constant_data : xr.Dataset, optional
Dataset containing constant_data group. If None, uses self.idata.constant_data.
This parameter allows testing with mock data and plotting alternative scenarios.
posterior_data : xr.Dataset, optional
Dataset containing posterior group. If None, uses self.idata.posterior.
This parameter allows testing with mock posterior samples and comparing model fits.
n_samples : int, default 10
Number of sample curves to draw per subplot.
Set to 0 to show only HDI bands without individual samples.
hdi_probs : float or list of float, optional
HDI probability levels for credible intervals.
Examples: 0.94 (single band), [0.5, 0.94] (multiple bands).
If None, no HDI bands are drawn.
random_seed : np.random.Generator, optional
Random number generator for reproducible curve sampling.
If None, uses ``np.random.default_rng()``.
dims : dict[str, str | int | list], optional
Dimension filters to apply. Examples:
- {"geo": "US"}
- {"geo": ["US", "UK"]}
If provided, only the selected slice(s) will be plotted.
backend : str, optional
Plotting backend to use. Options: "matplotlib", "plotly", "bokeh".
If None, uses global config via mmm_plot_config["plot.backend"].
Default is "matplotlib".
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Raises
------
ValueError
If curve is missing required dimensions ("x" or "channel").
ValueError
If original_scale=True but channel_contribution_original_scale not in posterior.
See Also
--------
saturation_scatterplot : Base scatter plot without curves
LegacyMMMPlotSuite.saturation_curves : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes)
- Lost colors, subplot_kwargs, rc_params parameters
- Different HDI calculation (uses arviz_plots instead of custom)
Examples
--------
Generate and plot saturation curves:
.. code-block:: python
# Generate curves using saturation transformation
curve = mmm.saturation.sample_curve(
idata=mmm.idata.posterior[["saturation_beta", "saturation_lam"]],
max_value=2.0,
)
pc = mmm.plot.saturation_curves(curve)
pc.show()
Add HDI bands:
.. code-block:: python
pc = mmm.plot.saturation_curves(curve, hdi_probs=[0.5, 0.94])
pc.show()
Original scale with custom seed:
.. code-block:: python
import numpy as np
rng = np.random.default_rng(42)
mmm.add_original_scale_contribution_variable(var=["channel_contribution"])
pc = mmm.plot.saturation_curves(
curve, original_scale=True, n_samples=15, random_seed=rng
)
pc.show()
Filter by dimension:
.. code-block:: python
pc = mmm.plot.saturation_curves(curve, dims={"geo": "US"})
pc.show()
"""
# Get constant_data and posterior_data with fallback
constant_data = self._get_data_or_fallback(
constant_data, "constant_data", "constant data"
)
posterior_data = self._get_data_or_fallback(
posterior_data, "posterior", "posterior data"
)
contrib_var = (
"channel_contribution_original_scale"
if original_scale
else "channel_contribution"
)
if original_scale and contrib_var not in posterior_data:
raise ValueError(
f"""No posterior.{contrib_var} data found in posterior_data.\n"
"Add a original scale deterministic:\n"
" mmm.add_original_scale_contribution_variable(\n"
" var=[\n"
" 'channel_contribution',\n"
" ...\n"
" ]\n"
" )\n"
"""
)
# Validate curve dimensions
if "x" not in curve.dims:
raise ValueError("curve must have an 'x' dimension")
if "channel" not in curve.dims:
raise ValueError("curve must have a 'channel' dimension")
if original_scale:
curve_data = curve * constant_data.target_scale
curve_data["x"] = curve_data["x"] * constant_data.channel_scale
else:
curve_data = curve
curve_data = curve_data.rename("saturation_curve")
# — 1. figure out grid shape based on scatter data dimensions / identify dims and combos
cdims = constant_data.channel_data.dims
all_dims = list(cdims)
additional_dims = [d for d in cdims if d not in ("date", "channel")]
# Validate dims and remove filtered dims from additional_dims
if dims:
self._validate_dims(dims, all_dims)
additional_dims = [d for d in additional_dims if d not in dims]
else:
self._validate_dims({}, all_dims)
# create the saturation scatterplot
pc = self.saturation_scatterplot(
original_scale=original_scale,
constant_data=constant_data,
posterior_data=posterior_data,
dims=dims,
backend=backend,
)
# add the hdi bands
if hdi_probs is not None:
# Robustly handle hdi_probs as float, list, tuple, or np.ndarray
if isinstance(hdi_probs, (float, int)):
hdi_probs_iter = [hdi_probs]
elif isinstance(hdi_probs, (list, tuple, np.ndarray)):
hdi_probs_iter = hdi_probs
else:
raise TypeError("hdi_probs must be a float, list, tuple, or np.ndarray")
for hdi_prob in hdi_probs_iter:
hdi = curve_data.azstats.hdi(hdi_prob)
pc.map(
azp.visuals.fill_between_y,
x=curve_data["x"],
y_bottom=hdi.sel(ci_bound="lower"),
y_top=hdi.sel(ci_bound="upper"),
alpha=0.2,
)
if n_samples > 0:
## sample the curves
rng = np.random.default_rng(random_seed)
# Stack the two dimensions
stacked = curve_data.stack(sample=("chain", "draw"))
# Sample from the stacked dimension
idx = rng.choice(stacked.sizes["sample"], size=n_samples, replace=False)
# Select and unstack
sampled_curves = stacked.isel(sample=idx)
# plot the sampled curves
pc.map(
azp.visuals.multiple_lines, x_dim="x", data=sampled_curves, alpha=0.2
)
return pc
[docs]
def budget_allocation_roas(
self,
samples: xr.Dataset,
dims: dict[str, str | int | list] | None = None,
dims_to_group_by: list[str] | str | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Plot ROI (Return on Ad Spend) distributions for budget allocation scenarios.
Visualizes the posterior distribution of ROI for each channel given a budget
allocation. Useful for comparing ROI across channels and understanding
optimization trade-offs.
Parameters
----------
samples : xr.Dataset
Dataset from budget allocation optimization containing:
- 'channel_contribution_original_scale': Channel contributions
- 'allocation': Allocated budget per channel
- 'channel' dimension
Typically obtained from: ``mmm.allocate_budget_to_maximize_response(...)``
dims : dict[str, str | int | list], optional
Dimension filters to apply. Examples:
- {"geo": "US"}
- {"geo": ["US", "UK"]}
If provided, only the selected slice(s) will be plotted.
dims_to_group_by : list[str] | str | None, optional
Dimension(s) to group by for overlaying distributions.
When specified, all ROI distributions for each coordinate of that
dimension will be plotted together for comparison.
- None (default): Each distribution plotted separately
- Single string: Group by that dimension (e.g., "geo")
- List of strings: Group by multiple dimensions (e.g., ["geo", "segment"])
backend : str | None, optional
Backend to use for plotting. If None, uses global backend configuration.
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Raises
------
ValueError
If 'channel' dimension not found in samples.
ValueError
If required variables not found in samples.
See Also
--------
LegacyMMMPlotSuite.budget_allocation : Legacy bar chart method (different purpose)
Notes
-----
This method is NEW in MMMPlotSuite v2 and serves a different purpose
than the legacy ``budget_allocation()`` method:
- **New method** (this): Shows ROI distributions (KDE plots)
- **Legacy method**: Shows bar charts comparing spend vs contributions
To use the legacy method, set: ``mmm_plot_config["plot.use_v2"] = False``
Examples
--------
Basic usage with budget optimization results:
.. code-block:: python
allocation_results = mmm.allocate_budget_to_maximize_response(
total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0}
)
pc = mmm.plot.budget_allocation_roas(allocation_results)
pc.show()
Group by geography to compare ROI across regions:
.. code-block:: python
pc = mmm.plot.budget_allocation_roas(
allocation_results, dims_to_group_by="geo"
)
pc.show()
Filter and group:
.. code-block:: python
pc = mmm.plot.budget_allocation_roas(
allocation_results, dims={"segment": "premium"}, dims_to_group_by="geo"
)
pc.show()
"""
# Get the channels from samples
if "channel" not in samples.dims:
raise ValueError(
"Expected 'channel' dimension in samples dataset, but none found."
)
# Check for required variables in samples
if "channel_contribution_original_scale" not in samples.data_vars:
raise ValueError(
"Expected a variable containing 'channel_contribution_original_scale' in samples, but none found."
)
if "allocation" not in samples:
raise ValueError(
"Expected 'allocation' variable in samples, but none found."
)
# Find the variable containing 'channel_contribution' in its name
channel_contrib_var = "channel_contribution_original_scale"
all_dims = list(samples.dims)
# Validate dims
if dims:
self._validate_dims(dims=dims, all_dims=all_dims)
else:
self._validate_dims({}, all_dims)
channel_contribution = samples[channel_contrib_var].sum(dim="date")
channel_contribution.name = "channel_contribution"
from arviz_base import convert_to_datatree
roa_da = channel_contribution / samples.allocation
roa_dt = convert_to_datatree(roa_da)
if isinstance(dims_to_group_by, str):
dims_to_group_by = [dims_to_group_by]
if dims_to_group_by:
grouped = {"all": roa_dt.copy()}
for dim in dims_to_group_by:
new_grouped = {}
for curr_k, curr_group in grouped.items():
curr_coords = curr_group.posterior.coords[dim].values
new_grouped.update(
{
f"{curr_k}, {dim}: {key}": curr_group.sel({dim: key})
for key in curr_coords
}
)
grouped = new_grouped
grouped_roa_dt = {}
prefix = "all, "
for k, v in grouped.items():
if k.startswith(prefix):
grouped_roa_dt[k[len(prefix) :]] = v
else:
grouped_roa_dt[k] = v
else:
grouped_roa_dt = roa_dt
pc = azp.plot_dist(
grouped_roa_dt,
kind="kde",
sample_dims=["sample"],
backend=backend,
labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(),
)
if dims_to_group_by:
pc.add_legend(dim="model", title="")
return pc
[docs]
def allocated_contribution_by_channel_over_time(
self,
samples: xr.Dataset,
hdi_prob: float = 0.85,
backend: str | None = None,
) -> PlotCollection:
"""Plot channel contributions over time from budget allocation optimization.
Visualizes how contributions from each channel evolve over time given an
optimized budget allocation. Shows mean contribution lines per channel with
HDI uncertainty bands.
Parameters
----------
samples : xr.Dataset
Dataset from budget allocation optimization containing channel
contributions over time. Required dimensions:
- 'channel': Channel names
- 'date': Time dimension
- 'sample': MCMC samples
Required variables:
- Variable containing 'channel_contribution' (e.g., 'channel_contribution'
or 'channel_contribution_original_scale')
Typically obtained from: ``mmm.allocate_budget_to_maximize_response(...)``
hdi_prob : float, default 0.85
Probability mass for HDI interval (between 0 and 1).
backend : str | None, optional
Backend to use for plotting. If None, uses global backend configuration.
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)``,
this provides a unified interface across all backends.
Raises
------
ValueError
If required dimensions ('channel', 'date', 'sample') not found in samples.
ValueError
If no variable containing 'channel_contribution' found in samples.
See Also
--------
budget_allocation_roas : Plot ROI distributions from same allocation results
LegacyMMMPlotSuite.allocated_contribution_by_channel_over_time : Legacy implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes)
- Lost scale_factor, lower_quantile, upper_quantile, figsize, ax parameters
- Now uses HDI instead of quantiles for uncertainty
- Automatic handling of extra dimensions (creates subplots)
Examples
--------
Basic usage with budget optimization results:
.. code-block:: python
allocation_results = mmm.allocate_budget_to_maximize_response(
total_budget=100_000, budget_bounds={"lower": 0.5, "upper": 2.0}
)
pc = mmm.plot.allocated_contribution_by_channel_over_time(
allocation_results
)
pc.show()
Custom HDI probability:
.. code-block:: python
pc = mmm.plot.allocated_contribution_by_channel_over_time(
allocation_results, hdi_prob=0.94
)
pc.show()
Use different backend:
.. code-block:: python
pc = mmm.plot.allocated_contribution_by_channel_over_time(
allocation_results, backend="plotly"
)
pc.show()
"""
# Check for expected dimensions and variables
if "channel" not in samples.dims:
raise ValueError(
"Expected 'channel' dimension in samples dataset, but none found."
)
if "date" not in samples.dims:
raise ValueError(
"Expected 'date' dimension in samples dataset, but none found."
)
if "sample" not in samples.dims:
raise ValueError(
"Expected 'sample' dimension in samples dataset, but none found."
)
# Check if any variable contains channel contributions
if not any(
"channel_contribution" in var_name for var_name in samples.data_vars
):
raise ValueError(
"Expected a variable containing 'channel_contribution' in samples, but none found."
)
# Get channel contributions data
channel_contrib_var = next(
var_name
for var_name in samples.data_vars
if "channel_contribution" in var_name
)
# Identify extra dimensions beyond 'channel', 'date', and 'sample'
all_dims = list(samples[channel_contrib_var].dims)
ignored_dims = {"channel", "date", "sample"}
extra_dims = [dim for dim in all_dims if dim not in ignored_dims]
pc = azp.PlotCollection.wrap(
samples[channel_contrib_var].to_dataset(),
cols=extra_dims,
aes={"color": ["channel"]},
col_wrap=1,
figure_kwargs={
"sharex": True,
},
backend=backend,
)
# plot hdi
hdi = samples[channel_contrib_var].azstats.hdi(hdi_prob, dim="sample")
pc.map(
azp.visuals.fill_between_y,
x=samples[channel_contrib_var]["date"],
y_bottom=hdi.sel(ci_bound="lower"),
y_top=hdi.sel(ci_bound="upper"),
alpha=0.2,
)
# plot mean contribution line
pc.map(
azp.visuals.line_xy,
x=samples[channel_contrib_var]["date"],
y=samples[channel_contrib_var].mean(dim="sample"),
)
pc.map(azp.visuals.labelled_x, text="Date", ignore_aes={"color"})
pc.map(
azp.visuals.labelled_y, text="Channel Contribution", ignore_aes={"color"}
)
pc.map(
azp.visuals.labelled_title,
subset_info=True,
labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(),
ignore_aes={"color"},
)
pc.add_legend(dim="channel")
return pc
def _sensitivity_analysis_plot(
self,
data: xr.DataArray | xr.Dataset,
hdi_prob: float = 0.94,
aggregation: dict[str, tuple[str, ...] | list[str]] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Private helper for plotting sensitivity analysis results.
This is an internal method that performs the core plotting logic for
sensitivity analysis visualizations. Public methods (sensitivity_analysis,
uplift_curve, marginal_curve) handle data retrieval and call this helper.
Parameters
----------
data : xr.DataArray or xr.Dataset
Sensitivity analysis data to plot. Must have required dimensions:
- 'sample': MCMC samples
- 'sweep': Sweep values (e.g., multipliers or input values)
If Dataset, should contain 'x' variable.
IMPORTANT: This parameter is REQUIRED with no fallback to self.idata.
This design maintains separation of concerns - public methods handle
data retrieval, this helper handles pure plotting.
hdi_prob : float, default 0.94
HDI probability mass (between 0 and 1).
aggregation : dict, optional
Aggregations to apply before plotting.
Keys are operations ("sum", "mean", "median"), values are dimension tuples.
Example: {"sum": ("channel",)} sums over the channel dimension.
backend : str | None, optional
Backend to use for plotting. If None, uses global backend configuration.
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Note: Y-axis label is NOT set by this helper. Public methods calling
this helper should set appropriate labels (e.g., "Contribution",
"Uplift (%)", "Marginal Effect").
Raises
------
ValueError
If data is missing required dimensions ('sample', 'sweep').
Notes
-----
Design rationale for REQUIRED data parameter:
- **Separation of concerns**: Public methods handle data location/retrieval
(from self.idata.sensitivity_analysis, self.idata.posterior, etc.),
this helper handles pure visualization logic.
- **Testability**: Easy to test plotting logic with mock data.
- **Cleaner implementation**: No monkey-patching or state manipulation.
- **Flexibility**: Can be reused for different data sources without
coupling to self.idata structure.
This is a PRIVATE method (starts with _) and should not be called directly
by users. Use public methods instead:
- sensitivity_analysis(): General sensitivity analysis plots
- uplift_curve(): Uplift percentage plots
- marginal_curve(): Marginal effects plots
"""
# Handle Dataset or DataArray
x = data["x"] if isinstance(data, xr.Dataset) else data
# Validate dimensions
required_dims = {"sample", "sweep"}
if not required_dims.issubset(set(x.dims)):
raise ValueError(
f"Data must have dimensions {required_dims}, got {set(x.dims)}"
)
# Coerce numeric dtype
try:
x = x.astype(float)
except Exception as err:
import warnings
warnings.warn(
f"Failed to cast sensitivity analysis data to float: {err}",
RuntimeWarning,
stacklevel=2,
)
# Apply aggregations
if aggregation:
for op, dims in aggregation.items():
dims_list = [d for d in dims if d in x.dims]
if not dims_list:
continue
if op == "sum":
x = x.sum(dim=dims_list)
elif op == "mean":
x = x.mean(dim=dims_list)
else:
x = x.median(dim=dims_list)
# Determine plotting dimensions (excluding sample & sweep)
plot_dims = set(x.dims) - {"sample", "sweep"}
pc = azp.PlotCollection.wrap(
x.to_dataset(),
cols=plot_dims,
col_wrap=2,
figure_kwargs={
"sharex": True,
},
backend=backend,
)
# plot hdi
hdi = x.azstats.hdi(hdi_prob, dim="sample")
pc.map(
azp.visuals.fill_between_y,
x=x["sweep"],
y_bottom=hdi.sel(ci_bound="lower"),
y_top=hdi.sel(ci_bound="upper"),
alpha=0.4,
color="C0",
)
# plot aggregated line
pc.map(
azp.visuals.line_xy,
x=x["sweep"],
y=x.mean(dim="sample"),
color="C0",
)
# add labels
pc.map(azp.visuals.labelled_x, text="Sweep")
pc.map(
azp.visuals.labelled_title,
subset_info=True,
labeller=mix_labellers((NoVarLabeller, DimCoordLabeller))(),
)
return pc
[docs]
def sensitivity_analysis(
self,
data: xr.DataArray | xr.Dataset | None = None,
hdi_prob: float = 0.94,
aggregation: dict[str, tuple[str, ...] | list[str]] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Plot sensitivity analysis results showing response to input changes.
Visualizes how model outputs (e.g., channel contributions) change as inputs
(e.g., channel spend) are varied. Shows mean response line and HDI bands
across sweep values.
Parameters
----------
data : xr.DataArray or xr.Dataset, optional
Sensitivity analysis data with required dimensions:
- 'sample': MCMC samples
- 'sweep': Sweep values (e.g., multipliers)
If Dataset, should contain 'x' variable.
If None, uses self.idata.sensitivity_analysis.
This parameter allows:
- Testing with mock sensitivity analysis results
- Plotting external sweep results
- Comparing different sensitivity analyses
hdi_prob : float, default 0.94
HDI probability mass (between 0 and 1).
aggregation : dict, optional
Aggregations to apply before plotting.
Keys: "sum", "mean", or "median"
Values: tuple of dimension names
Example: ``{"sum": ("channel",)}`` sums over channels before plotting.
backend : str | None, optional
Backend to use for plotting. If None, uses global backend configuration.
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``,
this provides a unified interface across all backends.
Raises
------
ValueError
If no sensitivity analysis data found in self.idata and no data provided.
See Also
--------
uplift_curve : Plot uplift percentages (derived from sensitivity analysis)
marginal_curve : Plot marginal effects (derived from sensitivity analysis)
LegacyMMMPlotSuite.sensitivity_analysis : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes) or Axes
- Lost ax, subplot_kwargs, plot_kwargs parameters (use backend methods)
- Cleaner implementation without monkey-patching
- Data parameter for explicit data passing (no side effects on self.idata)
Examples
--------
Run sweep and plot results:
.. code-block:: python
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis
# Run sensitivity sweep
sweeps = np.linspace(0.5, 1.5, 11)
sa = SensitivityAnalysis(mmm.model, mmm.idata)
results = sa.run_sweep(
var_input="channel_data",
sweep_values=sweeps,
var_names="channel_contribution",
sweep_type="multiplicative",
extend_idata=True, # Store in idata
)
# Plot stored results
pc = mmm.plot.sensitivity_analysis(hdi_prob=0.9)
pc.show()
Aggregate over channels:
.. code-block:: python
pc = mmm.plot.sensitivity_analysis(
hdi_prob=0.9, aggregation={"sum": ("channel",)}
)
pc.show()
Use different backend:
.. code-block:: python
pc = mmm.plot.sensitivity_analysis(backend="plotly")
pc.show()
Provide explicit data:
.. code-block:: python
external_results = sa.run_sweep(...) # Not stored in idata
pc = mmm.plot.sensitivity_analysis(data=external_results)
pc.show()
"""
# Retrieve data if not provided
data = self._get_data_or_fallback(
data, "sensitivity_analysis", "sensitivity analysis results"
)
pc = self._sensitivity_analysis_plot(
data=data, hdi_prob=hdi_prob, aggregation=aggregation, backend=backend
)
pc.map(azp.visuals.labelled_y, text="Contribution")
return pc
[docs]
def uplift_curve(
self,
data: xr.DataArray | xr.Dataset | None = None,
hdi_prob: float = 0.94,
aggregation: dict[str, tuple[str, ...] | list[str]] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Plot uplift curves showing percentage change relative to baseline.
Visualizes relative percentage changes in model outputs (e.g., channel
contributions) as inputs are varied, compared to a reference point.
Shows mean uplift line and HDI bands.
Parameters
----------
data : xr.DataArray or xr.Dataset, optional
Uplift curve data computed from sensitivity analysis.
If Dataset, should contain 'uplift_curve' variable.
If None, uses self.idata.sensitivity_analysis['uplift_curve'].
Must be precomputed using:
``SensitivityAnalysis.compute_uplift_curve_respect_to_base(...)``
This parameter allows:
- Testing with mock uplift curve data
- Plotting externally computed uplift curves
- Comparing uplift curves from different models
hdi_prob : float, default 0.94
HDI probability mass (between 0 and 1).
aggregation : dict, optional
Aggregations to apply before plotting.
Keys: "sum", "mean", or "median"
Values: tuple of dimension names
Example: ``{"sum": ("channel",)}`` sums over channels before plotting.
backend : str | None, optional
Backend to use for plotting. If None, uses global backend configuration.
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``,
this provides a unified interface across all backends.
Raises
------
ValueError
If no uplift curve data found in self.idata and no data provided.
ValueError
If 'uplift_curve' variable not found in sensitivity_analysis group.
See Also
--------
sensitivity_analysis : Plot raw sensitivity analysis results
marginal_curve : Plot marginal effects (absolute changes)
LegacyMMMPlotSuite.uplift_curve : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes) or Axes
- Cleaner implementation without monkey-patching
- No longer modifies self.idata.sensitivity_analysis temporarily
- Data parameter for explicit data passing
Examples
--------
Compute and plot uplift curve:
.. code-block:: python
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis
# Run sensitivity sweep
sweeps = np.linspace(0.5, 1.5, 11)
sa = SensitivityAnalysis(mmm.model, mmm.idata)
results = sa.run_sweep(
var_input="channel_data",
sweep_values=sweeps,
var_names="channel_contribution",
sweep_type="multiplicative",
)
# Compute uplift relative to baseline (ref=1.0)
uplift = sa.compute_uplift_curve_respect_to_base(
results,
ref=1.0,
extend_idata=True, # Store in idata
)
# Plot stored uplift curve
pc = mmm.plot.uplift_curve(hdi_prob=0.9)
pc.show()
Aggregate over channels:
.. code-block:: python
pc = mmm.plot.uplift_curve(aggregation={"sum": ("channel",)})
pc.show()
Use different backend:
.. code-block:: python
pc = mmm.plot.uplift_curve(backend="plotly")
pc.show()
Provide explicit data:
.. code-block:: python
uplift_data = sa.compute_uplift_curve_respect_to_base(results, ref=1.0)
pc = mmm.plot.uplift_curve(data=uplift_data)
pc.show()
"""
# Retrieve data if not provided
if data is None:
sa_group = self._get_data_or_fallback(
None, "sensitivity_analysis", "sensitivity analysis results"
)
if isinstance(sa_group, xr.Dataset):
if "uplift_curve" not in sa_group:
raise ValueError(
"Expected 'uplift_curve' in idata.sensitivity_analysis. "
"Use SensitivityAnalysis.compute_uplift_curve_respect_to_base(..., extend_idata=True)."
)
data = sa_group["uplift_curve"]
else:
raise ValueError(
"sensitivity_analysis does not contain 'uplift_curve'. Did you persist it to idata?"
)
# Handle Dataset input
if isinstance(data, xr.Dataset):
if "uplift_curve" in data:
data = data["uplift_curve"]
elif "x" in data:
data = data["x"]
else:
raise ValueError("Dataset must contain 'uplift_curve' or 'x' variable.")
# Call helper with data (no more monkey-patching!)
pc = self._sensitivity_analysis_plot(
data=data,
hdi_prob=hdi_prob,
aggregation=aggregation,
backend=backend,
)
pc.map(azp.visuals.labelled_y, text="Uplift (%)")
return pc
[docs]
def marginal_curve(
self,
data: xr.DataArray | xr.Dataset | None = None,
hdi_prob: float = 0.94,
aggregation: dict[str, tuple[str, ...] | list[str]] | None = None,
backend: str | None = None,
) -> PlotCollection:
"""Plot marginal effects showing absolute rate of change.
Visualizes the instantaneous rate of change (derivative) of model outputs
with respect to inputs. Shows how much output changes per unit change in
input at each sweep value.
Parameters
----------
data : xr.DataArray or xr.Dataset, optional
Marginal effects data computed from sensitivity analysis.
If Dataset, should contain 'marginal_effects' variable.
If None, uses self.idata.sensitivity_analysis['marginal_effects'].
Must be precomputed using:
``SensitivityAnalysis.compute_marginal_effects(...)``
This parameter allows:
- Testing with mock marginal effects data
- Plotting externally computed marginal effects
- Comparing marginal effects from different models
hdi_prob : float, default 0.94
HDI probability mass (between 0 and 1).
aggregation : dict, optional
Aggregations to apply before plotting.
Keys: "sum", "mean", or "median"
Values: tuple of dimension names
Example: ``{"sum": ("channel",)}`` sums over channels before plotting.
backend : str | None, optional
Backend to use for plotting. If None, uses global backend configuration.
Returns
-------
PlotCollection
arviz_plots PlotCollection object containing the plot.
Use ``.show()`` to display or ``.save("filename")`` to save.
Unlike the legacy suite which returned ``(Figure, Axes)`` or ``Axes``,
this provides a unified interface across all backends.
Raises
------
ValueError
If no marginal effects data found in self.idata and no data provided.
ValueError
If 'marginal_effects' variable not found in sensitivity_analysis group.
See Also
--------
sensitivity_analysis : Plot raw sensitivity analysis results
uplift_curve : Plot uplift percentages (relative changes)
LegacyMMMPlotSuite.marginal_curve : Legacy matplotlib-only implementation
Notes
-----
Breaking changes from legacy implementation:
- Returns PlotCollection instead of (Figure, Axes) or Axes
- Cleaner implementation without monkey-patching
- No longer modifies self.idata.sensitivity_analysis temporarily
- Data parameter for explicit data passing
Marginal effects show the **slope** of the sensitivity curve, helping
identify where returns are diminishing most rapidly.
Examples
--------
Compute and plot marginal effects:
.. code-block:: python
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis
# Run sensitivity sweep
sweeps = np.linspace(0.5, 1.5, 11)
sa = SensitivityAnalysis(mmm.model, mmm.idata)
results = sa.run_sweep(
var_input="channel_data",
sweep_values=sweeps,
var_names="channel_contribution",
sweep_type="multiplicative",
)
# Compute marginal effects (derivatives)
me = sa.compute_marginal_effects(
results,
extend_idata=True, # Store in idata
)
# Plot stored marginal effects
pc = mmm.plot.marginal_curve(hdi_prob=0.9)
pc.show()
Aggregate over channels:
.. code-block:: python
pc = mmm.plot.marginal_curve(aggregation={"sum": ("channel",)})
pc.show()
Use different backend:
.. code-block:: python
pc = mmm.plot.marginal_curve(backend="plotly")
pc.show()
Provide explicit data:
.. code-block:: python
marginal_data = sa.compute_marginal_effects(results)
pc = mmm.plot.marginal_curve(data=marginal_data)
pc.show()
"""
# Retrieve data if not provided
if data is None:
sa_group = self._get_data_or_fallback(
None, "sensitivity_analysis", "sensitivity analysis results"
)
if isinstance(sa_group, xr.Dataset):
if "marginal_effects" not in sa_group:
raise ValueError(
"Expected 'marginal_effects' in idata.sensitivity_analysis. "
"Use SensitivityAnalysis.compute_marginal_effects(..., extend_idata=True)."
)
data = sa_group["marginal_effects"]
else:
raise ValueError(
"sensitivity_analysis does not contain 'marginal_effects'. Did you persist it to idata?"
)
# Handle Dataset input
if isinstance(data, xr.Dataset):
if "marginal_effects" in data:
data = data["marginal_effects"]
elif "x" in data:
data = data["x"]
else:
raise ValueError(
"Dataset must contain 'marginal_effects' or 'x' variable."
)
# Call helper with data (no more monkey-patching!)
pc = self._sensitivity_analysis_plot(
data=data,
hdi_prob=hdi_prob,
aggregation=aggregation,
backend=backend,
)
pc.map(azp.visuals.labelled_y, text="Marginal Effect")
return pc
[docs]
def budget_allocation(self, *args, **kwargs):
"""
Create bar chart comparing allocated spend and channel contributions.
.. deprecated:: 0.18.0
This method was removed in MMMPlotSuite v2. The arviz_plots library
used in v2 doesn't support this specific chart type. See alternatives below.
Raises
------
NotImplementedError
This method is not available in MMMPlotSuite v2.
Notes
-----
Alternatives:
1. **For ROI distributions**: Use :meth:`budget_allocation_roas`
(different purpose but related to budget allocation)
2. **To use the old method**: Switch to legacy suite:
.. code-block:: python
from pymc_marketing.mmm import mmm_plot_config
mmm_plot_config["plot.use_v2"] = False
mmm.plot.budget_allocation(samples)
3. **Custom implementation**: Create bar chart using samples data:
.. code-block:: python
import matplotlib.pyplot as plt
channel_contrib = samples["channel_contribution"].mean(...)
allocated_spend = samples["allocation"]
# Create custom bar chart with matplotlib
See Also
--------
budget_allocation_roas : Plot ROI distributions by channel
Examples
--------
Use legacy suite temporarily:
.. code-block:: python
from pymc_marketing.mmm import mmm_plot_config
original = mmm_plot_config.get("plot.use_v2")
try:
mmm_plot_config["plot.use_v2"] = False
fig, ax = mmm.plot.budget_allocation(samples)
fig.savefig("budget.png")
finally:
mmm_plot_config["plot.use_v2"] = original
"""
raise NotImplementedError(
"budget_allocation() was removed in MMMPlotSuite v2.\n\n"
"The new arviz_plots-based implementation doesn't support this chart type.\n\n"
"Alternatives:\n"
" 1. For ROI distributions: use budget_allocation_roas()\n"
" 2. To use old method: set mmm_plot_config['plot.use_v2'] = False\n"
" 3. Implement custom bar chart using the samples data\n\n"
"See documentation: https://docs.pymc-marketing.io/en/latest/mmm/plotting_migration.html#budget-allocation"
)