diff --git a/src/language_models/bigram.py b/src/language_models/bigram.py new file mode 100644 index 0000000..4e48f50 --- /dev/null +++ b/src/language_models/bigram.py @@ -0,0 +1,25 @@ +class Bigram(nn.Module): + """ + Bigram Language Model 'neural net', simply a lookup table of logits for the + next character given a previous character. + """ + + def __init__(self, config): + super().__init__() + n = config.vocab_size + self.logits = nn.Parameter(torch.zeros((n, n))) + + def get_block_size(self): + return 1 # this model only needs one previous character to predict the next + + def forward(self, idx, targets=None): + + # 'forward pass', lol + logits = self.logits[idx] + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + + return logits, loss