MMMPlotSuite.marginal_curve#
- MMMPlotSuite.marginal_curve(data=None, hdi_prob=0.94, aggregation=None, backend=None)[source]#
Plot marginal effects showing absolute rate of change.
Visualizes the instantaneous rate of change (derivative) of model outputs with respect to inputs. Shows how much output changes per unit change in input at each sweep value.
- Parameters:
- data
xr.DataArrayorxr.Dataset, optional Marginal effects data computed from sensitivity analysis. If Dataset, should contain ‘marginal_effects’ variable. If None, uses self.idata.sensitivity_analysis[‘marginal_effects’].
Must be precomputed using:
SensitivityAnalysis.compute_marginal_effects(...)This parameter allows: - Testing with mock marginal effects data - Plotting externally computed marginal effects - Comparing marginal effects from different models- hdi_prob
float, default 0.94 HDI probability mass (between 0 and 1).
- aggregation
dict, optional Aggregations to apply before plotting. Keys: “sum”, “mean”, or “median” Values: tuple of dimension names
Example:
{"sum": ("channel",)}sums over channels before plotting.- backend
str|None, optional Backend to use for plotting. If None, uses global backend configuration.
- data
- Returns:
PlotCollectionarviz_plots PlotCollection object containing the plot.
Use
.show()to display or.save("filename")to save. Unlike the legacy suite which returned(Figure, Axes)orAxes, this provides a unified interface across all backends.
- Raises:
ValueErrorIf no marginal effects data found in self.idata and no data provided.
ValueErrorIf ‘marginal_effects’ variable not found in sensitivity_analysis group.
See also
sensitivity_analysisPlot raw sensitivity analysis results
uplift_curvePlot uplift percentages (relative changes)
LegacyMMMPlotSuite.marginal_curveLegacy matplotlib-only implementation
Notes
Breaking changes from legacy implementation:
Returns PlotCollection instead of (Figure, Axes) or Axes
Cleaner implementation without monkey-patching
No longer modifies self.idata.sensitivity_analysis temporarily
Data parameter for explicit data passing
Marginal effects show the slope of the sensitivity curve, helping identify where returns are diminishing most rapidly.
Examples
Compute and plot marginal effects:
from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis # Run sensitivity sweep sweeps = np.linspace(0.5, 1.5, 11) sa = SensitivityAnalysis(mmm.model, mmm.idata) results = sa.run_sweep( var_input="channel_data", sweep_values=sweeps, var_names="channel_contribution", sweep_type="multiplicative", ) # Compute marginal effects (derivatives) me = sa.compute_marginal_effects( results, extend_idata=True, # Store in idata ) # Plot stored marginal effects pc = mmm.plot.marginal_curve(hdi_prob=0.9) pc.show()
Aggregate over channels:
pc = mmm.plot.marginal_curve(aggregation={"sum": ("channel",)}) pc.show()
Use different backend:
pc = mmm.plot.marginal_curve(backend="plotly") pc.show()
Provide explicit data:
marginal_data = sa.compute_marginal_effects(results) pc = mmm.plot.marginal_curve(data=marginal_data) pc.show()