-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Failing to call CVRPEnv.local_search #212
Labels
bug
Something isn't working
Comments
Hi @ShuN6211 ! Thanks a lot, I pushed some hotfixes :) %load_ext autoreload
%autoreload 2
import torch
from rl4co.models.zoo import AttentionModelPolicy
from rl4co.envs import CVRPEnv
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize environment and policy
env = CVRPEnv(generator_params={'num_loc': 50})
policy = AttentionModelPolicy(env_name=env.name).to(device)
td_init = env.reset(batch_size=[8]).to(device)
# Rollout policy (untrained here, you may load a pre-trained model first)
out = policy(
td_init.clone(), phase="test", decode_type="greedy", return_actions=True
)
# Get initial rewards
rewards = env.get_reward(td_init, out["actions"])
print(f"Rewards: {rewards.mean():.3f}")
# Improve actions using local search
improved_actions = env.local_search(td_init.cpu(), out["actions"].cpu()) # raises error!!
rewards = env.get_reward(td_init.cpu(), improved_actions)
print(f"Rewards: {rewards.mean():.3f}") The above should work! |
Yup! I realized late that you had also opened a PR, I went straight for the fix 🤣 thanks~ |
Not a problem🤣 Your work was so quick and just be impressed!! thanks!! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Describe the bug
Calling
CVRPEnv.local_search
raises error and it fails.To Reproduce
System info
Reason and Possible fixes
#211
Checklist
The text was updated successfully, but these errors were encountered: