Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More informative structure mismatch error message in Linear layer #92

Merged
merged 3 commits into from
Nov 21, 2024

Conversation

amifalk
Copy link
Contributor

@amifalk amifalk commented Nov 15, 2024

As mentioned in #89. With the following changes,

from penzai import pz
import jax

linear = pz.nn.Linear.from_config("MyModel/MyComplexLayer/Linear_2", 
                                  jax.random.PRNGKey(0),
                                  input_axes={"feature": 3},
                                  output_axes={})

linear(pz.nx.ones({"feature": 4}))

will throw an error with information about the layer name:

penzai.core.shapecheck.StructureMismatchError: (MyModel/MyComplexLayer/Linear_2) Mismatch while checking structures:
At root: Named shape mismatch between value {'feature': 4} and pattern {**var('B'), 'feature': 3}:
  Axis 'feature': Actual size 4 does not match expected 3

This makes it much easier to debug shape errors in complex models with many layers.

Copy link

google-cla bot commented Nov 15, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

penzai/nn/linear_and_affine.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, mind running uv run pyink penzai to re-format the changed files using pyink so that the checks pass?

Thanks much!

@amifalk
Copy link
Contributor Author

amifalk commented Nov 20, 2024

By the way, the .vscode setup files didn't work for me, and I had to manually call pyink with uvx pyink penzai. Edit: For some reason uv didn't automatically install the dev dependencies, but after calling uv pip install .[dev] it worked.

if isinstance(
self.weights,
Parameter | ParameterValue,
) and self.weights.label.endswith(".weights"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like pytype is raising an error here. I don't think it knows how to refine the type of self.weights. Would you mind either:

  • refactoring this so that you first assign to a local variable weights = self.weights and then change the rest of the function to refer to weights,
  • or just adding a comment # pytype: disable=attribute-error here to tell pytype you know what you're doing?

Copy link
Collaborator

@danieldjohnson danieldjohnson Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(You should be able to run the typechecker yourself with uv run pytype --jobs auto penzai as long as the dev dependencies are installed.)

Copy link
Contributor Author

@amifalk amifalk Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, for someone reason locally disabling the error doesn't seem to work

    if isinstance(
        self.weights,
        Parameter | ParameterValue,
    ) and (  # pytype: disable=attribute-error
        self.weights.label.endswith(".weights")
    ):
      error_prefix = (  # pytype: disable=attribute-error
          f"({self.weights.label[:-8]}) "
      )
    else:
      error_prefix = ""
``

and neither does assigning self.weights to weights. Globally ignoring the attribute error does seem to work though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm that's weird. Does it work if you add

# pytype: disable=attribute-error

on a line on its own before the if block, and

# pytype: enable=attribute-error

on a line after?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That did the trick!

@danieldjohnson
Copy link
Collaborator

By the way, the .vscode setup files didn't work for me, and I had to manually call pyink with uvx pyink penzai. Edit: For some reason uv didn't automatically install the dev dependencies, but after calling uv pip install .[dev] it worked.

I'm not sure if this is the same thing you're running into, but I've been running into an issue where uvx (as run through VSCode) sometimes selects a default Python version that Black/Pyink doesn't support. I've tried to fix it in #97.

@danieldjohnson danieldjohnson merged commit e22d27f into google-deepmind:main Nov 21, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants