Skip to content

Commit

Permalink
random for all datasets ready
Browse files Browse the repository at this point in the history
  • Loading branch information
whoisjones committed Nov 2, 2023
1 parent f2e2cc6 commit 690d109
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
30 changes: 27 additions & 3 deletions first_experiment/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,46 @@
"imdb": {"text_column": ("text", None), "label_column": "label"},
"rte": {"text_column": ("sentence1", "sentence2"), "label_column": "label"},
"qnli": {"text_column": ("question", "sentence"), "label_column": "label"},
"sst2": {"text_column": ("text", None), "label_column": "label"},
"sst2": {"text_column": ("sentence", None), "label_column": "label"},
"snli": {"text_column": ("premise", "hypothesis"), "label_column": "label"},
}

eval_splits = {
"imdb": "test",
"rte": "validation",
"qnli": "validation",
"sst2": "validation",
"snli": "test",
}

dataset_prefixes = {
"rte": "glue",
"qnli": "glue",
"sst2": "glue",
}


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--dataset", type=str, default="imdb")
parser.add_argument("--dataset", type=str, default="imdb", choices=["imdb", "rte", "qnli", "sst2", "snli"])
parser.add_argument("--tam_model", type=str, default="distilbert-base-uncased")
parser.add_argument("--embedding_model", type=str, default=None)
parser.add_argument("--init_strategy", type=str, choices=["random", "closest-to-centeroid", "furthest-to-centeroid", "expected-gradients", "certainty"], default="random")
parser.add_argument("--stopping_criteria", type=str)
parser.add_argument("--dataset_size", type=int, nargs="+", default=[32, 64, 128, 256, 512, 1024, 2048, 4096, 0])
args = parser.parse_args()

full_dataset = load_dataset(args.dataset)
if args.dataset in dataset_prefixes:
full_dataset = load_dataset(dataset_prefixes[args.dataset], args.dataset)
else:
full_dataset = load_dataset(args.dataset)
task_keys = task_to_keys[args.dataset]

full_dataset["test"] = full_dataset[eval_splits[args.dataset]]

if args.dataset == "snli":
full_dataset = full_dataset.filter(lambda x: x["label"] != -1)

for dataset_size in args.dataset_size:
if dataset_size > 0:
dataset = select_fewshots(
Expand All @@ -40,5 +63,6 @@
train_classification(
args,
dataset,
dataset_size,
task_keys
)
5 changes: 3 additions & 2 deletions first_experiment/selection_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def select_fewshots(
dataset = random_selection(
full_dataset,
dataset_size,
task_keys["label_column"]
task_keys
)
elif args.init_strategy == "class-centeroid-closest":
if args.embedding_model is None:
Expand Down Expand Up @@ -65,12 +65,13 @@ def select_fewshots(
def random_selection(
dataset: DatasetDict,
num_total_samples: int,
label_column: str
task_keys: dict
) -> DatasetDict:
"""
Selects a fewshot dataset from the full dataset by randomly selecting examples.
"""
dataset = dataset.shuffle(seed=42)
label_column = task_keys["label_column"]
id2label = dict(enumerate(dataset["train"].features[label_column].names))
num_samples_per_class = num_total_samples // len(dataset["train"].features[label_column].names)
counter = Counter({idx: 0 for idx in id2label.keys()})
Expand Down
10 changes: 5 additions & 5 deletions first_experiment/tam_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def train_classification(
args: Namespace,
dataset: DatasetDict,
dataset_size: int,
task_keys: dict
):
label_column = task_keys["label_column"]
Expand All @@ -28,14 +29,14 @@ def train_classification(

experiment_extension = (f"{args.tam_model}"
f"_{args.dataset}"
f"_{args.dataset_size}"
f"_{dataset_size}"
f"_{args.init_strategy}"
f"_{args.embedding_model if args.embedding_model is not None else ''}")
f"{'_' + args.embedding_model if args.embedding_model is not None else ''}")

log_path = PATH / experiment_extension

batch_size = 16
total_steps = min(len(dataset["train"]) // batch_size * 3, 200)
total_steps = max(len(dataset["train"]) // batch_size * 3, 200)
training_args = TrainingArguments(
output_dir=str(log_path),
learning_rate=2e-5,
Expand All @@ -44,7 +45,7 @@ def train_classification(
num_train_epochs=total_steps * batch_size // len(dataset["train"]),
warmup_ratio=0.1,
weight_decay=0.01,
logging_steps=5,
logging_steps=10,
save_strategy="no",
push_to_hub=False,
)
Expand All @@ -60,7 +61,6 @@ def compute_metrics(eval_pred):
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"] if "validation" in tokenized_dataset else None,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
Expand Down

0 comments on commit 690d109

Please sign in to comment.