MMMPlotSuite.saturation_curves#

MMMPlotSuite.saturation_curves(curve, original_scale=False, constant_data=None, posterior_data=None, n_samples=10, hdi_probs=None, random_seed=None, dims=None, backend=None)[source]#

Overlay saturation scatter plots with posterior predictive curves and HDI bands.

Builds on saturation_scatterplot() by adding: - Sample curves from the posterior distribution - HDI bands showing uncertainty - Smooth saturation curves over the scatter plot

Parameters:
curvexr.DataArray

Posterior predictive saturation curves with required dimensions: - “chain”, “draw”: MCMC samples - “x”: Input values for curve evaluation - “channel”: Channel names

Generate using: mmm.saturation.sample_curve(...)

original_scalebool, default False

Plot in original scale (True) or scaled space (False). If True, requires channel_contribution_original_scale in posterior.

constant_dataxr.Dataset, optional

Dataset containing constant_data group. If None, uses self.idata.constant_data. This parameter allows testing with mock data and plotting alternative scenarios.

posterior_dataxr.Dataset, optional

Dataset containing posterior group. If None, uses self.idata.posterior. This parameter allows testing with mock posterior samples and comparing model fits.

n_samplesint, default 10

Number of sample curves to draw per subplot. Set to 0 to show only HDI bands without individual samples.

hdi_probsfloat or list of float, optional

HDI probability levels for credible intervals. Examples: 0.94 (single band), [0.5, 0.94] (multiple bands). If None, no HDI bands are drawn.

random_seednp.random.Generator, optional

Random number generator for reproducible curve sampling. If None, uses np.random.default_rng().

dimsdict[str, str | int | list], optional

Dimension filters to apply. Examples: - {“geo”: “US”} - {“geo”: [“US”, “UK”]}

If provided, only the selected slice(s) will be plotted.

backendstr, optional

Plotting backend to use. Options: “matplotlib”, “plotly”, “bokeh”. If None, uses global config via mmm_plot_config[“plot.backend”]. Default is “matplotlib”.

Returns:
PlotCollection

arviz_plots PlotCollection object containing the plot.

Use .show() to display or .save("filename") to save.

Raises:
ValueError

If curve is missing required dimensions (“x” or “channel”).

ValueError

If original_scale=True but channel_contribution_original_scale not in posterior.

See also

saturation_scatterplot

Base scatter plot without curves

LegacyMMMPlotSuite.saturation_curves

Legacy matplotlib-only implementation

Notes

Breaking changes from legacy implementation:

  • Returns PlotCollection instead of (Figure, Axes)

  • Lost colors, subplot_kwargs, rc_params parameters

  • Different HDI calculation (uses arviz_plots instead of custom)

Examples

Generate and plot saturation curves:

# Generate curves using saturation transformation
curve = mmm.saturation.sample_curve(
    idata=mmm.idata.posterior[["saturation_beta", "saturation_lam"]],
    max_value=2.0,
)
pc = mmm.plot.saturation_curves(curve)
pc.show()

Add HDI bands:

pc = mmm.plot.saturation_curves(curve, hdi_probs=[0.5, 0.94])
pc.show()

Original scale with custom seed:

import numpy as np

rng = np.random.default_rng(42)
mmm.add_original_scale_contribution_variable(var=["channel_contribution"])
pc = mmm.plot.saturation_curves(
    curve, original_scale=True, n_samples=15, random_seed=rng
)
pc.show()

Filter by dimension:

pc = mmm.plot.saturation_curves(curve, dims={"geo": "US"})
pc.show()