Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about CNN policy input channel [question] #1195

Open
DavidLudl opened this issue Jul 19, 2024 · 0 comments
Open

Questions about CNN policy input channel [question] #1195

DavidLudl opened this issue Jul 19, 2024 · 0 comments

Comments

@DavidLudl
Copy link

Hello,

I am learning how to implement the costum CNN policy and environment with the stablebaseline 3. I am following the example "Custom Feature Extractor" in this link:
https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html

I am confusing about the channel, which is defined as observation_space.shape[0]. When I am examing the oberservation space of gymnasium:

import gymnasium as gym

env = gym.make("BreakoutNoFrameskip-v4")
print("Observation Space Shape: ", env.observation_space.shape)
print("Image Channel: ", env.observation_space.shape[0])

The output is

Observation Space Shape:  (210, 160, 3)
Image Channel:  210

But when I excute the code in the link. There is no error. But if I pass the last item of observation space shape, n_input_channels = observation_space.shape[2], which I suppose the correct channel size. The error raised. So I want to ask, whether the SB3 reshuffle the observation space shape? And when I define my own ENV, should I set the space shape C * H * W or H * W * C (where should I put the channel)?

Thank you for your time and help.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant