MMMPlotSuite.sensitivity_analysis#
- MMMPlotSuite.sensitivity_analysis(data=None, hdi_prob=0.94, aggregation=None, backend=None)[source]#
Plot sensitivity analysis results showing response to input changes.
Visualizes how model outputs (e.g., channel contributions) change as inputs (e.g., channel spend) are varied. Shows mean response line and HDI bands across sweep values.
- Parameters:
- data
xr.DataArrayorxr.Dataset, optional Sensitivity analysis data with required dimensions: - ‘sample’: MCMC samples - ‘sweep’: Sweep values (e.g., multipliers)
If Dataset, should contain ‘x’ variable. If None, uses self.idata.sensitivity_analysis. This parameter allows: - Testing with mock sensitivity analysis results - Plotting external sweep results - Comparing different sensitivity analyses
- hdi_prob
float, default 0.94 HDI probability mass (between 0 and 1).
- aggregation
dict, optional Aggregations to apply before plotting. Keys: “sum”, “mean”, or “median” Values: tuple of dimension names
Example:
{"sum": ("channel",)}sums over channels before plotting.- backend
str|None, optional Backend to use for plotting. If None, uses global backend configuration.
- data
- Returns:
PlotCollectionarviz_plots PlotCollection object containing the plot.
Use
.show()to display or.save("filename")to save. Unlike the legacy suite which returned(Figure, Axes)orAxes, this provides a unified interface across all backends.
- Raises:
ValueErrorIf no sensitivity analysis data found in self.idata and no data provided.
See also
uplift_curvePlot uplift percentages (derived from sensitivity analysis)
marginal_curvePlot marginal effects (derived from sensitivity analysis)
LegacyMMMPlotSuite.sensitivity_analysisLegacy matplotlib-only implementation
Notes
Breaking changes from legacy implementation:
Returns PlotCollection instead of (Figure, Axes) or Axes
Lost ax, subplot_kwargs, plot_kwargs parameters (use backend methods)
Cleaner implementation without monkey-patching
Data parameter for explicit data passing (no side effects on self.idata)
Examples
Run sweep and plot results:
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( var_input="channel_data", sweep_values=sweeps, var_names="channel_contribution", sweep_type="multiplicative", extend_idata=True, # Store in idata ) # Plot stored results pc = mmm.plot.sensitivity_analysis(hdi_prob=0.9) pc.show()
Aggregate over channels:
pc = mmm.plot.sensitivity_analysis( hdi_prob=0.9, aggregation={"sum": ("channel",)} ) pc.show()
Use different backend:
pc = mmm.plot.sensitivity_analysis(backend="plotly") pc.show()
Provide explicit data:
external_results = sa.run_sweep(...) # Not stored in idata pc = mmm.plot.sensitivity_analysis(data=external_results) pc.show()