Skip to content

Commit

Permalink
- Include transformers in requirements
Browse files Browse the repository at this point in the history
- run black
- fix imports
  • Loading branch information
knikolaou committed Nov 3, 2023
1 parent 1b5efc8 commit 8d79342
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 7 deletions.
16 changes: 15 additions & 1 deletion examples/HuggingFace_ResNet_Implementation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"import optax\n",
"\n",
"from znnl.models import HuggingFaceFlaxModel\n",
"\n",
"from transformers import ResNetConfig, FlaxResNetForImageClassification\n",
"import jax\n",
"print(jax.default_backend())"
]
Expand Down Expand Up @@ -174,6 +174,20 @@
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tensorflow_probability
scipy
scikit-learn
jaxlib
jax
# jax
plotly
flax
tqdm
Expand All @@ -23,4 +23,5 @@ tensorflow-datasets
isort
tensorflow
pyyaml
jupyter
jupyter
transformers
1 change: 1 addition & 0 deletions znnl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Summary
-------
"""

from znnl.models.flax_model import FlaxModel
from znnl.models.huggingface_flax_model import HuggingFaceFlaxModel
from znnl.models.jax_model import JaxModel
Expand Down
1 change: 1 addition & 0 deletions znnl/models/flax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Summary
-------
"""

import logging
from typing import Callable, List, Sequence, Union

Expand Down
3 changes: 2 additions & 1 deletion znnl/models/huggingface_flax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
-------
Module for using a Flax model from Hugging Face in ZnNL.
"""

import logging
from typing import Callable, List, Sequence, Union

import jax
import jax.numpy as np
from flax import linen as nn
from transformers import FlaxPreTrainedModel, ResNetConfig
from transformers import FlaxPreTrainedModel

from znnl.models.jax_model import JaxModel

Expand Down
7 changes: 4 additions & 3 deletions znnl/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Summary
-------
"""

from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union

import jax
Expand Down Expand Up @@ -77,16 +78,16 @@ def __init__(
self.optimizer = optimizer
self.input_shape = input_shape

# Initialized in self.init_model
self.rng = None

# Input shape is required if no full model is passed.
if pre_built_model is None and input_shape is None:
raise ValueError(
"Input shape must be specified if no pre-built model is passed."
"Model is yet to be constructed."
)

# Initialized in self.init_model
self.rng = None

# initialize the model state
if pre_built_model is None:
self.init_model(seed=seed)
Expand Down
1 change: 1 addition & 0 deletions znnl/models/nt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Summary
-------
"""

import logging
from typing import Callable, Sequence, Union

Expand Down

0 comments on commit 8d79342

Please sign in to comment.