-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y' #321
Comments
Hi, I had the same issue, have you resolved it? |
HI @ShirleyChai730 : I haven't yet been able to resolve the above issue. |
Hi @ShirleyChai730 |
Thank you @Munger245 . It works. |
Thanks for pointing out this. I tried 0.4.20 and it still didn't work but I tried the older version 0.4.19 it works. |
@ShirleyChai730 I am also getting this error on mac m2. What is the version of lightweight_mmm that worked on your machine? Can you please share requirement file here with python version? |
@rahulmisal27 : I am using the latest version of lightweight mmm and it works. |
I tried installing jax and jaxlib 0.4.20 and have the same error, how did you fix it? @datainsight1 |
In a fresh Python 3.10 environment I needed to fix these versions to get things working:
|
hi there! im running into the same error with python 3.11 environment.. Anyone has figured out which version of jax is appropriate for this env? |
Hi, I have same issue. [7/13/24 edit] |
I had the same error message and installing jax and jaxlib versions 0.4.20 did not work for me. I have since fixed it and i'll list below the steps I took in case anyone has the same issue. Firstly, I created a python virtual environment using Anaconda with python version 3.10.14 as that's the latest version we know that works according to |
I encountered the same problem. My python version is 3.11.5. Finally, I followed the instructions of the two issues and installed the following versions:
This is useful for me! |
I am using python 3.10. In my case i also have to update the numpyro library to make it work. Packages updated below: scipy==1.12.0 |
Thank you ! This set up worked for me with python 3.11.7 |
TypeError Traceback (most recent call last)
Cell In[9], line 4
2 number_warmup=100
3 number_samples=100
----> 4 mmm.fit(
5 media=media_data_train,
6 media_prior=costs,
7 target=target_train,
8 extra_features=extra_features_train,
9 number_warmup=number_warmup,
10 number_samples=number_samples,
11 seed=SEED)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/lightweight_mmm.py:363, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed)
353 kernel = numpyro.infer.NUTS(
354 model=self._model_function,
355 target_accept_prob=target_accept_prob,
356 init_strategy=init_strategy)
358 mcmc = numpyro.infer.MCMC(
359 sampler=kernel,
360 num_warmup=number_warmup,
361 num_samples=number_samples,
362 num_chains=number_chains)
--> 363 mcmc.run(
364 rng_key=jax.random.PRNGKey(seed),
365 media_data=jnp.array(media),
366 extra_features=extra_features,
367 target_data=jnp.array(target),
368 media_prior=jnp.array(media_prior),
369 degrees_seasonality=degrees_seasonality,
370 frequency=seasonality_frequency,
371 transform_function=self._model_transform_function,
372 weekday_seasonality=weekday_seasonality,
373 custom_priors=custom_priors)
375 self.custom_priors = custom_priors
376 if media_names is not None:
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:638, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
636 else:
637 if self.chain_method == "sequential":
--> 638 states, last_state = _laxmap(partial_map_fn, map_args)
639 elif self.chain_method == "parallel":
640 states, last_state = pmap(partial_map_fn)(map_args)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:166, in _laxmap(f, xs)
164 for i in range(n):
165 x = jit(_get_value_from_index)(xs, i)
--> 166 ys.append(f(x))
168 return tree_map(lambda *args: jnp.stack(args), *ys)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:416, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
414 # Check if _sample_fn is None, then we need to initialize the sampler.
415 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 416 new_init_state = self.sampler.init(
417 rng_key,
418 self.num_warmup,
419 init_params,
420 model_args=args,
421 model_kwargs=kwargs,
422 )
423 init_state = new_init_state if init_state is None else init_state
424 sample_fn, postprocess_fn = self._get_cached_fns()
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
708 # vectorized
709 else:
710 rng_key, rng_key_init_model = jnp.swapaxes(
711 vmap(random.split)(rng_key), 0, 1
712 )
--> 713 init_params = self._init_state(
714 rng_key_init_model, model_args, model_kwargs, init_params
715 )
716 if self._potential_fn and init_params is None:
717 raise ValueError(
718 "Valid value of
init_params
must be provided with" "potential_fn
."719 )
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
651 if self._model is not None:
652 (
653 new_init_params,
654 potential_fn,
655 postprocess_fn,
656 model_trace,
--> 657 ) = initialize_model(
658 rng_key,
659 self._model,
660 dynamic_args=True,
661 init_strategy=self._init_strategy,
662 model_args=model_args,
663 model_kwargs=model_kwargs,
664 forward_mode_differentiation=self._forward_mode_differentiation,
665 )
666 if init_params is None:
667 init_params = new_init_params
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:656, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
646 model_kwargs = {} if model_kwargs is None else model_kwargs
647 substituted_model = substitute(
648 seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
649 substitute_fn=init_strategy,
650 )
651 (
652 inv_transforms,
653 replay_model,
654 has_enumerate_support,
655 model_trace,
--> 656 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
657 # substitute param sites from model_trace to model so
658 # we don't need to generate again parameters of
numpyro.module
659 model = substitute(
660 model,
661 data={
(...)
665 },
666 )
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:450, in _get_model_transforms(model, model_args, model_kwargs)
448 def _get_model_transforms(model, model_args=(), model_kwargs=None):
449 model_kwargs = {} if model_kwargs is None else model_kwargs
--> 450 model_trace = trace(model).get_trace(*model_args, **model_kwargs)
451 inv_transforms = {}
452 # model code may need to be replayed in the presence of deterministic sites
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return:
OrderedDict
containing the execution trace.170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:385, in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features)
380 elif transform_function == "carryover" and not transform_kwargs:
381 transform_kwargs = {"number_lags": 13 * 7}
383 media_transformed = numpyro.deterministic(
384 name="media_transformed",
--> 385 value=transform_function(media_data,
386 custom_priors=custom_priors,
387 **transform_kwargs if transform_kwargs else {}))
388 seasonality = media_transforms.calculate_seasonality(
389 number_periods=data_size,
390 degrees=degrees_seasonality,
391 frequency=frequency,
392 gamma_seasonality=gamma_seasonality)
393 # For national model's case
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:280, in transform_carryover(media_data, custom_priors, number_lags)
278 if media_data.ndim == 3:
279 exponent = jnp.expand_dims(exponent, axis=-1)
--> 280 return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent)
File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/media_transforms.py:189, in apply_exponent_safe(data, exponent)
172 @jax.jit
173 def apply_exponent_safe(
174 data: jnp.ndarray,
175 exponent: jnp.ndarray,
176 ) -> jnp.ndarray:
177 """Applies an exponent to given data in a gradient safe way.
178
179 More info on the double jnp.where can be found:
(...)
187 The result of the exponent operation with the inputs provided.
188 """
--> 189 exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent
190 return jnp.where(condition=(data == 0), x=0, y=exponent_safe)
TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'
The text was updated successfully, but these errors were encountered: