MMMPlotSuite.uplift_curve#

MMMPlotSuite.uplift_curve(data=None, hdi_prob=0.94, aggregation=None, backend=None)[source]#

Plot uplift curves showing percentage change relative to baseline.

Visualizes relative percentage changes in model outputs (e.g., channel contributions) as inputs are varied, compared to a reference point. Shows mean uplift line and HDI bands.

Parameters:
dataxr.DataArray or xr.Dataset, optional

Uplift curve data computed from sensitivity analysis. If Dataset, should contain ‘uplift_curve’ variable. If None, uses self.idata.sensitivity_analysis[‘uplift_curve’].

Must be precomputed using: SensitivityAnalysis.compute_uplift_curve_respect_to_base(...) This parameter allows: - Testing with mock uplift curve data - Plotting externally computed uplift curves - Comparing uplift curves from different models

hdi_probfloat, default 0.94

HDI probability mass (between 0 and 1).

aggregationdict, optional

Aggregations to apply before plotting. Keys: “sum”, “mean”, or “median” Values: tuple of dimension names

Example: {"sum": ("channel",)} sums over channels before plotting.

backendstr | None, optional

Backend to use for plotting. If None, uses global backend configuration.

Returns:
PlotCollection

arviz_plots PlotCollection object containing the plot.

Use .show() to display or .save("filename") to save. Unlike the legacy suite which returned (Figure, Axes) or Axes, this provides a unified interface across all backends.

Raises:
ValueError

If no uplift curve data found in self.idata and no data provided.

ValueError

If ‘uplift_curve’ variable not found in sensitivity_analysis group.

See also

sensitivity_analysis

Plot raw sensitivity analysis results

marginal_curve

Plot marginal effects (absolute changes)

LegacyMMMPlotSuite.uplift_curve

Legacy matplotlib-only implementation

Notes

Breaking changes from legacy implementation:

  • Returns PlotCollection instead of (Figure, Axes) or Axes

  • Cleaner implementation without monkey-patching

  • No longer modifies self.idata.sensitivity_analysis temporarily

  • Data parameter for explicit data passing

Examples

Compute and plot uplift curve:

from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis

# Run sensitivity sweep
sweeps = np.linspace(0.5, 1.5, 11)
sa = SensitivityAnalysis(mmm.model, mmm.idata)
results = sa.run_sweep(
    var_input="channel_data",
    sweep_values=sweeps,
    var_names="channel_contribution",
    sweep_type="multiplicative",
)

# Compute uplift relative to baseline (ref=1.0)
uplift = sa.compute_uplift_curve_respect_to_base(
    results,
    ref=1.0,
    extend_idata=True,  # Store in idata
)

# Plot stored uplift curve
pc = mmm.plot.uplift_curve(hdi_prob=0.9)
pc.show()

Aggregate over channels:

pc = mmm.plot.uplift_curve(aggregation={"sum": ("channel",)})
pc.show()

Use different backend:

pc = mmm.plot.uplift_curve(backend="plotly")
pc.show()

Provide explicit data:

uplift_data = sa.compute_uplift_curve_respect_to_base(results, ref=1.0)
pc = mmm.plot.uplift_curve(data=uplift_data)
pc.show()