-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataloader.go
83 lines (75 loc) · 2.17 KB
/
dataloader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package llmgo
import (
"bytes"
"encoding/binary"
"errors"
"io"
)
const Int32ByteLen = 4
type DataLoader struct {
filename string
batchSize int
seqLength int
currentPosition int64
fileSize int64
NumBatches int
data []int32
dataAll []int32
}
func NewDataLoader(filename string, batchSize, seqLength int) (*DataLoader, error) {
file, err := Open(filename)
if err != nil {
return nil, err
}
return newDataLoader(file, batchSize, seqLength)
}
func newDataLoader(file io.Reader, batchSize, seqLength int) (*DataLoader, error) {
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
size := len(data)
if size < (batchSize*seqLength+1)*Int32ByteLen {
return nil, errors.New("error: file size is too small for the batch size and sequence length")
}
loader := &DataLoader{
batchSize: batchSize,
seqLength: seqLength,
NumBatches: size / (batchSize * seqLength * Int32ByteLen),
data: make([]int32, size/Int32ByteLen),
fileSize: int64(size / Int32ByteLen),
}
if err := binary.Read(bytes.NewReader(data), binary.LittleEndian, loader.data); err != nil {
return nil, err
}
return loader, nil
}
func newDataLoaderFromInts(data []int32, batchSize, seqLength int) (*DataLoader, error) {
size := len(data)
if size < (batchSize*seqLength + 1) {
return nil, errors.New("error: file size is too small for the batch size and sequence length")
}
loader := &DataLoader{
batchSize: batchSize,
seqLength: seqLength,
NumBatches: size / (batchSize * seqLength),
data: data,
fileSize: int64(size),
}
return loader, nil
}
func (loader *DataLoader) Reset() {
loader.currentPosition = 0
}
func (loader *DataLoader) NextBatch() ([]int32, []int32, error) {
nextPos := loader.currentPosition + int64(loader.batchSize*loader.seqLength)
if nextPos+1 > loader.fileSize {
loader.Reset()
nextPos = loader.currentPosition + int64(loader.batchSize*loader.seqLength)
}
// don't x4 because we're indexing int32 not byte
inputs := loader.data[loader.currentPosition:nextPos]
targets := loader.data[loader.currentPosition+1 : nextPos+1]
loader.currentPosition = nextPos
return inputs, targets, nil
}