diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py index c171b88800ed02..af5e11c83a6ad2 100644 --- a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -9,13 +9,13 @@ import jax.numpy as jnp import joblib import optax +import wandb from flax import jax_utils, struct, traverse_util from flax.serialization import from_bytes, to_bytes from flax.training import train_state from flax.training.common_utils import shard from tqdm.auto import tqdm -import wandb from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py index 3840918d16ae1a..ce37b7f975bb3a 100644 --- a/examples/research_projects/jax-projects/big_bird/train.py +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -2,11 +2,11 @@ from dataclasses import replace import jax +import wandb from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step from datasets import load_dataset from flax import jax_utils -import wandb from transformers import BigBirdTokenizerFast diff --git a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py index 2a39955e347fe9..1bfbc4cd5c36f3 100644 --- a/examples/research_projects/vqgan-clip/VQGAN_CLIP.py +++ b/examples/research_projects/vqgan-clip/VQGAN_CLIP.py @@ -4,12 +4,12 @@ import imageio import torch import torchvision +import wandb from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan from loaders import load_vqgan from PIL import Image from torch import nn -import wandb from transformers import CLIPModel, CLIPTokenizerFast from utils import get_device, get_timestamp, show_pil