From e160fdcdfc846cd62cc39e77f2822d1596d3a33a Mon Sep 17 00:00:00 2001 From: Khari Date: Fri, 28 Jun 2024 11:56:51 -0400 Subject: [PATCH 1/5] =?UTF-8?q?Feature=20Integration=20-=20Added=20Kalman?= =?UTF-8?q?=20based=20gradfilter=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- grokfast.py | 57 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/grokfast.py b/grokfast.py index 3a0acd5..2c48080 100644 --- a/grokfast.py +++ b/grokfast.py @@ -9,16 +9,20 @@ def gradfilter_ma( grads: Optional[Dict[str, deque]] = None, window_size: int = 100, lamb: float = 5.0, - filter_type: Literal['mean', 'sum'] = 'mean', + filter_type: Literal["mean", "sum"] = "mean", warmup: bool = True, - trigger: bool = False, # For ablation study. + trigger: bool = False, # For ablation study. ) -> Dict[str, deque]: if grads is None: - grads = {n: deque(maxlen=window_size) for n, p in m.named_parameters() if p.requires_grad and p.grad is not None} + grads = { + n: deque(maxlen=window_size) + for n, p in m.named_parameters() + if p.requires_grad and p.grad is not None + } for n, p in m.named_parameters(): if p.requires_grad and p.grad is not None: - grads[n].append(p.grad.data.detach()) # .cpu()) + grads[n].append(p.grad.data.detach()) # .cpu()) # Modify the gradients. if not warmup or len(grads[n]) == window_size and not trigger: @@ -40,11 +44,52 @@ def gradfilter_ema( lamb: float = 2.0, ) -> Dict[str, torch.Tensor]: if grads is None: - grads = {n: p.grad.data.detach() for n, p in m.named_parameters() if p.requires_grad and p.grad is not None} + grads = { + n: p.grad.data.detach() + for n, p in m.named_parameters() + if p.requires_grad and p.grad is not None + } for n, p in m.named_parameters(): if p.requires_grad and p.grad is not None: grads[n] = grads[n] * alpha + p.grad.data.detach() * (1 - alpha) p.grad.data = p.grad.data + grads[n] * lamb - return grads \ No newline at end of file + return grads + + +def gradfilter_kalman( + m: nn.Module, + grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None, + process_noise: float = 1e-4, + measurement_noise: float = 1e-2, + lamb: float = 2.0, +) -> Dict[str, Dict[str, torch.Tensor]]: + if grads is None: + grads = { + n: {"x": torch.zeros_like(p.grad.data), "P": torch.ones_like(p.grad.data)} + for n, p in m.named_parameters() + if p.requires_grad and p.grad is not None + } + + for n, p in m.named_parameters(): + if p.requires_grad and p.grad is not None: + # Prediction step + x_pred = grads[n]["x"] + P_pred = grads[n]["P"] + process_noise + + # Update step + y = p.grad.data - x_pred + S = P_pred + measurement_noise + K = P_pred / S + x = x_pred + K * y + P = (1 - K) * P_pred + + # Store updated state + grads[n]["x"] = x + grads[n]["P"] = P + + # Apply the filtered gradient + p.grad.data = p.grad.data + x * lamb + + return grads From 108378abcf3ef00341a568ec68e6bf1fd174c2e2 Mon Sep 17 00:00:00 2001 From: Khari Date: Fri, 28 Jun 2024 11:57:22 -0400 Subject: [PATCH 2/5] =?UTF-8?q?Refactor=20-=20Added=20Kalman=20filter=20in?= =?UTF-8?q?tegration=20into=20main=20files=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 96 +++++++---- main_imdb.py | 429 +++++++++++++++++++++++++++++++++++++++----------- main_mnist.py | 174 +++++++++++++------- main_qm9.py | 148 +++++++++++------ 4 files changed, 608 insertions(+), 239 deletions(-) diff --git a/main.py b/main.py index c026fad..7fc7116 100644 --- a/main.py +++ b/main.py @@ -14,8 +14,7 @@ class Block(nn.Module): - """Causal transformer block - """ + """Causal transformer block""" def __init__(self, dim, num_heads): super().__init__() @@ -33,7 +32,7 @@ def forward(self, x): (len(x), len(x)), -float("Inf"), device=x.device, dtype=x.dtype ) attn_mask = torch.triu(attn_mask, diagonal=1) - attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device + attn_mask[torch.isnan(attn_mask)] = 0.0 # fixes all 'nan' on 'mps' device x = self.ln_1(x) a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False) @@ -44,8 +43,7 @@ def forward(self, x): class Decoder(nn.Module): - """Causal Transformer decoder - """ + """Causal Transformer decoder""" def __init__(self, dim=128, num_layers=2, num_heads=4, num_tokens=97, seq_len=5): super().__init__() @@ -71,8 +69,7 @@ def forward(self, x): def multiplication_mod_p_data(p, eq_token, op_token): - """x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p - """ + """x◦y = x/y (mod p) for 0 ≤ x < p, 0 < y < p""" x = torch.arange(p) y = torch.arange(1, p) x, y = torch.cartesian_prod(x, y).T @@ -107,7 +104,7 @@ def main(args): ).to(device) nparams = sum([p.numel() for p in model.parameters() if p.requires_grad]) print(model) - print(f'Total number of parameters: {nparams}') + print(f"Total number of parameters: {nparams}") data = multiplication_mod_p_data(args.p, eq_token, op_token) @@ -170,11 +167,29 @@ def main(args): if args.filter == "none": pass elif args.filter == "ma": - grads = gradfilter_ma(model, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger) + grads = gradfilter_ma( + model, + grads=grads, + window_size=args.window_size, + lamb=args.lamb, + trigger=trigger, + ) elif args.filter == "ema": - grads = gradfilter_ema(model, grads=grads, alpha=args.alpha, lamb=args.lamb) + grads = gradfilter_ema( + model, grads=grads, alpha=args.alpha, lamb=args.lamb + ) + elif args.filter == "kalman": + grads = gradfilter_kalman( + model, + grads=grads, + process_noise=args.process_noise, + measurement_noise=args.measurement_noise, + lamb=args.lamb, + ) else: - raise ValueError(f"Invalid gradient filter type `{args.filter}`") + raise ValueError( + f"Invalid gradient filter type `{args.filter}`" + ) ####### @@ -194,7 +209,11 @@ def main(args): val_loss.append(total_loss / valid_data.shape[-1]) if args.save_weights: - do_save = e <= 500 or (e > 500 and (e + 1) % 100 == 0) or e == int(args.budget) // steps_per_epoch - 1 + do_save = ( + e <= 500 + or (e > 500 and (e + 1) % 100 == 0) + or e == int(args.budget) // steps_per_epoch - 1 + ) else: do_save = (e + 1) % 100 == 0 if do_save: @@ -222,18 +241,18 @@ def main(args): plt.close() results = { - 'its': its, - 'train_acc': train_acc, - 'train_loss': train_loss, - 'val_acc': val_acc, - 'val_loss': val_loss, + "its": its, + "train_acc": train_acc, + "train_loss": train_loss, + "val_acc": val_acc, + "val_loss": val_loss, } if args.save_weights: net_its.append(e) nets.append(copy.deepcopy(model.state_dict())) - results['net_its'] = net_its - results['net'] = nets + results["net_its"] = net_its + results["net"] = nets torch.save(results, f"results/res_{args.label}.pt") @@ -252,37 +271,46 @@ def main(args): parser.add_argument("--optimizer", default="Adam") # Grokfast - parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none") + parser.add_argument( + "--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none" + ) parser.add_argument("--alpha", type=float, default=0.99) parser.add_argument("--window_size", type=int, default=100) parser.add_argument("--lamb", type=float, default=5.0) + parser.add_argument("--process_noise", type=float, default=1e-4) + parser.add_argument("--measurement_noise", type=float, default=1e-2) # Ablation studies - parser.add_argument("--two_stage", action='store_true') - parser.add_argument("--save_weights", action='store_true') + parser.add_argument("--two_stage", action="store_true") + parser.add_argument("--save_weights", action="store_true") args = parser.parse_args() - filter_str = ('_' if args.label != '' else '') + args.filter - window_size_str = f'_w{args.window_size}' - alpha_str = f'_a{args.alpha:.3f}'.replace('.', '') - lamb_str = f'_l{int(args.lamb)}' + filter_str = ("_" if args.label != "" else "") + args.filter + window_size_str = f"_w{args.window_size}" + alpha_str = f"_a{args.alpha:.3f}".replace(".", "") + lamb_str = f"_l{int(args.lamb)}" - if args.filter == 'none': - filter_suffix = '' - elif args.filter == 'ma': + if args.filter == "none": + filter_suffix = "" + elif args.filter == "ma": filter_suffix = window_size_str + lamb_str - elif args.filter == 'ema': + elif args.filter == "ema": filter_suffix = alpha_str + lamb_str + elif args.filter == "kalman": + filter_suffix = ( + f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + + lamb_str + ) else: raise ValueError(f"Unrecognized filter type {args.filter}") - optim_suffix = '' + optim_suffix = "" if args.weight_decay != 0: - optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '') + optim_suffix = optim_suffix + f"_wd{args.weight_decay:.1e}".replace(".", "") if args.lr != 1e-3: - optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}' + optim_suffix = optim_suffix + f"_lrx{int(args.lr / 1e-3)}" args.label = args.label + filter_str + filter_suffix + optim_suffix - print(f'Experiment results saved under name: {args.label}') + print(f"Experiment results saved under name: {args.label}") main(args) diff --git a/main_imdb.py b/main_imdb.py index 6cabff9..cdc11ed 100644 --- a/main_imdb.py +++ b/main_imdb.py @@ -6,8 +6,8 @@ warnings.filterwarnings("ignore") -import numpy as np # linear algebra -import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) +import numpy as np # linear algebra +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) import seaborn as sns import matplotlib.pyplot as plt from tqdm import tqdm @@ -22,7 +22,7 @@ def L2(model): - L2_ = 0. + L2_ = 0.0 for p in model.parameters(): L2_ += torch.sum(p**2) return L2_ @@ -34,8 +34,16 @@ def rescale(model, alpha): class SentimentRNN(nn.Module): - def __init__(self, no_layers, vocab_size, hidden_dim, embedding_dim, output_dim, drop_prob=0.0): - super(SentimentRNN,self).__init__() + def __init__( + self, + no_layers, + vocab_size, + hidden_dim, + embedding_dim, + output_dim, + drop_prob=0.0, + ): + super(SentimentRNN, self).__init__() self.output_dim = output_dim self.hidden_dim = hidden_dim @@ -46,9 +54,13 @@ def __init__(self, no_layers, vocab_size, hidden_dim, embedding_dim, output_dim, # embedding and LSTM layers self.embedding = nn.Embedding(vocab_size, embedding_dim) - #lstm - self.lstm = nn.LSTM(input_size=embedding_dim,hidden_size=self.hidden_dim, - num_layers=no_layers, batch_first=True) + # lstm + self.lstm = nn.LSTM( + input_size=embedding_dim, + hidden_size=self.hidden_dim, + num_layers=no_layers, + batch_first=True, + ) # dropout layer self.dropout = nn.Dropout(drop_prob) @@ -57,16 +69,16 @@ def __init__(self, no_layers, vocab_size, hidden_dim, embedding_dim, output_dim, self.fc = nn.Linear(self.hidden_dim, output_dim) self.sig = nn.Sigmoid() - self.register_buffer('device_checker', torch.zeros(0), False) + self.register_buffer("device_checker", torch.zeros(0), False) - def forward(self,x,hidden): + def forward(self, x, hidden): batch_size = x.size(0) # embeddings and lstm_out embeds = self.embedding(x) # shape: B x S x Feature since batch = True - #print(embeds.shape) #[50, 500, 1000] + # print(embeds.shape) #[50, 500, 1000] lstm_out, hidden = self.lstm(embeds, hidden) - lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim) + lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim) # dropout and fully connected layer out = self.dropout(lstm_out) @@ -78,17 +90,21 @@ def forward(self,x,hidden): # reshape to be batch_size first sig_out = sig_out.view(batch_size, -1) - sig_out = sig_out[:, -1] # get last batch of labels + sig_out = sig_out[:, -1] # get last batch of labels # return last sigmoid output and hidden state return sig_out, hidden def init_hidden(self, batch_size): - ''' Initializes hidden state ''' + """Initializes hidden state""" # Create two new tensors with sizes n_layers x batch_size x hidden_dim, # initialized to zero, for hidden state and cell state of LSTM - h0 = torch.zeros((self.no_layers, batch_size, self.hidden_dim)).to(self.device_checker.device) - c0 = torch.zeros((self.no_layers, batch_size, self.hidden_dim)).to(self.device_checker.device) + h0 = torch.zeros((self.no_layers, batch_size, self.hidden_dim)).to( + self.device_checker.device + ) + c0 = torch.zeros((self.no_layers, batch_size, self.hidden_dim)).to( + self.device_checker.device + ) hidden = (h0, c0) return hidden @@ -106,19 +122,19 @@ def main(args): alpha = args.init_scale device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - base_csv = './IMDB-Dataset.csv' + base_csv = "./IMDB-Dataset.csv" df = pd.read_csv(base_csv) - X, y = df[:args.size]['review'].values, df[:args.size]['sentiment'].values + X, y = df[: args.size]["review"].values, df[: args.size]["sentiment"].values x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y) def preprocess_string(s): # Remove all non-word characters (everything except numbers and letters) - s = re.sub(r"[^\w\s]", '', s) + s = re.sub(r"[^\w\s]", "", s) # Replace all runs of whitespaces with no space - s = re.sub(r"\s+", '', s) + s = re.sub(r"\s+", "", s) # replace digits with no space - s = re.sub(r"\d", '', s) + s = re.sub(r"\d", "", s) return s @@ -126,62 +142,241 @@ def tockenize(x_train, y_train, x_val, y_val): word_list = [] stop_words = { - 'a', 'about', 'above', 'after', 'again', 'against', 'ain', 'all', 'am', 'an', - 'and', 'any', 'are', 'aren', "aren't", 'as', 'at', 'be', 'because', 'been', - 'before', 'being', 'below', 'between', 'both', 'but', 'by', 'can', 'couldn', "couldn't", - 'd', 'did', 'didn', "didn't", 'do', 'does', 'doesn', "doesn't", 'doing', 'don', - "don't", 'down', 'during', 'each', 'few', 'for', 'from', 'further', 'had', 'hadn', - "hadn't", 'has', 'hasn', "hasn't", 'have', 'haven', "haven't", 'having', 'he', 'her', - 'here', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'i', 'if', 'in', - 'into', 'is', 'isn', "isn't", 'it', "it's", 'its', 'itself', 'just', 'll', - 'm', 'ma', 'me', 'mightn', "mightn't", 'more', 'most', 'mustn', "mustn't", 'my', - 'myself', 'needn', "needn't", 'no', 'nor', 'not', 'now', 'o', 'of', 'off', - 'on', 'once', 'only', 'or', 'other', 'our', 'ours', 'ourselves', 'out', 'over', - 'own', 're', 's', 'same', 'shan', "shan't", 'she', "she's", 'should', "should've", - 'shouldn', "shouldn't", 'so', 'some', 'such', 't', 'than', 'that', "that'll", 'the', - 'their', 'theirs', 'them', 'themselves', 'then', 'there', 'these', 'they', 'this', 'those', - 'through', 'to', 'too', 'under', 'until', 'up', 've', 'very', 'was', 'wasn', - "wasn't", 'we', 'were', 'weren', "weren't", 'what', 'when', 'where', 'which', 'while', - 'who', 'whom', 'why', 'will', 'with', 'won', "won't", 'wouldn', "wouldn't", 'y', - 'you', "you'd", "you'll", "you're", "you've", 'your', 'yours', 'yourself', 'yourselves', + "a", + "about", + "above", + "after", + "again", + "against", + "ain", + "all", + "am", + "an", + "and", + "any", + "are", + "aren", + "aren't", + "as", + "at", + "be", + "because", + "been", + "before", + "being", + "below", + "between", + "both", + "but", + "by", + "can", + "couldn", + "couldn't", + "d", + "did", + "didn", + "didn't", + "do", + "does", + "doesn", + "doesn't", + "doing", + "don", + "don't", + "down", + "during", + "each", + "few", + "for", + "from", + "further", + "had", + "hadn", + "hadn't", + "has", + "hasn", + "hasn't", + "have", + "haven", + "haven't", + "having", + "he", + "her", + "here", + "hers", + "herself", + "him", + "himself", + "his", + "how", + "i", + "if", + "in", + "into", + "is", + "isn", + "isn't", + "it", + "it's", + "its", + "itself", + "just", + "ll", + "m", + "ma", + "me", + "mightn", + "mightn't", + "more", + "most", + "mustn", + "mustn't", + "my", + "myself", + "needn", + "needn't", + "no", + "nor", + "not", + "now", + "o", + "of", + "off", + "on", + "once", + "only", + "or", + "other", + "our", + "ours", + "ourselves", + "out", + "over", + "own", + "re", + "s", + "same", + "shan", + "shan't", + "she", + "she's", + "should", + "should've", + "shouldn", + "shouldn't", + "so", + "some", + "such", + "t", + "than", + "that", + "that'll", + "the", + "their", + "theirs", + "them", + "themselves", + "then", + "there", + "these", + "they", + "this", + "those", + "through", + "to", + "too", + "under", + "until", + "up", + "ve", + "very", + "was", + "wasn", + "wasn't", + "we", + "were", + "weren", + "weren't", + "what", + "when", + "where", + "which", + "while", + "who", + "whom", + "why", + "will", + "with", + "won", + "won't", + "wouldn", + "wouldn't", + "y", + "you", + "you'd", + "you'll", + "you're", + "you've", + "your", + "yours", + "yourself", + "yourselves", } for sent in x_train: for word in sent.lower().split(): word = preprocess_string(word) - if word not in stop_words and word != '': + if word not in stop_words and word != "": word_list.append(word) corpus = Counter(word_list) # sorting on the basis of most common words - corpus_ = sorted(corpus,key=corpus.get,reverse=True)[:1000] + corpus_ = sorted(corpus, key=corpus.get, reverse=True)[:1000] # creating a dict - onehot_dict = {w:i+1 for i,w in enumerate(corpus_)} + onehot_dict = {w: i + 1 for i, w in enumerate(corpus_)} def _padding(sentences, seq_len): features = np.zeros((len(sentences), seq_len), dtype=int) for i, review in enumerate(sentences): if len(review) != 0: - features[i, -len(review):] = np.array(review)[:seq_len] + features[i, -len(review) :] = np.array(review)[:seq_len] return features # tockenize final_list_train, final_list_test = [], [] for sent in x_train: - final_list_train.append([onehot_dict[preprocess_string(word)] for word in sent.lower().split() - if preprocess_string(word) in onehot_dict.keys()]) + final_list_train.append( + [ + onehot_dict[preprocess_string(word)] + for word in sent.lower().split() + if preprocess_string(word) in onehot_dict.keys() + ] + ) final_list_train = _padding(final_list_train, 500) for sent in x_val: - final_list_test.append([onehot_dict[preprocess_string(word)] for word in sent.lower().split() - if preprocess_string(word) in onehot_dict.keys()]) + final_list_test.append( + [ + onehot_dict[preprocess_string(word)] + for word in sent.lower().split() + if preprocess_string(word) in onehot_dict.keys() + ] + ) final_list_test = _padding(final_list_test, 500) - encoded_train = [1 if label =='positive' else 0 for label in y_train] - encoded_test = [1 if label =='positive' else 0 for label in y_val] - return np.array(final_list_train), np.array(encoded_train), np.array(final_list_test), np.array(encoded_test), onehot_dict + encoded_train = [1 if label == "positive" else 0 for label in y_train] + encoded_test = [1 if label == "positive" else 0 for label in y_val] + return ( + np.array(final_list_train), + np.array(encoded_train), + np.array(final_list_test), + np.array(encoded_test), + onehot_dict, + ) # create Tensor datasets - x_train, y_train, x_test, y_test, vocab = tockenize(x_train, y_train, x_test, y_test) + x_train, y_train, x_test, y_test, vocab = tockenize( + x_train, y_train, x_test, y_test + ) train_data = TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) valid_data = TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)) @@ -195,30 +390,44 @@ def _padding(sentences, seq_len): # define model no_layers = 2 - vocab_size = len(vocab) + 1 #extra 1 for padding + vocab_size = len(vocab) + 1 # extra 1 for padding embedding_dim = 64 output_dim = 1 hidden_dim = 256 - model = SentimentRNN(no_layers, vocab_size, hidden_dim, embedding_dim, output_dim, drop_prob=0.0) + model = SentimentRNN( + no_layers, vocab_size, hidden_dim, embedding_dim, output_dim, drop_prob=0.0 + ) model.to(device) - + rescale(model, alpha) L2_ = L2(model) # loss and optimization functions criterion = nn.BCELoss() - optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + optimizer = torch.optim.AdamW( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) ####### train_loss_min = np.Inf valid_loss_min = np.Inf - train_acc_max = 0. - valid_acc_max = 0. + train_acc_max = 0.0 + valid_acc_max = 0.0 # train for some number of epochs - log_steps, train_accs, test_accs, test_losses, train_losses, train_avg_losses, test_avg_losses, train_avg_accs, test_avg_accs = [], [], [], [], [], [], [], [], [] + ( + log_steps, + train_accs, + test_accs, + test_losses, + train_losses, + train_avg_losses, + test_avg_losses, + train_avg_accs, + test_avg_accs, + ) = ([], [], [], [], [], [], [], [], []) step, epoch = 0, 0 grads = None @@ -236,7 +445,7 @@ def _padding(sentences, seq_len): model.train() inputs, labels = inputs.to(device), labels.to(device) - # initialize hidden state + # initialize hidden state h = model.init_hidden(inputs.shape[0]) # Creating new variables for the hidden state, otherwise # we'd backprop through the entire training history @@ -268,9 +477,25 @@ def _padding(sentences, seq_len): if args.filter == "none": pass elif args.filter == "ma": - grads = gradfilter_ma(model, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger) + grads = gradfilter_ma( + model, + grads=grads, + window_size=args.window_size, + lamb=args.lamb, + trigger=trigger, + ) elif args.filter == "ema": - grads = gradfilter_ema(model, grads=grads, alpha=args.alpha, lamb=args.lamb) + grads = gradfilter_ema( + model, grads=grads, alpha=args.alpha, lamb=args.lamb + ) + elif args.filter == "kalman": + grads = gradfilter_kalman( + model, + grads=grads, + process_noise=args.process_noise, + measurement_noise=args.measurement_noise, + lamb=args.lamb, + ) else: raise ValueError(f"Invalid gradient filter type `{args.filter}`") @@ -299,8 +524,10 @@ def _padding(sentences, seq_len): test_accs.append(val_acc) if (step + 1) % 10 == 0: - tqdm.write(f'step : {step} train_loss : {loss.item()} val_loss : {val_loss.item()}\n' - f'train_accuracy : {train_acc} val_accuracy : {val_acc}') + tqdm.write( + f"step : {step} train_loss : {loss.item()} val_loss : {val_loss.item()}\n" + f"train_accuracy : {train_acc} val_accuracy : {val_acc}" + ) step += 1 pbar.update() @@ -319,12 +546,14 @@ def _padding(sentences, seq_len): test_epoch_acc = test_epoch_acc / test_size test_avg_accs.append(test_epoch_acc) - tqdm.write(f"Epochs: {epoch} | epoch avg. acc: {epoch_acc:.3f} | " - f"test avg. acc: {test_epoch_acc:.3f}") + tqdm.write( + f"Epochs: {epoch} | epoch avg. acc: {epoch_acc:.3f} | " + f"test avg. acc: {test_epoch_acc:.3f}" + ) if (epoch + 1) % 100 == 0 or step == args.iterations - 1: - title = (f"IMDb Binary Sentiment Analysis") + title = f"IMDb Binary Sentiment Analysis" plt.plot(np.arange(step), train_accs, label="train") plt.plot(np.arange(step), test_accs, label="val") @@ -349,18 +578,21 @@ def _padding(sentences, seq_len): plt.savefig(f"results/imdb_loss_{args.label}.png", dpi=150) plt.close() - torch.save({ - 'its': np.arange(len(train_losses)), - 'its_avg': np.arange(len(train_avg_losses)), - 'train_acc': train_accs, - 'train_loss': train_losses, - 'train_avg_acc': train_avg_accs, - 'train_avg_loss': train_avg_losses, - 'val_acc': test_accs, - 'val_loss': test_losses, - 'val_avg_acc': test_avg_accs, - 'val_avg_loss': test_avg_losses, - }, f"results/imdb_{args.label}.pt") + torch.save( + { + "its": np.arange(len(train_losses)), + "its_avg": np.arange(len(train_avg_losses)), + "train_acc": train_accs, + "train_loss": train_losses, + "train_avg_acc": train_avg_accs, + "train_avg_loss": train_avg_losses, + "val_acc": test_accs, + "val_loss": test_losses, + "val_avg_acc": test_avg_accs, + "val_avg_loss": test_avg_losses, + }, + f"results/imdb_{args.label}.pt", + ) epoch += 1 @@ -375,38 +607,49 @@ def _padding(sentences, seq_len): parser.add_argument("--weight_decay", type=float, default=1.0) parser.add_argument("--gradient_clip", type=float, default=5.0) parser.add_argument("--size", type=int, default=1000) - parser.add_argument("--init_scale", type=float, default=6.0) # init_scale 1.0 no grokking / init_scale 6.0 grokking + parser.add_argument( + "--init_scale", type=float, default=6.0 + ) # init_scale 1.0 no grokking / init_scale 6.0 grokking # Grokfast - parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none") + parser.add_argument( + "--filter", type=str, choices=["none", "ma", "ema", "kalman"], default="none" + ) + parser.add_argument("--process_noise", type=float, default=1e-4) + parser.add_argument("--measurement_noise", type=float, default=1e-2) parser.add_argument("--alpha", type=float, default=0.99) parser.add_argument("--window_size", type=int, default=100) parser.add_argument("--lamb", type=float, default=5.0) args = parser.parse_args() - model_suffix = f'size{args.size}_alpha{args.init_scale:.4f}' + model_suffix = f"size{args.size}_alpha{args.init_scale:.4f}" - filter_str = ('_' if args.label != '' else '') + args.filter - window_size_str = f'_w{args.window_size}' - alpha_str = f'_a{args.alpha:.3f}'.replace('.', '') - lamb_str = f'_l{int(args.lamb)}' + filter_str = ("_" if args.label != "" else "") + args.filter + window_size_str = f"_w{args.window_size}" + alpha_str = f"_a{args.alpha:.3f}".replace(".", "") + lamb_str = f"_l{int(args.lamb)}" - if args.filter == 'none': - filter_suffix = '' - elif args.filter == 'ma': + if args.filter == "none": + filter_suffix = "" + elif args.filter == "ma": filter_suffix = window_size_str + lamb_str - elif args.filter == 'ema': + elif args.filter == "ema": filter_suffix = alpha_str + lamb_str + elif args.filter == "kalman": + filter_suffix = ( + f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + + lamb_str + ) else: raise ValueError(f"Unrecognized filter type {args.filter}") - optim_suffix = '' + optim_suffix = "" if args.weight_decay != 0: - optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '') + optim_suffix = optim_suffix + f"_wd{args.weight_decay:.1e}".replace(".", "") if args.lr != 1e-3: - optim_suffix = optim_suffix + f'_lrx{int(args.lr / 0.0003)}' + optim_suffix = optim_suffix + f"_lrx{int(args.lr / 0.0003)}" args.label = args.label + model_suffix + filter_str + filter_suffix + optim_suffix - print(f'Experiment results saved under name: {args.label}') + print(f"Experiment results saved under name: {args.label}") main(args) diff --git a/main_mnist.py b/main_mnist.py index 9f90dda..4bfa6e6 100644 --- a/main_mnist.py +++ b/main_mnist.py @@ -24,12 +24,13 @@ def cycle(iterable): def compute_accuracy(network, dataset, device, N=2000, batch_size=50): - """Computes accuracy of `network` on `dataset`. - """ + """Computes accuracy of `network` on `dataset`.""" with torch.no_grad(): N = min(len(dataset), N) batch_size = min(batch_size, N) - dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataset_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True + ) correct = 0 total = 0 for x, labels in islice(dataset_loader, N // batch_size): @@ -41,43 +42,41 @@ def compute_accuracy(network, dataset, device, N=2000, batch_size=50): def compute_loss(network, dataset, loss_function, device, N=2000, batch_size=50): - """Computes mean loss of `network` on `dataset`. - """ + """Computes mean loss of `network` on `dataset`.""" with torch.no_grad(): N = min(len(dataset), N) batch_size = min(batch_size, N) - dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) - loss_fn = loss_function_dict[loss_function](reduction='sum') + dataset_loader = torch.utils.data.DataLoader( + dataset, batch_size=batch_size, shuffle=True + ) + loss_fn = loss_function_dict[loss_function](reduction="sum") one_hots = torch.eye(10, 10).to(device) total = 0 points = 0 for x, labels in islice(dataset_loader, N // batch_size): y = network(x.to(device)) - if loss_function == 'CrossEntropy': + if loss_function == "CrossEntropy": total += loss_fn(y, labels.to(device)).item() - elif loss_function == 'MSE': + elif loss_function == "MSE": total += loss_fn(y, one_hots[labels]).item() points += len(labels) return total / points optimizer_dict = { - 'AdamW': torch.optim.AdamW, - 'Adam': torch.optim.Adam, - 'SGD': torch.optim.SGD + "AdamW": torch.optim.AdamW, + "Adam": torch.optim.Adam, + "SGD": torch.optim.SGD, } activation_dict = { - 'ReLU': nn.ReLU, - 'Tanh': nn.Tanh, - 'Sigmoid': nn.Sigmoid, - 'GELU': nn.GELU + "ReLU": nn.ReLU, + "Tanh": nn.Tanh, + "Sigmoid": nn.Sigmoid, + "GELU": nn.GELU, } -loss_function_dict = { - 'MSE': nn.MSELoss, - 'CrossEntropy': nn.CrossEntropyLoss -} +loss_function_dict = {"MSE": nn.MSELoss, "CrossEntropy": nn.CrossEntropyLoss} def main(args): @@ -93,14 +92,26 @@ def main(args): np.random.seed(args.seed) # load dataset - train = torchvision.datasets.MNIST(root=args.download_directory, train=True, - transform=torchvision.transforms.ToTensor(), download=True) - test = torchvision.datasets.MNIST(root=args.download_directory, train=False, - transform=torchvision.transforms.ToTensor(), download=True) + train = torchvision.datasets.MNIST( + root=args.download_directory, + train=True, + transform=torchvision.transforms.ToTensor(), + download=True, + ) + test = torchvision.datasets.MNIST( + root=args.download_directory, + train=False, + transform=torchvision.transforms.ToTensor(), + download=True, + ) train = torch.utils.data.Subset(train, range(args.train_points)) - train_loader = torch.utils.data.DataLoader(train, batch_size=args.batch_size, shuffle=True) + train_loader = torch.utils.data.DataLoader( + train, batch_size=args.batch_size, shuffle=True + ) - assert args.activation in activation_dict, f"Unsupported activation function: {args.activation}" + assert ( + args.activation in activation_dict + ), f"Unsupported activation function: {args.activation}" activation_fn = activation_dict[args.activation] # create model @@ -119,17 +130,20 @@ def main(args): for p in mlp.parameters(): p.data = args.initialization_scale * p.data nparams = sum([p.numel() for p in mlp.parameters() if p.requires_grad]) - print(f'Number of parameters: {nparams}') + print(f"Number of parameters: {nparams}") # create optimizer - assert args.optimizer in optimizer_dict, f"Unsupported optimizer choice: {args.optimizer}" - optimizer = optimizer_dict[args.optimizer](mlp.parameters(), lr=args.lr, weight_decay=args.weight_decay) + assert ( + args.optimizer in optimizer_dict + ), f"Unsupported optimizer choice: {args.optimizer}" + optimizer = optimizer_dict[args.optimizer]( + mlp.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) # define loss function assert args.loss_function in loss_function_dict loss_fn = loss_function_dict[args.loss_function]() - train_losses, test_losses, train_accuracies, test_accuracies = [], [], [], [] norms, last_layer_norms, log_steps = [], [], [] grads = None @@ -138,11 +152,21 @@ def main(args): one_hots = torch.eye(10, 10).to(device) with tqdm(total=args.optimization_steps, dynamic_ncols=True) as pbar: for x, labels in islice(cycle(train_loader), args.optimization_steps): - do_log = (steps < 30) or (steps < 150 and steps % 10 == 0) or steps % log_freq == 0 + do_log = ( + (steps < 30) + or (steps < 150 and steps % 10 == 0) + or steps % log_freq == 0 + ) if do_log: - train_losses.append(compute_loss(mlp, train, args.loss_function, device, N=len(train))) - train_accuracies.append(compute_accuracy(mlp, train, device, N=len(train))) - test_losses.append(compute_loss(mlp, test, args.loss_function, device, N=len(test))) + train_losses.append( + compute_loss(mlp, train, args.loss_function, device, N=len(train)) + ) + train_accuracies.append( + compute_accuracy(mlp, train, device, N=len(train)) + ) + test_losses.append( + compute_loss(mlp, test, args.loss_function, device, N=len(test)) + ) test_accuracies.append(compute_accuracy(mlp, test, device, N=len(test))) log_steps.append(steps) @@ -150,15 +174,15 @@ def main(args): "L: {0:1.1e}|{1:1.1e}. A: {2:2.1f}%|{3:2.1f}%".format( train_losses[-1], test_losses[-1], - train_accuracies[-1] * 100, + train_accuracies[-1] * 100, test_accuracies[-1] * 100, ) ) y = mlp(x.to(device)) - if args.loss_function == 'CrossEntropy': + if args.loss_function == "CrossEntropy": loss = loss_fn(y, labels.to(device)) - elif args.loss_function == 'MSE': + elif args.loss_function == "MSE": loss = loss_fn(y, one_hots[labels]) optimizer.zero_grad() @@ -171,9 +195,25 @@ def main(args): if args.filter == "none": pass elif args.filter == "ma": - grads = gradfilter_ma(mlp, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger) + grads = gradfilter_ma( + mlp, + grads=grads, + window_size=args.window_size, + lamb=args.lamb, + trigger=trigger, + ) elif args.filter == "ema": - grads = gradfilter_ema(mlp, grads=grads, alpha=args.alpha, lamb=args.lamb) + grads = gradfilter_ema( + mlp, grads=grads, alpha=args.alpha, lamb=args.lamb + ) + elif args.filter == "kalman": + grads = gradfilter_kalman( + mlp, + grads=grads, + process_noise=args.process_noise, + measurement_noise=args.measurement_noise, + lamb=args.lamb, + ) else: raise ValueError(f"Invalid gradient filter type `{args.filter}`") @@ -185,7 +225,7 @@ def main(args): pbar.update(1) if do_log: - title = (f"MNIST Image Classification") + title = f"MNIST Image Classification" plt.plot(log_steps, train_accuracies, label="train") plt.plot(log_steps, test_accuracies, label="val") @@ -210,16 +250,19 @@ def main(args): plt.savefig(f"results/mnist_loss_{args.label}.png", dpi=150) plt.close() - torch.save({ - 'its': log_steps, - 'train_acc': train_accuracies, - 'train_loss': train_losses, - 'val_acc': test_accuracies, - 'val_loss': test_losses, - }, f"results/mnist_{args.label}.pt") + torch.save( + { + "its": log_steps, + "train_acc": train_accuracies, + "train_loss": train_losses, + "val_acc": test_accuracies, + "val_loss": test_losses, + }, + f"results/mnist_{args.label}.pt", + ) -if __name__ == '__main__': +if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--label", default="") parser.add_argument("--seed", type=int, default=0) @@ -238,33 +281,42 @@ def main(args): parser.add_argument("--activation", type=str, default="ReLU") # Grokfast - parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none") + parser.add_argument( + "--filter", type=str, choices=["none", "ma", "ema", "kalman"], default="none" + ) + parser.add_argument("--process_noise", type=float, default=1e-4) + parser.add_argument("--measurement_noise", type=float, default=1e-2) parser.add_argument("--alpha", type=float, default=0.99) parser.add_argument("--window_size", type=int, default=100) parser.add_argument("--lamb", type=float, default=5.0) args = parser.parse_args() - filter_str = ('_' if args.label != '' else '') + args.filter - window_size_str = f'_w{args.window_size}' - alpha_str = f'_a{args.alpha:.3f}'.replace('.', '') - lamb_str = f'_l{args.lamb:.2f}'.replace('.', '') + filter_str = ("_" if args.label != "" else "") + args.filter + window_size_str = f"_w{args.window_size}" + alpha_str = f"_a{args.alpha:.3f}".replace(".", "") + lamb_str = f"_l{args.lamb:.2f}".replace(".", "") - if args.filter == 'none': - filter_suffix = '' - elif args.filter == 'ma': + if args.filter == "none": + filter_suffix = "" + elif args.filter == "ma": filter_suffix = window_size_str + lamb_str - elif args.filter == 'ema': + elif args.filter == "ema": filter_suffix = alpha_str + lamb_str + elif args.filter == "kalman": + filter_suffix = ( + f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + + lamb_str + ) else: raise ValueError(f"Unrecognized filter type {args.filter}") - optim_suffix = '' + optim_suffix = "" if args.weight_decay != 0: - optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '') + optim_suffix = optim_suffix + f"_wd{args.weight_decay:.1e}".replace(".", "") if args.lr != 1e-3: - optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}' + optim_suffix = optim_suffix + f"_lrx{int(args.lr / 1e-3)}" args.label = args.label + filter_str + filter_suffix + optim_suffix - print(f'Experiment results saved under name: {args.label}') + print(f"Experiment results saved under name: {args.label}") main(args) diff --git a/main_qm9.py b/main_qm9.py index 6b18cb3..a227d7a 100644 --- a/main_qm9.py +++ b/main_qm9.py @@ -21,11 +21,11 @@ def __init__(self, num_node_features, num_edge_features): conv1_net = nn.Sequential( nn.Linear(num_edge_features, 32), nn.ReLU(), - nn.Linear(32, num_node_features * 32)) + nn.Linear(32, num_node_features * 32), + ) conv2_net = nn.Sequential( - nn.Linear(num_edge_features, 32), - nn.ReLU(), - nn.Linear(32, 32 * 16)) + nn.Linear(num_edge_features, 32), nn.ReLU(), nn.Linear(32, 32 * 16) + ) self.conv1 = NNConv(num_node_features, 32, conv1_net) self.conv2 = NNConv(32, 16, conv2_net) self.fc_1 = nn.Linear(16, 32) @@ -33,19 +33,23 @@ def __init__(self, num_node_features, num_edge_features): def forward(self, data): batch, x, edge_index, edge_attr = ( - data.batch, data.x, data.edge_index, data.edge_attr) + data.batch, + data.x, + data.edge_index, + data.edge_attr, + ) # First graph conv layer x = F.relu(self.conv1(x, edge_index, edge_attr)) # Second graph conv layer x = F.relu(self.conv2(x, edge_index, edge_attr)) - x = global_add_pool(x,batch) + x = global_add_pool(x, batch) x = F.relu(self.fc_1(x)) output = self.out(x) return output def L2(model): - L2_ = 0. + L2_ = 0.0 for p in model.parameters(): L2_ += torch.sum(p**2) return L2_ @@ -61,13 +65,13 @@ def main(args): torch.manual_seed(args.seed) alpha = args.init_scale - #size = 1000 + # size = 1000 epochs = int(100 * 50000 / args.size) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the QM9 small molecule dataset - dset = QM9('.') - dset = dset[:args.size] + dset = QM9(".") + dset = dset[: args.size] train_set, test_set = random_split(dset, [int(args.size / 2), int(args.size / 2)]) trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) testloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True) @@ -77,13 +81,15 @@ def main(args): net = ExampleNet(qm9_node_feats, qm9_edge_feats) # initialize an optimizer with some reasonable parameters - optimizer = torch.optim.AdamW(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) - target_idx = 1 # index position of the polarizability label + optimizer = torch.optim.AdamW( + net.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + target_idx = 1 # index position of the polarizability label net.to(device) - + rescale(net, alpha) L2_ = L2(net) - + train_best = 1e10 test_best = 1e10 @@ -94,7 +100,7 @@ def main(args): for total_epochs in tqdm.trange(epochs): epoch_loss = 0 total_graphs_train = 0 - + for batch in trainloader: net.train() batch.to(device) @@ -113,9 +119,25 @@ def main(args): if args.filter == "none": pass elif args.filter == "ma": - grads = gradfilter_ma(net, grads=grads, window_size=args.window_size, lamb=args.lamb, trigger=trigger) + grads = gradfilter_ma( + net, + grads=grads, + window_size=args.window_size, + lamb=args.lamb, + trigger=trigger, + ) elif args.filter == "ema": - grads = gradfilter_ema(net, grads=grads, alpha=args.alpha, lamb=args.lamb) + grads = gradfilter_ema( + net, grads=grads, alpha=args.alpha, lamb=args.lamb + ) + elif args.filter == "kalman": + grads = gradfilter_kalman( + net, + grads=grads, + process_noise=args.process_noise, + measurement_noise=args.measurement_noise, + lamb=args.lamb, + ) else: raise ValueError(f"Invalid gradient filter type `{args.filter}`") @@ -154,8 +176,10 @@ def main(args): ####### - tqdm.tqdm.write(f"Epochs: {total_epochs} | epoch avg. loss: {train_avg_loss:.3f} | " - f"test avg. loss: {test_avg_loss:.3f}") + tqdm.tqdm.write( + f"Epochs: {total_epochs} | epoch avg. loss: {train_avg_loss:.3f} | " + f"test avg. loss: {test_avg_loss:.3f}" + ) if (total_epochs + 1) % 100 == 0 or total_epochs == epochs - 1: @@ -172,27 +196,38 @@ def main(args): plt.savefig(f"results/qm9_loss_{args.label}.png", dpi=150) plt.close() - torch.save({ - 'its': np.arange(len(train_losses)), - 'its_avg': np.arange(len(train_avg_losses)), - 'train_acc': None, - 'train_loss': train_losses, - 'train_avg_loss': train_avg_losses, - 'val_acc': None, - 'val_loss': test_losses, - 'val_avg_loss': test_avg_losses, - 'train_best': train_best, - 'val_best': test_best, - }, f"results/qm9_{args.label}.pt") + torch.save( + { + "its": np.arange(len(train_losses)), + "its_avg": np.arange(len(train_avg_losses)), + "train_acc": None, + "train_loss": train_losses, + "train_avg_loss": train_avg_losses, + "val_acc": None, + "val_loss": test_losses, + "val_avg_loss": test_avg_losses, + "train_best": train_best, + "val_best": test_best, + }, + f"results/qm9_{args.label}.pt", + ) ####### fig, ax = plt.subplots(1, 1, figsize=(4.2, 4.2)) - ax.plot((np.arange(len(test_losses))+1)[::20], np.mean(np.array(test_losses).reshape(-1, 20), axis=1), color='#ff7f0e') - ax.plot((np.arange(len(train_losses))+1)[::20], np.mean(np.array(train_losses).reshape(-1, 20), axis=1), color='#1f77b4') - ax.set_xscale('log') - ax.set_yscale('log') + ax.plot( + (np.arange(len(test_losses)) + 1)[::20], + np.mean(np.array(test_losses).reshape(-1, 20), axis=1), + color="#ff7f0e", + ) + ax.plot( + (np.arange(len(train_losses)) + 1)[::20], + np.mean(np.array(train_losses).reshape(-1, 20), axis=1), + color="#1f77b4", + ) + ax.set_xscale("log") + ax.set_yscale("log") ax.set_ylim(1e-2, 1000) ax.set_ylabel("MSE", fontsize=15) @@ -212,38 +247,49 @@ def main(args): parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--weight_decay", type=float, default=0) parser.add_argument("--size", type=int, default=100) - parser.add_argument("--init_scale", type=float, default=3.0) # init_scale 1.0 no grokking / init_scale 3.0 grokking + parser.add_argument( + "--init_scale", type=float, default=3.0 + ) # init_scale 1.0 no grokking / init_scale 3.0 grokking # Grokfast - parser.add_argument("--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none") + parser.add_argument( + "--filter", type=str, choices=["none", "ma", "ema", "kalman"], default="none" + ) + parser.add_argument("--process_noise", type=float, default=1e-4) + parser.add_argument("--measurement_noise", type=float, default=1e-2) parser.add_argument("--alpha", type=float, default=0.99) parser.add_argument("--window_size", type=int, default=100) parser.add_argument("--lamb", type=float, default=5.0) args = parser.parse_args() - filter_str = ('_' if args.label != '' else '') + args.filter - window_size_str = f'_w{args.window_size}' - alpha_str = f'_a{args.alpha:.3f}'.replace('.', '') - lamb_str = f'_l{args.lamb:.2f}'.replace('.', '') + filter_str = ("_" if args.label != "" else "") + args.filter + window_size_str = f"_w{args.window_size}" + alpha_str = f"_a{args.alpha:.3f}".replace(".", "") + lamb_str = f"_l{args.lamb:.2f}".replace(".", "") - model_suffix = f'size{args.size}_alpha{args.init_scale:.4f}' + model_suffix = f"size{args.size}_alpha{args.init_scale:.4f}" - if args.filter == 'none': - filter_suffix = '' - elif args.filter == 'ma': + if args.filter == "none": + filter_suffix = "" + elif args.filter == "ma": filter_suffix = window_size_str + lamb_str - elif args.filter == 'ema': + elif args.filter == "ema": filter_suffix = alpha_str + lamb_str + elif args.filter == "kalman": + filter_suffix = ( + f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + + lamb_str + ) else: raise ValueError(f"Unrecognized filter type {args.filter}") - optim_suffix = '' + optim_suffix = "" if args.weight_decay != 0: - optim_suffix = optim_suffix + f'_wd{args.weight_decay:.1e}'.replace('.', '') + optim_suffix = optim_suffix + f"_wd{args.weight_decay:.1e}".replace(".", "") if args.lr != 1e-3: - optim_suffix = optim_suffix + f'_lrx{int(args.lr / 1e-3)}' + optim_suffix = optim_suffix + f"_lrx{int(args.lr / 1e-3)}" args.label = args.label + model_suffix + filter_str + filter_suffix + optim_suffix - print(f'Experiment results saved under name: {args.label}') + print(f"Experiment results saved under name: {args.label}") - main(args) \ No newline at end of file + main(args) From cffce40efaf122e514e1d00d1f61bfbfb1b2c66a Mon Sep 17 00:00:00 2001 From: Khari Date: Fri, 28 Jun 2024 11:57:59 -0400 Subject: [PATCH 3/5] =?UTF-8?q?Refactor=20-=20Updated=20README=20with=20Ka?= =?UTF-8?q?lman=20Filter=20instructions=20and=20information=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 11efe75..5d00a7c 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,8 @@ loss.backwards() # Calculate the gradients. grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb) ### Option 2: Grokfast-MA (has argument window_size, lamb) # grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb) +### Option 3: Grokfast-Kalman (has arguments process_noise, measurement_noise, lamb) +# grads = gradfilter_kalman(model, grads=grads, process_noise=process_noise, measurement_noise=measurement_noise, lamb=lamb optimizer.step() # Call the optimizer. # ... logging & other codes. @@ -155,6 +157,14 @@ def gradfilter_ma( - `warmup: bool = True`: If true, filter is not applied until the queue is filled. - `trigger: bool = False`: For ablation study only. If true, the filter is simply not applied. +3. Grokfast-Kalman (`gradfilter_kalman`) + + - `m: nn.Module`: Model that contains every trainable parameters. + - `grads: Optional[Dict[str, Dict[str, torch.Tensor]]] = None`: Running memory (Kalman filter state). Initialize by setting it to `None`. Feed the output of the method recursively after on. + - `process_noise: float = 1e-4`: Process noise parameter for the Kalman filter. + - `measurement_noise: float = 1e-2`: Measurement noise parameter for the Kalman filter. + - `lamb: float = 2.0`: Amplifying factor hyperparameter of the filter. + --- ## Reproduction @@ -242,9 +252,9 @@ python main_qm9.py --label test --alpha 0.9 --lamb 1.0 --weight_decay 0.01 These recommendations are based on my experiences during the experiments shown in the main manuscript. This may not work perfectly to every other problems, and maybe more intelligent techniques can do better jobs than this procedure. So, please take these as one possible starting guidelines for designing your own filters. -1. **Cutoff parameters**: The work uses MA/EMA filters to implement the filtering techniques. The cutoff frequency is determined by the _window size_ for the MA filter, and the _momentum parameter_ for the EMA filter. +1. **Cutoff parameters**: The work uses MA/EMA/Kalman filters to implement the filtering techniques. The cutoff frequency is determined by the _window size_ for the MA filter, the _momentum parameter_ for the EMA filter, and the _process noise_ and _measurement noise_ for the Kalman filter. 1. **Roughly figure out the amount of acceleration you want to achieve.** For example, in the main manuscript, the cutoff parameters are determined based on the original grokking report, where experiments shows generalization happening X100 slower than overfitting. Therefore, we want *N=100* times faster acceleration. - 2. **Set the pivotal values for the cutoff parameter search.** For MA, I started to set the window size of "w=N=100" and for EMA, I began with the momentum parameter alpha that satisfies "alpha^{N} = alpha^{100} = 0.1" (which is roughly alpha ~ 0.98). + 2. **Set the pivotal values for the cutoff parameter search.** For MA, I started to set the window size of "w=N=100" and for EMA, I began with the momentum parameter alpha that satisfies "alpha^{N} = alpha^{100} = 0.1" (which is roughly alpha ~ 0.98). For the Kalman filter, start with process_noise=1e-4 and measurement_noise=1e-2.** These are reasonable starting points, but you may need to adjust them based on your specific task. 3. **Perform hyperparameter search near the pivot values.** I swept across hyperparameter values near the values set in (1.b). 3. **Weight decay**: The weight decay is set in the optimizer constructor as usual (e.g., `optimizer = optim.Adam(m.parameters(), weight_decay=wd)`). 1. **Start from the default weight decay of that task.** For example, the value chosen by the most widely used Github repository of that task. From 80eddc921d5a448d78488a6a5e7a5dc47a8303d0 Mon Sep 17 00:00:00 2001 From: Khari Date: Fri, 28 Jun 2024 12:05:18 -0400 Subject: [PATCH 4/5] =?UTF-8?q?Refactor=20-=20Changed=20naming=20conventio?= =?UTF-8?q?n=20for=20consistency=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 6 +++--- main_imdb.py | 6 +++--- main_mnist.py | 6 +++--- main_qm9.py | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 7fc7116..f473816 100644 --- a/main.py +++ b/main.py @@ -178,7 +178,7 @@ def main(args): grads = gradfilter_ema( model, grads=grads, alpha=args.alpha, lamb=args.lamb ) - elif args.filter == "kalman": + elif args.filter == "kal": grads = gradfilter_kalman( model, grads=grads, @@ -272,7 +272,7 @@ def main(args): # Grokfast parser.add_argument( - "--filter", type=str, choices=["none", "ma", "ema", "fir"], default="none" + "--filter", type=str, choices=["none", "ma", "ema", "kal"], default="none" ) parser.add_argument("--alpha", type=float, default=0.99) parser.add_argument("--window_size", type=int, default=100) @@ -296,7 +296,7 @@ def main(args): filter_suffix = window_size_str + lamb_str elif args.filter == "ema": filter_suffix = alpha_str + lamb_str - elif args.filter == "kalman": + elif args.filter == "kal": filter_suffix = ( f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + lamb_str diff --git a/main_imdb.py b/main_imdb.py index cdc11ed..0dae97f 100644 --- a/main_imdb.py +++ b/main_imdb.py @@ -488,7 +488,7 @@ def _padding(sentences, seq_len): grads = gradfilter_ema( model, grads=grads, alpha=args.alpha, lamb=args.lamb ) - elif args.filter == "kalman": + elif args.filter == "kal": grads = gradfilter_kalman( model, grads=grads, @@ -613,7 +613,7 @@ def _padding(sentences, seq_len): # Grokfast parser.add_argument( - "--filter", type=str, choices=["none", "ma", "ema", "kalman"], default="none" + "--filter", type=str, choices=["none", "ma", "ema", "kal"], default="none" ) parser.add_argument("--process_noise", type=float, default=1e-4) parser.add_argument("--measurement_noise", type=float, default=1e-2) @@ -635,7 +635,7 @@ def _padding(sentences, seq_len): filter_suffix = window_size_str + lamb_str elif args.filter == "ema": filter_suffix = alpha_str + lamb_str - elif args.filter == "kalman": + elif args.filter == "kal": filter_suffix = ( f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + lamb_str diff --git a/main_mnist.py b/main_mnist.py index 4bfa6e6..1de3d54 100644 --- a/main_mnist.py +++ b/main_mnist.py @@ -206,7 +206,7 @@ def main(args): grads = gradfilter_ema( mlp, grads=grads, alpha=args.alpha, lamb=args.lamb ) - elif args.filter == "kalman": + elif args.filter == "kal": grads = gradfilter_kalman( mlp, grads=grads, @@ -282,7 +282,7 @@ def main(args): # Grokfast parser.add_argument( - "--filter", type=str, choices=["none", "ma", "ema", "kalman"], default="none" + "--filter", type=str, choices=["none", "ma", "ema", "kal"], default="none" ) parser.add_argument("--process_noise", type=float, default=1e-4) parser.add_argument("--measurement_noise", type=float, default=1e-2) @@ -302,7 +302,7 @@ def main(args): filter_suffix = window_size_str + lamb_str elif args.filter == "ema": filter_suffix = alpha_str + lamb_str - elif args.filter == "kalman": + elif args.filter == "kal": filter_suffix = ( f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + lamb_str diff --git a/main_qm9.py b/main_qm9.py index a227d7a..8ebff3d 100644 --- a/main_qm9.py +++ b/main_qm9.py @@ -130,7 +130,7 @@ def main(args): grads = gradfilter_ema( net, grads=grads, alpha=args.alpha, lamb=args.lamb ) - elif args.filter == "kalman": + elif args.filter == "kal": grads = gradfilter_kalman( net, grads=grads, @@ -253,7 +253,7 @@ def main(args): # Grokfast parser.add_argument( - "--filter", type=str, choices=["none", "ma", "ema", "kalman"], default="none" + "--filter", type=str, choices=["none", "ma", "ema", "kal"], default="none" ) parser.add_argument("--process_noise", type=float, default=1e-4) parser.add_argument("--measurement_noise", type=float, default=1e-2) @@ -275,7 +275,7 @@ def main(args): filter_suffix = window_size_str + lamb_str elif args.filter == "ema": filter_suffix = alpha_str + lamb_str - elif args.filter == "kalman": + elif args.filter == "kal": filter_suffix = ( f"_p{args.process_noise:.1e}_m{args.measurement_noise:.1e}".replace(".", "") + lamb_str From 0deb198f266d743701625e3e9d11325bdedfcaf0 Mon Sep 17 00:00:00 2001 From: Khari Date: Fri, 28 Jun 2024 13:32:12 -0400 Subject: [PATCH 5/5] =?UTF-8?q?Refactor=20-=20Initialize=20state=20covaria?= =?UTF-8?q?nce=20(P)=20with=20measurement=20noise=20for=20better=20initial?= =?UTF-8?q?=20estimates=20=E2=9C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- grokfast.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/grokfast.py b/grokfast.py index 2c48080..9356437 100644 --- a/grokfast.py +++ b/grokfast.py @@ -67,7 +67,10 @@ def gradfilter_kalman( ) -> Dict[str, Dict[str, torch.Tensor]]: if grads is None: grads = { - n: {"x": torch.zeros_like(p.grad.data), "P": torch.ones_like(p.grad.data)} + n: { + "x": torch.zeros_like(p.grad.data), + "P": torch.ones_like(p.grad.data) * measurement_noise, + } for n, p in m.named_parameters() if p.requires_grad and p.grad is not None } @@ -90,6 +93,6 @@ def gradfilter_kalman( grads[n]["P"] = P # Apply the filtered gradient - p.grad.data = p.grad.data + x * lamb + p.grad.data += x * lamb return grads