Lightweight MMM を Google Colab で実行する方法【2025年5月版】

Google ColabでLightweight MMMを実行するための手順を解説します。公式デモの修正方法や必要なライブラリのインストール方法を詳しく紹介。2025年5月時点での環境に対応した実行方法を説明し、エラーが発生しないように丁寧に解説します。
目次
1 はじめに
マーケティング施策の効果を統計的に推定する手法であるMarketing Mix Modeling(MMM)は、最近のサードパーティクッキー規制の流れを受けて、ますます注目を集めています。
MMMを手軽に試してみたいと思ったときに、「Google ColabでLightweight MMMライブラリを使用する」というのが選択肢の一つになるのではないでしょうか。
実際、Lightweight MMMの公式リポジトリにはGoogle Colab用のデモが提供されていますし、日本語による記事も比較的充実しています。
しかし、このデモが作成されてから少し時間が経っているため、最新のGoogle Colab環境でそのまま実行すると、いくつかのエラーが発生することがあります。
そこで本ブログ記事では、Google ColabでLightweight MMMを試してみたいという方に向けて、2025年5月時点での実行手順を解説します。なお、本記事ではMMMの理論やモデリングの技術については扱いません。
2 手順1 : ライブラリのインストール
まず、Lightweight MMMライブラリをインストールします。
現時点で最新のバージョンが0.1.9
です。lightweight_mmm==0.1.9
とバージョンを指定せずにインストールすることもできますが、念のため明記しておきます。
# LightweightMMMライブラリのインストール
!pip install lightweight_mmm==0.1.9
実行すると、「セッションを再起動する」という警告が出ると思うので、[セッションを再起動する]をクリックしてください。
また、次のようなエラーも出力されるはずです。
どうやら3.8.0
以降のmatplotlib
が必要みたいなので、バージョンを変更します。ここでも「セッションを再起動する」という警告が出ると思いますが、再起動をクリックすればよいです。
# matplotlibのバージョン変更
!pip install -U matplotlib==3.8.0
3 手順2 : ライブラリのインポート
Lightweight MMMはJAXとNumPyroを基盤にしています。最初にjax.numpy
をインポートし、その後numpyro
をインポートします。
import jax.numpy as jnp
import numpyro
次に、lightweight_mmm
をインポートします。
# LMMMをimportする
from lightweight_mmm import lightweight_mmm
from lightweight_mmm import optimize_media
from lightweight_mmm import plot
from lightweight_mmm import preprocessing
from lightweight_mmm import utils
その他、必要なライブラリがあれば読み込んでおきましょう。
ここでは、手元のcsvファイルをLightweight MMMの入力データにするために、pandas
をインポートしておきます。
import pandas as pd
4 手順3 : データの準備
今回は適当に乱数生成したダミーデータをcsvファイルとして用意しています。
Googleドライブにcsvファイルを格納してあるものとして、ドライブをマウントします。
from google.colab import drive
drive.mount('/content/drive')
csvファイルを読み込み、pandas.Dataframeとして格納します。
# データの読み込み
df = pd.read_csv('/content/drive/MyDrive/[your_path]/[file_name].csv')
次に、説明変数を メディア・広告関連の変数 / その他の外生変数 / (費用関連の変数; optional) に分け、リストに格納します。また、時系列を表すカラムと目的変数も定義しておきます。
list_media = [
'GoogleAdSpend',
'YouTubeSpend'
]
list_extra = [
'OfficialEvent'
]
date_feature = 'Week'
target = 'Sales'
SEED = 42
data_size = df.shape[0]
次に、Lightweight MMMで扱えるように、pandas.Dataframeをjax.numpyに変換します。
lightweight_mmm
ライブラリにはutils.dataframe_to_jax()
が用意されているので、これを使います。
media_data, extra_features, target, costs = utils.dataframe_to_jax( dataframe = df, media_features = list_media, extra_features = list_extra, date_feature = date_feature, target = target )
学習データとテストデータに分割します。
# データを分割する
split_point = data_size - 13
# メディア関連の変数
media_data_train = media_data[:split_point, ...]
media_data_test = media_data[split_point:, ...]
# 外生変数
extra_features_train = extra_features[:split_point, ...]
extra_features_test = extra_features[split_point:, ...]
# 目的変数
target_train = target[:split_point]
データをスケーリングします。
media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean, multiply_by=0.15)
media_data_train = media_scaler.fit_transform(media_data_train)
extra_features_train = extra_features_scaler.fit_transform(extra_features_train)
target_train = target_scaler.fit_transform(target_train)
costs = cost_scaler.fit_transform(costs)
5 手順4 : JAXエラーの対応
ここまででデータの準備は完了しました。
次にモデルを学習するのですが、ここでLightweight MMMとJAXのバージョン干渉が発生します。
実際、モデルをフィッティングするコードを実行してみると、エラーが出るはずです。
mmm = lightweight_mmm.LightweightMMM(model_name="carryover")
number_warmup=1000
number_samples=1000
# MCMCサンプリングを行う
mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
seed=SEED
)
このエラーは、Lightweight MMMが想定するJAXのバージョンとGoogle Colabに入っているJAXのバージョンが異なることが原因で起きています。
lightweight_mmm 0.1.9
は、JAXの古いバージョン (0.3系) を想定して実装されていますが、最新のGoogle Colab環境には JAX 0.4系 がインストールされています。lightweight_mmm 0.1.9
で実装されている
jax.numpy.where(condition=(data == 0), x=1, y=data)
というキーワード引数付き呼び出しはJAX 0.4系 からは許されず、
jax.numpy.where(cond, x, y)
とする必要があります。
次のようにGoogle Colab環境でJAXのバージョンを0.3系にダウングレードすればよいはずなのですが、筆者はうまく実行できませんでした。
!pip install jax==0.3.25 jaxlib==0.3.25
そこで、応急処置ではありますがライブラリを一時的にパッチして対応します。
import jax.numpy as jnp
import lightweight_mmm.media_transforms as mt
def _apply_exponent_safe(data, exponent):
exponent_safe = jnp.where((data == 0), 1, data) ** exponent
return jnp.where((data == 0), 0, exponent_safe)
mt.apply_exponent_safe = _apply_exponent_safe # fit() より前に1回だけ実行
6 手順5 : パラメータの推定
モデルのインスタンスを作成し、パラメータをMCMCサンプリングによって推定します。
mmm = lightweight_mmm.LightweightMMM(model_name="carryover")
number_warmup=1000
number_samples=1000
# MCMCサンプリングを行う
mmm.fit(
media=media_data_train,
media_prior=costs,
target=target_train,
extra_features=extra_features_train,
number_warmup=number_warmup,
number_samples=number_samples,
seed=SEED
)
以下により、各パラメータが推定できていることが確認できると思います。
mmm.print_summary()
7 おわりに
本ブログ記事では、Google ColabでLightweight MMMを動かすために、公式のデモを修正する手順を紹介しました。
本格的にMMMをプロジェクトで運用する場合は、専用の環境を構築すると思いますが、デモとして試してみたい場合にはGoogle ColabでLightweight MMMを使うというのは良い選択肢だと思います。
本記事で解説した手順に沿ってコードを実行すれば、エラーなしでMMMを実行できるはずです。ぜひ試してみてください!