diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml index be1ae39..24af6fa 100644 --- a/.github/workflows/black.yml +++ b/.github/workflows/black.yml @@ -9,6 +9,6 @@ jobs: steps: - uses: actions/checkout@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Black Check uses: psf/black@22.8.0 diff --git a/.github/workflows/doc.yml b/.github/workflows/doc.yml index ba5fb3b..5d84512 100644 --- a/.github/workflows/doc.yml +++ b/.github/workflows/doc.yml @@ -11,14 +11,14 @@ jobs: - name: Setup Python environment uses: actions/setup-python@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Install dependencies run: | sudo apt install pandoc - pip3 install -r requirements.txt + pip install -r dev-requirements.txt + pip install -r requirements.txt pip install . - pip3 install h5py --upgrade --no-dependencies - pip3 install cached-property + - name: Build documentation run: | cd docs diff --git a/.github/workflows/flake8.yml b/.github/workflows/flake8.yml index bb71a80..1e4a252 100644 --- a/.github/workflows/flake8.yml +++ b/.github/workflows/flake8.yml @@ -9,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.10" ] + python-version: [ "3.11" ] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index 7a3d1e9..0a56f07 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -10,7 +10,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: '3.10' + python-version: '3.11' - name: Install isort run: | pip install isort==5.10.1 diff --git a/.github/workflows/nbtest.yml b/.github/workflows/nbtest.yml index a482867..26370ff 100644 --- a/.github/workflows/nbtest.yml +++ b/.github/workflows/nbtest.yml @@ -12,7 +12,7 @@ jobs: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: - python-version: "3.10" + python-version: "3.11" - name: Install dev requirements run: | pip3 install nbmake diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 12c5d1b..cc6aff5 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10"] + python-version: ["3.11"] steps: - uses: actions/checkout@v2 @@ -22,8 +22,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install -r dev-requirements.txt + pip install -r requirements.txt - name: Install package run: | pip install . diff --git a/CI/unit_tests/models/test_huggingface_flax_model.py b/CI/unit_tests/models/_test_huggingface_flax_model.py similarity index 93% rename from CI/unit_tests/models/test_huggingface_flax_model.py rename to CI/unit_tests/models/_test_huggingface_flax_model.py index 7ef3867..7d77063 100644 --- a/CI/unit_tests/models/test_huggingface_flax_model.py +++ b/CI/unit_tests/models/_test_huggingface_flax_model.py @@ -46,7 +46,6 @@ def setup_class(cls): Create a model and data for the tests. The resnet config has a 1 dimensional input and a 2 dimensional output. """ - resnet_config = ResNetConfig( num_channels=2, embedding_size=64, @@ -88,3 +87,11 @@ def test_infinite_failure(self): """ with pytest.raises(NotImplementedError): self.model.compute_ntk(self.x, infinite=True) + + +if __name__ == "__main__": + test_class = TestFlaxHFModule() + test_class.setup_class() + + # test_class.test_infinite_failure() + test_class.test_ntk_shape() diff --git a/CI/unit_tests/utils/test_matrix_utils.py b/CI/unit_tests/utils/test_matrix_utils.py index 44c5a5a..bc73ded 100644 --- a/CI/unit_tests/utils/test_matrix_utils.py +++ b/CI/unit_tests/utils/test_matrix_utils.py @@ -54,7 +54,7 @@ def test_unscaled_eigenvalues(self): values, vectors = compute_eigensystem(matrix, normalize=False) - assert_array_equal(np.real(values), [1, 1]) + assert_array_equal(np.real(values), [1.0, 1.0]) def test_scaled_eigenvalues(self): """ diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000..89184ee --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,10 @@ +isort>=5.13.2 +black>=24.4.0 +sphinx>=7.3.7 +sphinx_copybutton>=0.5.2 +sphinx_rtd_theme>=2.0.0 +nbsphinx>=0.9.3 +pytest>=8.1.1 +numpydoc>=1.7.0 +flake8>=7.0.0 +pre_commit>=3.7.0 diff --git a/requirements.txt b/requirements.txt index d31d588..71bb1d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,28 +1,17 @@ -numpy -matplotlib -sphinx -flake8 -black -ipython -numpydoc -optax -sphinx_copybutton -sphinx_rtd_theme -nbsphinx -tensorflow_probability -scipy -scikit-learn -# Temp fix of version of jax and jaxlib until the next release -jax<=0.4.25 -jaxlib<=0.4.25 -plotly -flax -tqdm -pandas +numpy>=1.26.4 +matplotlib>=3.8.4 +optax>=0.2.2 +tensorflow_probability>=0.24.0 +scipy>=1.13.0 +scikit-learn>=1.4.2 +plotly>=5.21.0 +flax>=0.8.2 +tqdm>=4.66.2 +pandas>=2.2.2 neural-tangents>=0.6.5 -tensorflow-datasets -isort -tensorflow -pyyaml -jupyter -transformers \ No newline at end of file +tensorflow-datasets>=4.9.4 +tensorflow>=2.16.1 +jupyter>=1.0.0 +transformers>=4.40.0 +jax>=0.4.26 +jaxlib>=0.4.26 \ No newline at end of file diff --git a/znnl/models/jax_model.py b/znnl/models/jax_model.py index 8c816db..eb1448b 100644 --- a/znnl/models/jax_model.py +++ b/znnl/models/jax_model.py @@ -25,7 +25,6 @@ ------- """ -from functools import partial from typing import Any, Callable, Optional, Sequence, Union import jax diff --git a/znnl/training_strategies/simple_training.py b/znnl/training_strategies/simple_training.py index 325bf7e..c0675e8 100644 --- a/znnl/training_strategies/simple_training.py +++ b/znnl/training_strategies/simple_training.py @@ -369,7 +369,7 @@ def train_model( state = self.model.model_state loading_bar = trange( - 1, epochs + 1, ncols=100, unit="batch", disable=self.disable_loading_bar + 0, epochs, ncols=100, unit="batch", disable=self.disable_loading_bar ) train_losses = []