Skip to content

Commit

Permalink
remove spend scaling step from budget optimizer (pymc-labs#1070)
Browse files Browse the repository at this point in the history
* remove spend scaling step

* re-run nb
  • Loading branch information
juanitorduz authored and jsnyde0 committed Sep 30, 2024
1 parent 0592fcb commit 210a53b
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 109 deletions.
188 changes: 106 additions & 82 deletions docs/source/notebooks/mmm/mmm_budget_allocation_example.ipynb

Large diffs are not rendered by default.

40 changes: 13 additions & 27 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,28 +2239,14 @@ def plot_budget_allocation(
The matplotlib figure object and axis containing the plot.
"""
if original_scale:
channel_contributions = (
samples["channel_contributions"]
.mean(dim=["sample"])
.mean(dim=["date"])
.values
* self.get_target_transformer()["scaler"].scale_
)
channel_contributions = (
samples["channel_contributions"].mean(dim=["date", "sample"]).to_numpy()
)

allocate_spend = (
np.array(list(self.optimal_allocation_dict.values()))
* self.channel_transformer["scaler"].scale_
)
if original_scale:
channel_contributions *= self.get_target_transformer()["scaler"].scale_

else:
channel_contributions = (
samples["channel_contributions"]
.mean(dim=["sample"])
.mean(dim=["date"])
.values
)
allocate_spend = np.array(list(self.optimal_allocation_dict.values()))
allocated_spend = np.array(list(self.optimal_allocation_dict.values()))

if ax is None:
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -2274,11 +2260,11 @@ def plot_budget_allocation(

bars1 = ax.bar(
index,
allocate_spend,
allocated_spend,
bar_width,
color="b",
color="C0",
alpha=opacity,
label="Allocate Spend",
label="Allocated Spend",
)

ax2 = ax.twinx()
Expand All @@ -2287,19 +2273,19 @@ def plot_budget_allocation(
index + bar_width,
channel_contributions,
bar_width,
color="r",
color="C1",
alpha=opacity,
label="Channel Contributions",
)

ax.set_xlabel("Channels")
ax.set_ylabel("Allocate Spend", color="b")
ax.set_ylabel("Allocate Spend", color="C0")
ax.tick_params(axis="x", rotation=90)
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(self.channel_columns)

ax.set_ylabel("Allocate Spend", color="b", labelpad=10)
ax2.set_ylabel("Channel Contributions", color="r", labelpad=10)
ax.set_ylabel("Allocate Spend", color="C0", labelpad=10)
ax2.set_ylabel("Channel Contributions", color="C1", labelpad=10)

ax.grid(False)
ax2.grid(False)
Expand Down

0 comments on commit 210a53b

Please sign in to comment.