MMMPlotSuite.marginal_curve#

MMMPlotSuite.marginal_curve(data=None, hdi_prob=0.94, aggregation=None, backend=None)[source]#

Plot marginal effects showing absolute rate of change.

Visualizes the instantaneous rate of change (derivative) of model outputs with respect to inputs. Shows how much output changes per unit change in input at each sweep value.

Parameters:
dataxr.DataArray or xr.Dataset, optional

Marginal effects data computed from sensitivity analysis. If Dataset, should contain ‘marginal_effects’ variable. If None, uses self.idata.sensitivity_analysis[‘marginal_effects’].

Must be precomputed using: SensitivityAnalysis.compute_marginal_effects(...) This parameter allows: - Testing with mock marginal effects data - Plotting externally computed marginal effects - Comparing marginal effects from different models

hdi_probfloat, default 0.94

HDI probability mass (between 0 and 1).

aggregationdict, optional

Aggregations to apply before plotting. Keys: “sum”, “mean”, or “median” Values: tuple of dimension names

Example: {"sum": ("channel",)} sums over channels before plotting.

backendstr | None, optional

Backend to use for plotting. If None, uses global backend configuration.

Returns:
PlotCollection

arviz_plots PlotCollection object containing the plot.

Use .show() to display or .save("filename") to save. Unlike the legacy suite which returned (Figure, Axes) or Axes, this provides a unified interface across all backends.

Raises:
ValueError

If no marginal effects data found in self.idata and no data provided.

ValueError

If ‘marginal_effects’ variable not found in sensitivity_analysis group.

See also

sensitivity_analysis

Plot raw sensitivity analysis results

uplift_curve

Plot uplift percentages (relative changes)

LegacyMMMPlotSuite.marginal_curve

Legacy matplotlib-only implementation

Notes

Breaking changes from legacy implementation:

  • Returns PlotCollection instead of (Figure, Axes) or Axes

  • Cleaner implementation without monkey-patching

  • No longer modifies self.idata.sensitivity_analysis temporarily

  • Data parameter for explicit data passing

Marginal effects show the slope of the sensitivity curve, helping identify where returns are diminishing most rapidly.

Examples

Compute and plot marginal effects:

from pymc_marketing.mmm.sensitivity_analysis import SensitivityAnalysis

# Run sensitivity sweep
sweeps = np.linspace(0.5, 1.5, 11)
sa = SensitivityAnalysis(mmm.model, mmm.idata)
results = sa.run_sweep(
    var_input="channel_data",
    sweep_values=sweeps,
    var_names="channel_contribution",
    sweep_type="multiplicative",
)

# Compute marginal effects (derivatives)
me = sa.compute_marginal_effects(
    results,
    extend_idata=True,  # Store in idata
)

# Plot stored marginal effects
pc = mmm.plot.marginal_curve(hdi_prob=0.9)
pc.show()

Aggregate over channels:

pc = mmm.plot.marginal_curve(aggregation={"sum": ("channel",)})
pc.show()

Use different backend:

pc = mmm.plot.marginal_curve(backend="plotly")
pc.show()

Provide explicit data:

marginal_data = sa.compute_marginal_effects(results)
pc = mmm.plot.marginal_curve(data=marginal_data)
pc.show()