-
Notifications
You must be signed in to change notification settings - Fork 11
/
learn_concepts_dataset.py
65 lines (52 loc) · 2.68 KB
/
learn_concepts_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import pickle
import torch
import argparse
import numpy as np
from models import get_model
from concepts import learn_concept_bank
from data import get_concept_loaders
def config():
parser = argparse.ArgumentParser()
parser.add_argument("--backbone-name", default="resnet18_cub", type=str)
parser.add_argument("--dataset-name", default="cub", type=str)
parser.add_argument("--out-dir", required=True, type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--seed", default=1, type=int, help="Random seed")
parser.add_argument("--num-workers", default=4, type=int, help="Number of workers in the data loader.")
parser.add_argument("--batch-size", default=100, type=int, help="Batch size in the concept loader.")
parser.add_argument("--C", nargs="+", default=[0.01, 0.1], type=float, help="Regularization parameter for SVMs.")
parser.add_argument("--n-samples", default=50, type=int,
help="Number of positive/negative samples used to learn concepts.")
return parser.parse_args()
def main():
args = config()
n_samples = args.n_samples
# Bottleneck part of model
backbone, preprocess = get_model(args, args.backbone_name)
backbone = backbone.to(args.device)
backbone = backbone.eval()
concept_libs = {C: {} for C in args.C}
# Get the positive and negative loaders for each concept.
concept_loaders = get_concept_loaders(args.dataset_name, preprocess, n_samples=args.n_samples, batch_size=args.batch_size,
num_workers=args.num_workers, seed=args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
for concept_name, loaders in concept_loaders.items():
pos_loader, neg_loader = loaders['pos'], loaders['neg']
# Get CAV for each concept using positive/negative image split
cav_info = learn_concept_bank(pos_loader, neg_loader, backbone, n_samples, args.C, device="cuda")
# Store CAV train acc, val acc, margin info for each regularization parameter and each concept
for C in args.C:
concept_libs[C][concept_name] = cav_info[C]
print(concept_name, C, cav_info[C][1], cav_info[C][2])
# Save CAV results
for C in concept_libs.keys():
lib_path = os.path.join(args.out_dir, f"{args.dataset_name}_{args.backbone_name}_{C}_{args.n_samples}.pkl")
with open(lib_path, "wb") as f:
pickle.dump(concept_libs[C], f)
print(f"Saved to: {lib_path}")
total_concepts = len(concept_libs[C].keys())
print(f"File: {lib_path}, Total: {total_concepts}")
if __name__ == "__main__":
main()