diff --git a/lightweight_mmm/optimize_media.py b/lightweight_mmm/optimize_media.py index 002b557..996ea84 100644 --- a/lightweight_mmm/optimize_media.py +++ b/lightweight_mmm/optimize_media.py @@ -71,7 +71,7 @@ def _objective_function( media_values = jnp.tile( media_values / media_input_shape[0], reps=media_input_shape[0]) # Distribute budget of each channels across time. - media_values = jnp.reshape(a=media_values, newshape=media_input_shape) + media_values = jnp.reshape(a=media_values, shape=media_input_shape) media_values = media_scaler.transform(media_values) return -jnp.sum( media_mix_model.predict(