MMMPlotSuite.uplift_curve#
- MMMPlotSuite.uplift_curve(data=None, hdi_prob=0.94, aggregation=None, backend=None)[source]#
Plot uplift curves showing percentage change relative to baseline.
Visualizes relative percentage changes in model outputs (e.g., channel contributions) as inputs are varied, compared to a reference point. Shows mean uplift line and HDI bands.
- Parameters:
- data
xr.DataArrayorxr.Dataset, optional Uplift curve data computed from sensitivity analysis. If Dataset, should contain ‘uplift_curve’ variable. If None, uses self.idata.sensitivity_analysis[‘uplift_curve’].
Must be precomputed using:
SensitivityAnalysis.compute_uplift_curve_respect_to_base(...)This parameter allows: - Testing with mock uplift curve data - Plotting externally computed uplift curves - Comparing uplift curves from different models- hdi_prob
float, default 0.94 HDI probability mass (between 0 and 1).
- aggregation
dict, optional Aggregations to apply before plotting. Keys: “sum”, “mean”, or “median” Values: tuple of dimension names
Example:
{"sum": ("channel",)}sums over channels before plotting.- backend
str|None, optional Backend to use for plotting. If None, uses global backend configuration.
- data
- Returns:
PlotCollectionarviz_plots PlotCollection object containing the plot.
Use
.show()to display or.save("filename")to save. Unlike the legacy suite which returned(Figure, Axes)orAxes, this provides a unified interface across all backends.
- Raises:
ValueErrorIf no uplift curve data found in self.idata and no data provided.
ValueErrorIf ‘uplift_curve’ variable not found in sensitivity_analysis group.
See also
sensitivity_analysisPlot raw sensitivity analysis results
marginal_curvePlot marginal effects (absolute changes)
LegacyMMMPlotSuite.uplift_curveLegacy matplotlib-only implementation
Notes
Breaking changes from legacy implementation:
Returns PlotCollection instead of (Figure, Axes) or Axes
Cleaner implementation without monkey-patching
No longer modifies self.idata.sensitivity_analysis temporarily
Data parameter for explicit data passing
Examples
Compute and plot uplift curve:
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( var_input="channel_data", sweep_values=sweeps, var_names="channel_contribution", sweep_type="multiplicative", ) # Compute uplift relative to baseline (ref=1.0) uplift = sa.compute_uplift_curve_respect_to_base( results, ref=1.0, extend_idata=True, # Store in idata ) # Plot stored uplift curve pc = mmm.plot.uplift_curve(hdi_prob=0.9) pc.show()
Aggregate over channels:
pc = mmm.plot.uplift_curve(aggregation={"sum": ("channel",)}) pc.show()
Use different backend:
pc = mmm.plot.uplift_curve(backend="plotly") pc.show()
Provide explicit data:
uplift_data = sa.compute_uplift_curve_respect_to_base(results, ref=1.0) pc = mmm.plot.uplift_curve(data=uplift_data) pc.show()