diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e8cafe8..501ae5d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - InitSegment.TweakSingleTrakLive changes an init segment to fit live streaming +- Made bits.Mask() function public + +### Changed + +- All readers and writers in bits package now stop working at first error and provides the first error as AccError() +- Renamed bits.AccErrReader, bits.AccErrEBSPReader, bits.AccErrWriter to corresponiding names without AccErr +- Renamed bits.SliceWriterError to bits.ErrSliceWrite ## [0.42.0] - 2024-01-26 diff --git a/aac/aac.go b/aac/aac.go index 34cd2bde..f9723f5c 100644 --- a/aac/aac.go +++ b/aac/aac.go @@ -76,7 +76,7 @@ var ReverseFrequencies = map[int]byte{ // DecodeAudioSpecificConfig - func DecodeAudioSpecificConfig(r io.Reader) (*AudioSpecificConfig, error) { - br := bits.NewAccErrReader(r) + br := bits.NewReader(r) asc := &AudioSpecificConfig{} audioObjectType := byte(br.Read(5)) @@ -150,11 +150,11 @@ func (a *AudioSpecificConfig) Encode(w io.Writer) error { } bw.Write(0x00, 3) // GASpecificConfig bw.Flush() - return bw.Error() + return bw.AccError() } // getFrequency - either from 4-bit index or 24-bit value -func getFrequency(br *bits.AccErrReader) (frequency int, ok bool) { +func getFrequency(br *bits.Reader) (frequency int, ok bool) { frequencyIndex := br.Read(4) if frequencyIndex == 0x0f { f := br.Read(24) diff --git a/aac/adts.go b/aac/adts.go index 8ba5a0fc..7ed0959a 100644 --- a/aac/adts.go +++ b/aac/adts.go @@ -59,7 +59,7 @@ func (a ADTSHeader) Encode() []byte { // DecodeADTSHeader by first looking for sync word func DecodeADTSHeader(r io.Reader) (header *ADTSHeader, offset int, err error) { - br := bits.NewAccErrReader(r) + br := bits.NewReader(r) tsPacketSize := 188 // Find sync 0xfff in first 188 bytes (MPEG-TS related) syncFound := false offset = 0 diff --git a/avc/pps.go b/avc/pps.go index 7ba54577..25c0f1df 100644 --- a/avc/pps.go +++ b/avc/pps.go @@ -52,7 +52,7 @@ func ParsePPSNALUnit(data []byte, spsMap map[uint32]*SPS) (*PPS, error) { pps := &PPS{} rd := bytes.NewReader(data) - reader := bits.NewAccErrEBSPReader(rd) + reader := bits.NewEBSPReader(rd) // Note! First byte is NAL Header naluHdr := reader.Read(8) diff --git a/avc/slice.go b/avc/slice.go index b50029c6..4ab4a203 100644 --- a/avc/slice.go +++ b/avc/slice.go @@ -66,16 +66,11 @@ func GetSliceTypeFromNALU(data []byte) (sliceType SliceType, err error) { r := bits.NewEBSPReader(bytes.NewReader((data[1:]))) // first_mb_in_slice - if _, err = r.ReadExpGolomb(); err != nil { - return - } - - // slice_type - var st uint - if st, err = r.ReadExpGolomb(); err != nil { - return + _ = r.ReadExpGolomb() + sliceType = SliceType(r.ReadExpGolomb()) + if r.AccError() != nil { + err = r.AccError() } - sliceType = SliceType(st) if sliceType > 9 { err = ErrInvalidSliceType return @@ -134,7 +129,7 @@ type SliceHeader struct { func ParseSliceHeader(nalu []byte, spsMap map[uint32]*SPS, ppsMap map[uint32]*PPS) (*SliceHeader, error) { sh := SliceHeader{} buf := bytes.NewBuffer(nalu) - r := bits.NewAccErrEBSPReader(buf) + r := bits.NewEBSPReader(buf) nalHdr := r.Read(8) naluType := GetNaluType(byte(nalHdr)) switch naluType { diff --git a/avc/sps.go b/avc/sps.go index d1d8ae26..509698f1 100644 --- a/avc/sps.go +++ b/avc/sps.go @@ -117,7 +117,7 @@ func ParseSPSNALUnit(data []byte, parseVUIBeyondAspectRatio bool) (*SPS, error) sps := &SPS{} rd := bytes.NewReader(data) - reader := bits.NewAccErrEBSPReader(rd) + reader := bits.NewEBSPReader(rd) // Note! First byte is NAL Header nalHdr := reader.Read(8) @@ -271,7 +271,7 @@ func (s *SPS) ChromaArrayType() byte { // parseVUI - parse VUI (Visual Usability Information) // if parseVUIBeyondAspectRatio is false, stop after AspectRatio has been parsed -func parseVUI(reader *bits.AccErrEBSPReader, parseVUIBeyondAspectRatio bool) *VUIParameters { +func parseVUI(reader *bits.EBSPReader, parseVUIBeyondAspectRatio bool) *VUIParameters { vui := &VUIParameters{} var err error aspectRatioInfoPresentFlag := reader.ReadFlag() @@ -342,7 +342,7 @@ func parseVUI(reader *bits.AccErrEBSPReader, parseVUIBeyondAspectRatio bool) *VU return vui } -func parseHrdParameters(r *bits.AccErrEBSPReader) *HrdParameters { +func parseHrdParameters(r *bits.EBSPReader) *HrdParameters { hp := &HrdParameters{} hp.CpbCountMinus1 = r.ReadExpGolomb() @@ -380,7 +380,7 @@ func GetSARfromIDC(index uint) (uint, uint, error) { return aspectRatioTable[index-1][0], aspectRatioTable[index-1][1], nil } -func readScalingList(reader *bits.AccErrEBSPReader, sizeOfScalingList int) ScalingList { +func readScalingList(reader *bits.EBSPReader, sizeOfScalingList int) ScalingList { scalingList := make([]int, sizeOfScalingList) lastScale := 8 nextScale := 8 diff --git a/bits/aeebspreader_test.go b/bits/aeebspreader_test.go deleted file mode 100644 index 594d5c2f..00000000 --- a/bits/aeebspreader_test.go +++ /dev/null @@ -1,53 +0,0 @@ -package bits_test - -import ( - "bytes" - "encoding/hex" - "testing" - - "github.com/Eyevinn/mp4ff/bits" -) - -func TestAccErrEBSPReader(t *testing.T) { - testCases := []struct { - hexData string - readBits int - expectedValue uint - expectedNrBytesRead int - expectedNrBitsRead int - expectedNrBitsReadInCurrentByte int - }{ - {"00", 0, 0, 0, 0, 8}, - {"0001", 1, 0, 1, 1, 1}, - {"0001", 8, 0x00, 1, 8, 8}, - {"0010", 12, 0x001, 2, 12, 4}, - {"0102", 16, 0x0102, 2, 16, 8}, - {"00000304", 24, 0x000004, 4, 32, 8}, // Note the start-code emulation 0x03 - } - for _, tc := range testCases { - t.Run("", func(t *testing.T) { - data, err := hex.DecodeString(tc.hexData) - if err != nil { - t.Error(err) - } - buf := bytes.NewBuffer(data) - r := bits.NewAccErrEBSPReader(buf) - if r.AccError() != nil { - t.Errorf("expected no error, got %v", r.AccError()) - } - v := r.Read(tc.readBits) - if v != tc.expectedValue { - t.Errorf("expected value %d, got %d", tc.expectedValue, v) - } - if r.NrBytesRead() != tc.expectedNrBytesRead { - t.Errorf("expected %d bytes read, got %d", tc.expectedNrBytesRead, r.NrBytesRead()) - } - if r.NrBitsRead() != tc.expectedNrBitsRead { - t.Errorf("expected %d bits read, got %d", tc.expectedNrBitsRead, r.NrBitsRead()) - } - if r.NrBitsReadInCurrentByte() != tc.expectedNrBitsReadInCurrentByte { - t.Errorf("expected %d bits read in current byte, got %d", tc.expectedNrBitsReadInCurrentByte, r.NrBitsReadInCurrentByte()) - } - }) - } -} diff --git a/bits/aereader_test.go b/bits/aereader_test.go deleted file mode 100644 index 5ec18298..00000000 --- a/bits/aereader_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package bits - -import ( - "bytes" - "io" - "testing" -) - -func TestAccErrReader(t *testing.T) { - input := []byte{0xff, 0x0f} // 1111 1111 0000 1111 - rd := bytes.NewReader(input) - reader := NewAccErrReader(rd) - - cases := []struct { - n int - want uint - }{ - {2, 3}, // 11 - {3, 7}, // 111 - {5, 28}, // 11100 - {3, 1}, // 001 - {3, 7}, // 111 - } - - for _, tc := range cases { - got := reader.Read(tc.n) - - if got != tc.want { - t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) - } - } - err := reader.AccError() - if err != nil { - t.Errorf("Got accumulated error: %s", err.Error()) - } -} - -func TestAccErrReaderSigned(t *testing.T) { - input := []byte{0xff, 0x0c} // 1111 1111 0000 1100 - rd := bytes.NewReader(input) - reader := NewAccErrReader(rd) - - cases := []struct { - n int - want int - }{ - {2, -1}, // 11 - {3, -1}, // 111 - {5, -4}, // 11100 - {3, 1}, // 001 - {3, -4}, // 100 - } - - for _, tc := range cases { - got := reader.ReadSigned(tc.n) - - if got != tc.want { - t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) - } - } - err := reader.AccError() - if err != nil { - t.Errorf("Got accumulated error: %s", err.Error()) - } -} - -func TestBadAccErrReader(t *testing.T) { - // Check that reading beyond EOF provides value = 0 after acc error - input := []byte{0xff, 0x0f} // 1111 1111 0000 1111 - rd := bytes.NewReader(input) - reader := NewAccErrReader(rd) - - cases := []struct { - err error - n int - want uint - }{ - {nil, 2, 3}, // 11 - {nil, 3, 7}, // 111 - {io.EOF, 12, 0}, // 0 because of error - {io.EOF, 3, 0}, // 0 because of acc error - {io.EOF, 3, 0}, // 0 because of acc error - } - - for _, tc := range cases { - got := reader.Read(tc.n) - - if got != tc.want { - t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) - } - } - err := reader.AccError() - if err != io.EOF { - t.Errorf("Wanted io.EOF but got %v", err) - } -} diff --git a/bits/bits.go b/bits/bits.go deleted file mode 100644 index a9f459ea..00000000 --- a/bits/bits.go +++ /dev/null @@ -1,144 +0,0 @@ -package bits - -import ( - "encoding/binary" - "io" -) - -// Writer writes bits into underlying io.Writer. Stops writing at first error. -// Errors that have occurred can later be checked with Error(). -type Writer struct { - wr io.Writer - err error // The first error caused by any write operation - out []byte // Slice of length 1 to avoid allocation at output - n int // current number of bits - v uint // current accumulated value -} - -// NewWriter - returns a new Writer -func NewWriter(w io.Writer) *Writer { - return &Writer{ - wr: w, - out: make([]byte, 1), - } -} - -// Write - write n bits from bits and save error state -func (w *Writer) Write(bits uint, n int) { - if w.err != nil { - return - } - w.v <<= uint(n) - w.v |= bits & mask(n) - w.n += n - for w.n >= 8 { - b := (w.v >> (uint(w.n) - 8)) & mask(8) - w.out[0] = uint8(b) - _, err := w.wr.Write(w.out) - if err != nil { - w.err = err - return - } - w.n -= 8 - } - w.v &= mask(8) -} - -// Flush - write remaining bits to the underlying io.Writer. -// bits will be left-shifted. -func (w *Writer) Flush() { - if w.err != nil { - return - } - if w.n != 0 { - b := (w.v << (8 - uint(w.n))) & mask(8) - if err := binary.Write(w.wr, binary.BigEndian, uint8(b)); err != nil { - w.err = err - return - } - } -} - -// Error - error that has occurred and stopped writing -func (w *Writer) Error() error { - return w.err -} - -// Reader - read bits from the given io.Reader -type Reader struct { - rd io.Reader - n int // current number of bits - v uint // current accumulated value -} - -// NewReader - return a new Reader -func NewReader(rd io.Reader) *Reader { - return &Reader{ - rd: rd, - } -} - -// Read - read n bits -func (r *Reader) Read(n int) (uint, error) { - var err error - - for r.n < n { - r.v <<= 8 - var b uint8 - err = binary.Read(r.rd, binary.BigEndian, &b) - if err != nil && err != io.EOF { - return 0, err - } - r.v |= uint(b) - - r.n += 8 - } - v := r.v >> uint(r.n-n) - - r.n -= n - r.v &= mask(r.n) - - return v, err -} - -// ReadFlag - read 1 bit into flag -func (r *Reader) ReadFlag() (bool, error) { - bit, err := r.Read(1) - if err != nil { - return false, err - } - return bit == 1, nil -} - -// MustRead - Read bits and panic if not possible -func (r *Reader) MustRead(n int) uint { - var err error - - for r.n < n { - r.v <<= 8 - var b uint8 - err = binary.Read(r.rd, binary.BigEndian, &b) - if err != nil { - panic("Reading error") - } - r.v |= uint(b) - - r.n += 8 - } - v := r.v >> uint(r.n-n) - - r.n -= n - r.v &= mask(r.n) - - return v -} - -// MustReadFlag - read 1 bit into flag and panic if not possible -func (r *Reader) MustReadFlag() bool { - return r.MustRead(1) == 1 -} - -// mask - n-bit binary mask -func mask(n int) uint { - return (1 << uint(n)) - 1 -} diff --git a/bits/aebyteswriter.go b/bits/bytewriter.go similarity index 60% rename from bits/aebyteswriter.go rename to bits/bytewriter.go index 40b1facc..e797fc03 100644 --- a/bits/aebyteswriter.go +++ b/bits/bytewriter.go @@ -5,26 +5,27 @@ import ( "io" ) -// AccErrByteWriter - writer that wraps an io.Writer and accumulates error -type AccErrByteWriter struct { +// ByteWriter - writer that wraps an io.Writer and accumulates error. +// Only the first error is saved, but any later calls will not panic. +type ByteWriter struct { w io.Writer err error } -// NewAccErrByteWriter - create accumulated error writer around io.Writer -func NewAccErrByteWriter(w io.Writer) *AccErrByteWriter { - return &AccErrByteWriter{ +// NewByteWriter creates accumulated error writer around io.Writer. +func NewByteWriter(w io.Writer) *ByteWriter { + return &ByteWriter{ w: w, } } // AccError - return accumulated error -func (a *AccErrByteWriter) AccError() error { +func (a *ByteWriter) AccError() error { return a.err } // WriteUint8 - write a byte -func (a *AccErrByteWriter) WriteUint8(b byte) { +func (a *ByteWriter) WriteUint8(b byte) { if a.err != nil { return } @@ -32,7 +33,7 @@ func (a *AccErrByteWriter) WriteUint8(b byte) { } // WriteUint16 - write uint16 -func (a *AccErrByteWriter) WriteUint16(u uint16) { +func (a *ByteWriter) WriteUint16(u uint16) { if a.err != nil { return } @@ -40,7 +41,7 @@ func (a *AccErrByteWriter) WriteUint16(u uint16) { } // WriteUint32 - write uint32 -func (a *AccErrByteWriter) WriteUint32(u uint32) { +func (a *ByteWriter) WriteUint32(u uint32) { if a.err != nil { return } @@ -48,7 +49,7 @@ func (a *AccErrByteWriter) WriteUint32(u uint32) { } // WriteUint48 - write uint48 -func (a *AccErrByteWriter) WriteUint48(u uint64) { +func (a *ByteWriter) WriteUint48(u uint64) { if a.err != nil { return } @@ -62,7 +63,7 @@ func (a *AccErrByteWriter) WriteUint48(u uint64) { } // WriteUint64 - write uint64 -func (a *AccErrByteWriter) WriteUint64(u uint64) { +func (a *ByteWriter) WriteUint64(u uint64) { if a.err != nil { return } @@ -70,7 +71,7 @@ func (a *AccErrByteWriter) WriteUint64(u uint64) { } // WriteSlice - write a slice -func (a *AccErrByteWriter) WriteSlice(s []byte) { +func (a *ByteWriter) WriteSlice(s []byte) { if a.err != nil { return } diff --git a/bits/bytewriter_test.go b/bits/bytewriter_test.go new file mode 100644 index 00000000..b5f8a06c --- /dev/null +++ b/bits/bytewriter_test.go @@ -0,0 +1,91 @@ +package bits_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/Eyevinn/mp4ff/bits" +) + +func TestByteWriter(t *testing.T) { + t.Run("write ints", func(t *testing.T) { + buf := bytes.Buffer{} + w := bits.NewByteWriter(&buf) + if w.AccError() != nil { + t.Error("Error should be nil") + } + w.WriteUint8(0x00) + w.WriteUint16(0x0102) + w.WriteUint32(0x03040506) + w.WriteUint48(0x0708090a0b0c) + w.WriteUint64(0x0d0e0f1011121314) + w.WriteSlice([]byte{0x15, 0x16, 0x17}) + if w.AccError() != nil { + t.Errorf("unexpected error: %v", w.AccError()) + } + expected := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, + 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17} + if !bytes.Equal(expected, buf.Bytes()) { + t.Errorf("bytes differ: %v %v", expected, buf.Bytes()) + } + }) + t.Run("write after error", func(t *testing.T) { + buf := newLimitedBuffer(1) + w := bits.NewByteWriter(buf) + w.WriteUint8(0x77) + if w.AccError() != nil { + t.Error("limited buffer should accept 1 byte") + } + w.WriteUint48(0x00) + if w.AccError() == nil { + t.Error("limited buffer should not accept any more bytes") + } + w.WriteUint8(0x01) + if w.AccError() == nil { + t.Error("limited buffer should not accept 2 bytes") + } + w.WriteUint16(0x0102) + if w.AccError() == nil { + t.Error("limited buffer should have an accumulated error") + } + w.WriteUint32(0x03040506) + if w.AccError() == nil { + t.Error("limited buffer should have an accumulated error") + } + w.WriteUint48(0x0708090a0b0c) + if w.AccError() == nil { + t.Error("limited buffer should have an accumulated error") + } + w.WriteUint64(0x0d0e0f1011121314) + if w.AccError() == nil { + t.Error("limited buffer should have an accumulated error") + } + w.WriteSlice([]byte{0x15, 0x16, 0x17}) + if w.AccError() == nil { + t.Error("limited buffer should have an accumulated error") + } + expected := []byte{0x77} + if !bytes.Equal(expected, buf.buf) { + t.Errorf("bytes differ: %v %v", expected, buf.buf) + } + }) + +} + +type limitedBuffer struct { + limit int + buf []byte +} + +func newLimitedBuffer(limit int) *limitedBuffer { + return &limitedBuffer{limit: limit, buf: make([]byte, 0, limit)} +} + +func (lb *limitedBuffer) Write(p []byte) (n int, err error) { + if len(lb.buf)+len(p) > lb.limit { + return 0, fmt.Errorf("overflow") + } + lb.buf = append(lb.buf, p...) + return len(p), nil +} diff --git a/bits/doc.go b/bits/doc.go index f655b99c..85d17a42 100644 --- a/bits/doc.go +++ b/bits/doc.go @@ -1,8 +1,9 @@ /* -Package bits - bit reading and writing including Golomb codes and EBSP. - -Beyond plain bit reading and writing, reading of ebsp (Encapsulated Byte Sequence Packets) -Golomb codes as used in the AVC/H.264 and HEVC video coding standards. +Package bits - bit and bytes reading and writing including Golomb codes and EBSP. +All readaer and writer accumulates errors in the sense that they will stop reading or writing at the first error. +The first error can be retrieved with AccError(). +Beyond plain bit reading and writing, reading and writing of ebsp (Encapsulated Byte Sequence Packets) +with start-code emulation prevention bytes and exponential Golomb codes as used in the AVC/H.264 and HEVC video coding standards. */ package bits diff --git a/bits/ebsp.go b/bits/ebsp.go deleted file mode 100644 index 9f19da85..00000000 --- a/bits/ebsp.go +++ /dev/null @@ -1,306 +0,0 @@ -package bits - -import ( - "encoding/binary" - "errors" - "fmt" - "io" -) - -const ( - startCodeEmulationPreventionByte = 0x03 -) - -// ESBPReader errors -var ( - ErrNotReedSeeker = errors.New("Reader does not support Seek") -) - -// NewEBSPReader - return a new Reader. -func NewEBSPReader(rd io.Reader) *EBSPReader { - return &EBSPReader{ - rd: rd, - pos: -1, - } -} - -// EBSPReader - Reader that drops start code emulation 0x03 after two bytes of 0x00 -type EBSPReader struct { - rd io.Reader - n int // current number of bits - v uint // current accumulated value - pos int - zeroCount int // Count number of zero bytes read -} - -// MustRead - read n bits and panic if not possible -func (r *EBSPReader) MustRead(n int) uint { - var err error - - for r.n < n { - r.v <<= 8 - var b uint8 - err = binary.Read(r.rd, binary.BigEndian, &b) - if err != nil { - panic("Reading error") - } - r.pos++ - if r.zeroCount == 2 && b == startCodeEmulationPreventionByte { - err = binary.Read(r.rd, binary.BigEndian, &b) - if err != nil { - panic("Reading error") - } - r.pos++ - r.zeroCount = 0 - } - if b != 0 { - r.zeroCount = 0 - } else { - r.zeroCount++ - } - r.v |= uint(b) - - r.n += 8 - } - v := r.v >> uint(r.n-n) - - r.n -= n - r.v &= mask(r.n) - - return v -} - -// MustReadFlag - read 1 bit into flag. Panic if not possible -func (r *EBSPReader) MustReadFlag() bool { - return r.MustRead(1) == 1 -} - -// MustReadExpGolomb - Read one unsigned exponential golomb code. Panic if not possible -func (r *EBSPReader) MustReadExpGolomb() uint { - leadingZeroBits := 0 - - for { - b := r.MustRead(1) - if b == 1 { - break - } - leadingZeroBits++ - } - - var res uint = (1 << leadingZeroBits) - 1 - endBits := r.MustRead(leadingZeroBits) - - return res + endBits -} - -// MustReadSignedGolomb - Read one signed exponential golomb code. Panic if not possible -func (r *EBSPReader) MustReadSignedGolomb() int { - unsignedGolomb := r.MustReadExpGolomb() - if unsignedGolomb%2 == 1 { - return int((unsignedGolomb + 1) / 2) - } - return -int(unsignedGolomb / 2) -} - -// NrBytesRead - how many bytes read into parser -func (r *EBSPReader) NrBytesRead() int { - return r.pos + 1 // Starts at -1 -} - -// NrBitsReadInCurrentByte - how many bits have been read -func (r *EBSPReader) NrBitsReadInCurrentByte() int { - return 8 - r.n -} - -// EBSP2rbsp - convert from EBSP to RBSP by removing start code emulation prevention bytes -func EBSP2rbsp(ebsp []byte) []byte { - zeroCount := 0 - output := make([]byte, 0, len(ebsp)) - for i := 0; i < len(ebsp); i++ { - b := ebsp[i] - if zeroCount == 2 && b == startCodeEmulationPreventionByte { - zeroCount = 0 - } else { - if b != 0 { - zeroCount = 0 - } else { - zeroCount++ - } - output = append(output, ebsp[i]) - } - } - return output -} - -// Read - read n bits and return error if not possible -func (r *EBSPReader) Read(n int) (uint, error) { - var err error - - for r.n < n { - r.v <<= 8 - var b uint8 - err = binary.Read(r.rd, binary.BigEndian, &b) - if err != nil { - return 0, err - } - r.pos++ - if r.zeroCount == 2 && b == startCodeEmulationPreventionByte { - err = binary.Read(r.rd, binary.BigEndian, &b) - if err != nil { - return 0, err - } - r.pos++ - r.zeroCount = 0 - } - if b != 0 { - r.zeroCount = 0 - } else { - r.zeroCount++ - } - r.v |= uint(b) - - r.n += 8 - } - v := r.v >> uint(r.n-n) - - r.n -= n - r.v &= mask(r.n) - - return v, nil -} - -// ReadFlag - read 1 bit into flag. Return error if not possible -func (r *EBSPReader) ReadFlag() (bool, error) { - bit, err := r.Read(1) - if err != nil { - return false, err - } - return bit == 1, nil -} - -// ReadExpGolomb - Read one unsigned exponential golomb code -func (r *EBSPReader) ReadExpGolomb() (uint, error) { - leadingZeroBits := 0 - - for { - b, err := r.Read(1) - if err != nil { - return 0, err - } - if b == 1 { - break - } - leadingZeroBits++ - } - - var res uint = (1 << leadingZeroBits) - 1 - - endBits, err := r.Read(leadingZeroBits) - if err != nil { - return 0, err - } - - return res + endBits, nil -} - -// ReadSignedGolomb - Read one signed exponential golomb code -func (r *EBSPReader) ReadSignedGolomb() (int, error) { - unsignedGolomb, err := r.ReadExpGolomb() - if err != nil { - return 0, err - } - if unsignedGolomb%2 == 1 { - return int((unsignedGolomb + 1) / 2), nil - } - return -int(unsignedGolomb / 2), nil -} - -// IsSeeker - does reader support Seek -func (r *EBSPReader) IsSeeker() bool { - _, ok := r.rd.(io.ReadSeeker) - return ok -} - -// MoreRbspData - false if next bit is 1 and last 1 in fullSlice -// Underlying reader must support ReadSeeker interface to reset after check -func (r *EBSPReader) MoreRbspData() (bool, error) { - if !r.IsSeeker() { - return false, ErrNotReedSeeker - } - // Find out if next position is the last 1 - stateCopy := *r - - firstBit, err := r.Read(1) - if err != nil { - return false, err - } - if firstBit != 1 { - err = r.reset(stateCopy) - if err != nil { - return false, err - } - return true, nil - } - // If all remainging bits are zero, there is no more rbsp data - more := false - for { - b, err := r.Read(1) - if err == io.EOF { - break - } - if err != nil { - return false, err - } - if b == 1 { - more = true - break - } - } - err = r.reset(stateCopy) - if err != nil { - return false, nil - } - return more, nil -} - -// reset EBSPReader based on copy of previous state -func (r *EBSPReader) reset(prevState EBSPReader) error { - rdSeek, ok := r.rd.(io.ReadSeeker) - - if !ok { - return ErrNotReedSeeker - } - - _, err := rdSeek.Seek(int64(prevState.pos+1), 0) - if err != nil { - return err - } - r.n = prevState.n - r.v = prevState.v - r.pos = prevState.pos - r.zeroCount = prevState.zeroCount - return nil -} - -// ReadRbspTrailingBits - read rbsp_traling_bits. Return false if wrong pattern -func (r *EBSPReader) ReadRbspTrailingBits() error { - firstBit, err := r.Read(1) - if err != nil { - return err - } - if firstBit != 1 { - return fmt.Errorf("RbspTrailingBits don't start with 1") - } - for { - b, err := r.Read(1) - if err == io.EOF { - return nil - } - if err != nil { - return err - } - if b == 1 { - return fmt.Errorf("Another 1 in RbspTrailingBits") - } - } -} diff --git a/bits/ebsp_test.go b/bits/ebsp_test.go deleted file mode 100644 index 014c83bf..00000000 --- a/bits/ebsp_test.go +++ /dev/null @@ -1,254 +0,0 @@ -package bits - -import ( - "bytes" - "encoding/hex" - "io" - "testing" - - "github.com/go-test/deep" -) - -func TestGolomb(t *testing.T) { - t.Run("unsignedGolomb", func(t *testing.T) { - cases := []struct { - input []byte - want uint - }{ - {[]byte{0b10000000}, 0}, - {[]byte{0b01000000}, 1}, - {[]byte{0b01100000}, 2}, - {[]byte{0b00010000}, 7}, - {[]byte{0b00010010}, 8}, - {[]byte{0b00011110}, 14}, - } - - for _, c := range cases { - buf := bytes.NewBuffer(c.input) - r := NewEBSPReader(buf) - got := r.MustReadExpGolomb() - if got != c.want { - t.Errorf("got %d want %d", got, c.want) - } - } - }) - t.Run("signedGolomb", func(t *testing.T) { - cases := []struct { - input []byte - want int - }{ - {[]byte{0b10000000}, 0}, - {[]byte{0b01000000}, 1}, - {[]byte{0b01100000}, -1}, - {[]byte{0b00010000}, 4}, - {[]byte{0b00010010}, -4}, - {[]byte{0b00011110}, -7}, - } - - for _, c := range cases { - buf := bytes.NewBuffer(c.input) - r := NewEBSPReader(buf) - got := r.MustReadSignedGolomb() - if got != c.want { - t.Errorf("got %d want %d", got, c.want) - } - } - }) -} - -// TestEbspParser including startCodeEmulationPrevention removal -func TestEbspParser(t *testing.T) { - - cases := []struct{ name, start, want string }{ - { - "read last byte", - "32", - "32", - }, - { - "remove escape but not 007", - "27640020ac2ec05005bb011000000300100000078e840016e300005b8d8bdef83b438627", - "27640020ac2ec05005bb0110000000100000078e840016e300005b8d8bdef83b438627", - }, - { - "Long zero sequence", - "00000300000300", - "0000000000", - }, - } - - for _, c := range cases { - byteData, _ := hex.DecodeString(c.start) - buf := bytes.NewBuffer(byteData) - r := NewEBSPReader(buf) - got := []byte{} - for { - b, err := r.Read(8) - if err == io.EOF { - break - } - got = append(got, byte(b)) - } - wantBytes, err := hex.DecodeString(c.want) - if err != nil { - t.Error(err) - } - if diff := deep.Equal(got, wantBytes); diff != nil { - t.Errorf("%s: %v", c.name, diff) - } - } -} - -func TestGetPrecisePosition(t *testing.T) { - - testCases := []struct { - name string - inBytes []byte - nrBitsToRead int - nrBytesRead int - nrBitsReadInByte int - }{ - { - name: "Read 13 bits", - inBytes: []byte{0, 1, 2}, - nrBitsToRead: 13, - nrBytesRead: 2, - nrBitsReadInByte: 5, - }, - } - - for _, c := range testCases { - - buf := bytes.NewBuffer(c.inBytes) - r := NewEBSPReader(buf) - _, err := r.Read(c.nrBitsToRead) - if err != nil { - t.Error(err) - } - if r.NrBytesRead() != c.nrBytesRead { - t.Errorf("%s: got %d bytes want %d bytes", c.name, r.NrBytesRead(), c.nrBytesRead) - } - if r.NrBitsReadInCurrentByte() != c.nrBitsReadInByte { - t.Errorf("%s: got %d bits want %d bits", c.name, r.NrBitsReadInCurrentByte(), c.nrBitsReadInByte) - } - - } -} - -func TestMoreRbspData(t *testing.T) { - - testCases := []struct { - name string - inBytes []byte - nrBitsBefore int - moreRbsp bool - nrBitsAfter int - valueRead uint - }{ - { - name: "start of byte", - inBytes: []byte{0b10000000}, - nrBitsBefore: 0, - moreRbsp: false, - nrBitsAfter: 0, - valueRead: 0, - }, - { - name: "one bit left", - inBytes: []byte{0b11000000}, - nrBitsBefore: 0, - moreRbsp: true, - nrBitsAfter: 2, - valueRead: 3, - }, - { - name: "after one bit", - inBytes: []byte{0b11000000}, - nrBitsBefore: 1, - moreRbsp: false, - nrBitsAfter: 0, - valueRead: 0, - }, - { - name: "after one byte", - inBytes: []byte{0b11110111, 0b11001000}, - nrBitsBefore: 9, - moreRbsp: true, - nrBitsAfter: 4, - valueRead: 0b00001001, - }, - } - - for _, c := range testCases { - - brd := bytes.NewReader(c.inBytes) - r := NewEBSPReader(brd) - _, err := r.Read(c.nrBitsBefore) - if err != nil { - t.Error(err) - } - moreRbsp, err := r.MoreRbspData() - if err != nil { - t.Error(err) - } - if moreRbsp != c.moreRbsp { - t.Errorf("%s: got %t want %t", c.name, moreRbsp, c.moreRbsp) - } - got, err := r.Read(c.nrBitsAfter) - if err != nil { - t.Error(err) - } - if got != c.valueRead { - t.Errorf("%s: got %d want %d after check", c.name, got, c.valueRead) - } - } -} - -func TestReadTrailingRbspBits(t *testing.T) { - input := []byte{0b10000000} - brd := bytes.NewReader(input) - reader := NewEBSPReader(brd) - err := reader.ReadRbspTrailingBits() - if err != nil { - t.Error(err) - } - _, err = reader.Read(1) - if err != io.EOF { - t.Errorf("Not at end after reading rbsp_trailing_bits") - } -} - -func TestEBSPWriter(t *testing.T) { - testCases := []struct { - in []byte - out []byte - }{ - { - in: []byte{0, 0, 0, 1}, - out: []byte{0, 0, 3, 0, 1}, - }, - { - in: []byte{1, 0, 0, 2}, - out: []byte{1, 0, 0, 3, 2}, - }, - { - in: []byte{0, 0, 0, 0, 0}, - out: []byte{0, 0, 3, 0, 0, 3, 0}, - }, - { - in: []byte{0, 0, 0, 0, 5, 0}, - out: []byte{0, 0, 3, 0, 0, 5, 0}, - }, - } - for _, tc := range testCases { - buf := bytes.Buffer{} - w := NewEBSPWriter(&buf) - for _, b := range tc.in { - w.Write(uint(b), 8) - } - diff := deep.Equal(buf.Bytes(), tc.out) - if diff != nil { - t.Errorf("Got %v but wanted %d", buf.Bytes(), tc.out) - } - } -} diff --git a/bits/ebsp_writer_test.go b/bits/ebsp_writer_test.go deleted file mode 100644 index 6cd8e77f..00000000 --- a/bits/ebsp_writer_test.go +++ /dev/null @@ -1,45 +0,0 @@ -package bits - -import ( - "bytes" - "fmt" - "testing" -) - -func TestExpGolomb(t *testing.T) { - cases := []struct { - bits string - n uint - }{ - {"1", 0}, - {"010", 1}, - {"011", 2}, - {"00100", 3}, - {"0001111", 14}, - {"000010000", 15}, - {"000011111", 30}, - } - - for _, tc := range cases { - b := bytes.Buffer{} - w := NewEBSPWriter(&b) - w.WriteExpGolomb(tc.n) - gotBits := getBitsWritten(w, &b) - if gotBits != tc.bits { - t.Errorf("wanted %s but got %s for %d", tc.bits, gotBits, tc.n) - } - } -} - -func getBitsWritten(w *EBSPWriter, b *bytes.Buffer) string { - bits := "" - for _, c := range b.Bytes() { - bits += fmt.Sprintf("%08b", c) - } - valueInWriter, nrBitsInWriter := w.BitsInBuffer() - if nrBitsInWriter > 0 { - fullByte := fmt.Sprintf("%08b", valueInWriter) - bits += fullByte[8-nrBitsInWriter:] - } - return bits -} diff --git a/bits/aeebspreader.go b/bits/ebspreader.go similarity index 58% rename from bits/aeebspreader.go rename to bits/ebspreader.go index f11de431..bb91da68 100644 --- a/bits/aeebspreader.go +++ b/bits/ebspreader.go @@ -2,12 +2,23 @@ package bits import ( "encoding/binary" + "errors" "fmt" "io" ) -// AccErrEBSPReader - Reader that drops start code emulation 0x03 after two bytes of 0x00 -type AccErrEBSPReader struct { +// ESBPReader errors +var ( + ErrNotReadSeeker = errors.New("reader does not support Seek") +) + +const ( + startCodeEmulationPreventionByte = 0x03 +) + +// EBSPReader reads an EBSP bitstream dropping start-code emulation bytes. +// It also supports checking for more rbsp data and reading rbsp_trailing_bits. +type EBSPReader struct { rd io.Reader err error n int // current number of bits @@ -16,26 +27,26 @@ type AccErrEBSPReader struct { zeroCount int // Count number of zero bytes read } -// NewAccErrEBSPReader - return a new reader accumulating errors. -func NewAccErrEBSPReader(rd io.Reader) *AccErrEBSPReader { - return &AccErrEBSPReader{ +// NewEBSPReader return a new EBSP reader stopping reading at first error. +func NewEBSPReader(rd io.Reader) *EBSPReader { + return &EBSPReader{ rd: rd, pos: -1, } } -// AccError - accumulated error -func (r *AccErrEBSPReader) AccError() error { +// AccError returns the accumulated error. If no error, returns nil. +func (r *EBSPReader) AccError() error { return r.err } -// NrBytesRead - how many bytes read into parser -func (r *AccErrEBSPReader) NrBytesRead() int { +// NrBytesRead returns how many bytes read into parser. +func (r *EBSPReader) NrBytesRead() int { return r.pos + 1 // Starts at -1 } -// NrBitsRead - how many bits read into parser -func (r *AccErrEBSPReader) NrBitsRead() int { +// NrBitsRead returns total number of bits read into parser. +func (r *EBSPReader) NrBitsRead() int { nrBits := r.NrBytesRead() * 8 if r.NrBitsReadInCurrentByte() != 8 { nrBits += r.NrBitsReadInCurrentByte() - 8 @@ -43,13 +54,13 @@ func (r *AccErrEBSPReader) NrBitsRead() int { return nrBits } -// NrBitsReadInCurrentByte - how many bits have been read -func (r *AccErrEBSPReader) NrBitsReadInCurrentByte() int { +// NrBitsReadInCurrentByte returns number of bits read in current byte. +func (r *EBSPReader) NrBitsReadInCurrentByte() int { return 8 - r.n } -// Read - read n bits and return 0 if (previous) error -func (r *AccErrEBSPReader) Read(n int) uint { +// Read reads n bits and respects and accumulates errors. If error, returns 0. +func (r *EBSPReader) Read(n int) uint { if r.err != nil { return 0 } @@ -84,13 +95,13 @@ func (r *AccErrEBSPReader) Read(n int) uint { v := r.v >> uint(r.n-n) r.n -= n - r.v &= mask(r.n) + r.v &= Mask(r.n) return v } -// Read - read n bytes and return nil if (previous) error or if n bytes not available -func (r *AccErrEBSPReader) ReadBytes(n int) []byte { +// ReadBytes read n bytes and return nil if new or accumulated error. +func (r *EBSPReader) ReadBytes(n int) []byte { if r.err != nil { return nil } @@ -105,13 +116,13 @@ func (r *AccErrEBSPReader) ReadBytes(n int) []byte { return payload } -// ReadFlag - read 1 bit into bool. Return false if not possible -func (r *AccErrEBSPReader) ReadFlag() bool { +// ReadFlag reads 1 bit and translates a bool. +func (r *EBSPReader) ReadFlag() bool { return r.Read(1) == 1 } -// ReadExpGolomb - Read one unsigned exponential golomb code. Return 0 if error -func (r *AccErrEBSPReader) ReadExpGolomb() uint { +// ReadExpGolomb reads one unsigned exponential Golomb code. +func (r *EBSPReader) ReadExpGolomb() uint { if r.err != nil { return 0 } @@ -138,8 +149,8 @@ func (r *AccErrEBSPReader) ReadExpGolomb() uint { return res + endBits } -// ReadSignedGolomb - Read one signed exponential golomb code. Return 0 if error -func (r *AccErrEBSPReader) ReadSignedGolomb() int { +// ReadSignedGolomb reads one signed exponential Golomb code. +func (r *EBSPReader) ReadSignedGolomb() int { if r.err != nil { return 0 } @@ -153,18 +164,18 @@ func (r *AccErrEBSPReader) ReadSignedGolomb() int { return -int(unsignedGolomb / 2) } -// IsSeeker - does reader support Seek -func (r *AccErrEBSPReader) IsSeeker() bool { +// IsSeeker returns tru if underluing reader supports Seek interface. +func (r *EBSPReader) IsSeeker() bool { _, ok := r.rd.(io.ReadSeeker) return ok } -// MoreRbspData - false if next bit is 1 and last 1 in fullSlice. -// Underlying reader must support ReadSeeker interface to reset after check -// Return false, nil if underlying error -func (r *AccErrEBSPReader) MoreRbspData() (bool, error) { +// MoreRbspData returns false if next bit is 1 and last 1-bit in fullSlice. +// Underlying reader must support ReadSeeker interface to reset after check. +// Return false, nil if underlying error. +func (r *EBSPReader) MoreRbspData() (bool, error) { if !r.IsSeeker() { - return false, ErrNotReedSeeker + return false, ErrNotReadSeeker } // Find out if next position is the last 1 stateCopy := *r @@ -203,12 +214,12 @@ func (r *AccErrEBSPReader) MoreRbspData() (bool, error) { return more, nil } -// reset EBSPReader based on copy of previous state -func (r *AccErrEBSPReader) reset(prevState AccErrEBSPReader) error { +// reset resets EBSPReader based on copy of previous state. +func (r *EBSPReader) reset(prevState EBSPReader) error { rdSeek, ok := r.rd.(io.ReadSeeker) if !ok { - return ErrNotReedSeeker + return ErrNotReadSeeker } _, err := rdSeek.Seek(int64(prevState.pos+1), 0) @@ -222,9 +233,9 @@ func (r *AccErrEBSPReader) reset(prevState AccErrEBSPReader) error { return nil } -// ReadRbspTrailingBits - read rbsp_traling_bits. Return error if wrong pattern -// If other error, return nil and let AccError provide that error -func (r *AccErrEBSPReader) ReadRbspTrailingBits() error { +// ReadRbspTrailingBits reads rbsp_traling_bits. Returns error if wrong pattern. +// If other error, returns nil and let AccError() provide that error. +func (r *EBSPReader) ReadRbspTrailingBits() error { if r.err != nil { return nil } @@ -250,9 +261,9 @@ func (r *AccErrEBSPReader) ReadRbspTrailingBits() error { } } -// SetError - set an error if not already set. -func (r *AccErrEBSPReader) SetError(err error) { - if r.err != nil { +// SetError sets an error if not already set. +func (r *EBSPReader) SetError(err error) { + if r.err == nil { r.err = err } } diff --git a/bits/ebspreader_test.go b/bits/ebspreader_test.go new file mode 100644 index 00000000..cdebcaec --- /dev/null +++ b/bits/ebspreader_test.go @@ -0,0 +1,352 @@ +package bits_test + +import ( + "bytes" + "encoding/hex" + "io" + "testing" + + "github.com/Eyevinn/mp4ff/bits" + "github.com/go-test/deep" +) + +func TestEBSPReader(t *testing.T) { + t.Run("read bits", func(t *testing.T) { + testCases := []struct { + name string + hexData string + readBits int + expectedValue uint + expectedNrBytesRead int + expectedNrBitsRead int + expectedNrBitsReadInCurrentByte int + expectedError string + }{ + {"zero bits", "00", 0, 0, 0, 0, 8, ""}, + {"1 bit", "0001", 1, 0, 1, 1, 1, ""}, + {"8 bits", "0001", 8, 0x00, 1, 8, 8, ""}, + {"12 bits", "0010", 12, 0x001, 2, 12, 4, ""}, + {"16 bits", "0102", 16, 0x0102, 2, 16, 8, ""}, + {"24 bits including start code emulation", "00000304", 24, 0x000004, 4, 32, 8, ""}, + {"missing data after start code emulation prevention byte", "000003", 17, 0, 0, 0, 0, "EOF"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data, err := hex.DecodeString(tc.hexData) + if err != nil { + t.Error(err) + } + buf := bytes.NewBuffer(data) + r := bits.NewEBSPReader(buf) + if r.AccError() != nil { + t.Errorf("expected no error, got %v", r.AccError()) + } + v := r.Read(tc.readBits) + if tc.expectedError != "" { + gotErr := r.AccError().Error() + if r.AccError().Error() != tc.expectedError { + t.Errorf("expected error %s, got %s", tc.expectedError, gotErr) + } + return + } + if v != tc.expectedValue { + t.Errorf("expected value %d, got %d", tc.expectedValue, v) + } + if r.NrBytesRead() != tc.expectedNrBytesRead { + t.Errorf("expected %d bytes read, got %d", tc.expectedNrBytesRead, r.NrBytesRead()) + } + if r.NrBitsRead() != tc.expectedNrBitsRead { + t.Errorf("expected %d bits read, got %d", tc.expectedNrBitsRead, r.NrBitsRead()) + } + if r.NrBitsReadInCurrentByte() != tc.expectedNrBitsReadInCurrentByte { + t.Errorf("expected %d bits read in current byte, got %d", tc.expectedNrBitsReadInCurrentByte, + r.NrBitsReadInCurrentByte()) + } + }) + } + }) + t.Run("read bytes", func(t *testing.T) { + testCases := []struct { + name string + hexData string + readNrBytes int + expectedValue []byte + expectedError string + }{ + {"1 byte", "01", 1, []byte{0x01}, ""}, + {"too many bytes", "01", 2, nil, "EOF"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + data, err := hex.DecodeString(tc.hexData) + if err != nil { + t.Error(err) + } + buf := bytes.NewBuffer(data) + r := bits.NewEBSPReader(buf) + got := r.ReadBytes(tc.readNrBytes) + err = r.AccError() + if tc.expectedError != "" { + if err == nil { + t.Errorf("expected error %s, got none", tc.expectedError) + } else { + if err.Error() != tc.expectedError { + t.Errorf("expected error %s, got %s", tc.expectedError, err.Error()) + } + } + } else { + if !bytes.Equal(got, tc.expectedValue) { + t.Errorf("expected %v, got %v", tc.expectedValue, got) + } + } + }) + } + }) + t.Run("read exp golomb", func(t *testing.T) { + cases := []struct { + input uint + wanted uint + expectError bool + }{ + {0b10000000, 0, false}, + {0b01000000, 1, false}, + {0b01100000, 2, false}, + {0b00010000, 7, false}, + {0b00010010, 8, false}, + {0b00011110, 14, false}, + {0x00, 0, true}, + } + for _, c := range cases { + buf := bytes.NewBuffer([]byte{byte(c.input)}) + r := bits.NewEBSPReader(buf) + got := r.ReadExpGolomb() + if c.expectError && r.AccError() == nil { + t.Errorf("expected error, got none") + _ = r.ReadExpGolomb() + if r.AccError() == nil { + t.Errorf("expected error to stay") + } + } else { + if got != c.wanted { + t.Errorf("got %d, wanted %d", got, c.wanted) + } + } + } + }) + t.Run("read signed exp golomb", func(t *testing.T) { + cases := []struct { + input uint + wanted int + expectError bool + }{ + {0b10000000, 0, false}, + {0b01000000, 1, false}, + {0b01100000, -1, false}, + {0b00010000, 4, false}, + {0b00010010, -4, false}, + {0b00011110, -7, false}, + {0x00, 0, true}, + } + for _, c := range cases { + buf := bytes.NewBuffer([]byte{byte(c.input)}) + r := bits.NewEBSPReader(buf) + got := r.ReadSignedGolomb() + if c.expectError && r.AccError() == nil { + t.Errorf("expected error, got none") + _ = r.ReadSignedGolomb() + if r.AccError() == nil { + t.Errorf("expected error to stay") + } + } else { + if got != c.wanted { + t.Errorf("got %d, wanted %d", got, c.wanted) + } + } + } + }) + t.Run("remove escape bytes", func(t *testing.T) { + + cases := []struct{ name, start, want string }{ + { + "read last byte", + "32", + "32", + }, + { + "remove escape but not 007", + "27640020ac2ec05005bb011000000300100000078e840016e300005b8d8bdef83b438627", + "27640020ac2ec05005bb0110000000100000078e840016e300005b8d8bdef83b438627", + }, + { + "Long zero sequence", + "00000300000300", + "0000000000", + }, + } + + for _, c := range cases { + byteData, _ := hex.DecodeString(c.start) + buf := bytes.NewBuffer(byteData) + r := bits.NewEBSPReader(buf) + got := []byte{} + for { + b := r.Read(8) + if r.AccError() == io.EOF { + break + } + got = append(got, byte(b)) + } + wantBytes, err := hex.DecodeString(c.want) + if err != nil { + t.Error(err) + } + if diff := deep.Equal(got, wantBytes); diff != nil { + t.Errorf("%s: %v", c.name, diff) + } + } + }) + t.Run("get precise position", func(t *testing.T) { + testCases := []struct { + name string + inBytes []byte + nrBitsToRead int + nrBytesRead int + nrBitsReadInByte int + }{ + { + name: "Read 13 bits", + inBytes: []byte{0, 1, 2}, + nrBitsToRead: 13, + nrBytesRead: 2, + nrBitsReadInByte: 5, + }, + } + + for _, c := range testCases { + t.Run(c.name, func(t *testing.T) { + buf := bytes.NewBuffer(c.inBytes) + r := bits.NewEBSPReader(buf) + _ = r.Read(c.nrBitsToRead) + if r.AccError() != nil { + t.Error(r.AccError()) + } + if r.NrBytesRead() != c.nrBytesRead { + t.Errorf("%s: got %d bytes want %d bytes", c.name, r.NrBytesRead(), c.nrBytesRead) + } + if r.NrBitsReadInCurrentByte() != c.nrBitsReadInByte { + t.Errorf("%s: got %d bits want %d bits", c.name, r.NrBitsReadInCurrentByte(), c.nrBitsReadInByte) + } + }) + } + }) + t.Run("more rbsp data", func(t *testing.T) { + + testCases := []struct { + name string + inBytes []byte + nrBitsBefore int + moreRbsp bool + nrBitsAfter int + valueRead uint + }{ + { + name: "start of byte", + inBytes: []byte{0b10000000}, + nrBitsBefore: 0, + moreRbsp: false, + nrBitsAfter: 0, + valueRead: 0, + }, + { + name: "one bit left", + inBytes: []byte{0b11000000}, + nrBitsBefore: 0, + moreRbsp: true, + nrBitsAfter: 2, + valueRead: 3, + }, + { + name: "after one bit", + inBytes: []byte{0b11000000}, + nrBitsBefore: 1, + moreRbsp: false, + nrBitsAfter: 0, + valueRead: 0, + }, + { + name: "after one byte", + inBytes: []byte{0b11110111, 0b11001000}, + nrBitsBefore: 9, + moreRbsp: true, + nrBitsAfter: 4, + valueRead: 0b00001001, + }, + } + + for _, c := range testCases { + t.Run(c.name, func(t *testing.T) { + brd := bytes.NewReader(c.inBytes) + r := bits.NewEBSPReader(brd) + _ = r.Read(c.nrBitsBefore) + if r.AccError() != nil { + t.Error(r.AccError()) + } + moreRbsp, err := r.MoreRbspData() + if err != nil { + t.Error(err) + } + if moreRbsp != c.moreRbsp { + t.Errorf("%s: got %t want %t", c.name, moreRbsp, c.moreRbsp) + } + got := r.Read(c.nrBitsAfter) + if r.AccError() != nil { + t.Error(r.AccError()) + } + if got != c.valueRead { + t.Errorf("%s: got %d want %d after check", c.name, got, c.valueRead) + } + }) + } + }) + t.Run("read traling RBSP bits", func(t *testing.T) { + input := []byte{0b10000000} + brd := bytes.NewReader(input) + r := bits.NewEBSPReader(brd) + err := r.ReadRbspTrailingBits() + if err != nil { + t.Error(err) + } + _ = r.Read(1) + if r.AccError() != io.EOF { + t.Errorf("Not at end after reading rbsp_trailing_bits") + } + }) + t.Run("read after earlier error", func(t *testing.T) { + input := []byte{0b10000000} + brd := bytes.NewReader(input) + r := bits.NewEBSPReader(brd) + r.SetError(io.ErrUnexpectedEOF) + // Read shold never result in panic + // Error should be preseverd + _ = r.Read(100) + if r.AccError() != io.ErrUnexpectedEOF { + t.Errorf("Expected error not found") + } + _ = r.ReadBytes(100) + if r.AccError() != io.ErrUnexpectedEOF { + t.Errorf("Expected error not found") + } + _ = r.ReadExpGolomb() + if r.AccError() != io.ErrUnexpectedEOF { + t.Errorf("Expected error not found") + } + _ = r.ReadSignedGolomb() + if r.AccError() != io.ErrUnexpectedEOF { + t.Errorf("Expected error not found") + } + _ = r.ReadRbspTrailingBits() + if r.AccError() != io.ErrUnexpectedEOF { + t.Errorf("Expected error not found") + } + }) +} diff --git a/bits/ebsp_writer.go b/bits/ebspwriter.go similarity index 89% rename from bits/ebsp_writer.go rename to bits/ebspwriter.go index d8844e69..65c10fee 100644 --- a/bits/ebsp_writer.go +++ b/bits/ebspwriter.go @@ -4,9 +4,9 @@ import ( "io" ) -// EBSPWriter - write bits and insert start-code emulation prevention bytes as necessary. -// Cease writing at first error. -// Errors that have occurred can later be checked with AccError(). +// EBSPWriter write bits and insert start-code emulation prevention bytes as necessary. +// Ceases writing at first error. +// The first error can later be checked with AccError(). type EBSPWriter struct { wr io.Writer // underlying writer err error // The first error caused by any write operation @@ -32,10 +32,10 @@ func (w *EBSPWriter) Write(bits uint, n int) { return } w.v <<= uint(n) - w.v |= bits & mask(n) + w.v |= bits & Mask(n) w.n += n for w.n >= 8 { - b := (w.v >> (uint(w.n) - 8)) & mask(8) + b := (w.v >> (uint(w.n) - 8)) & Mask(8) if w.nr0 == 2 && b <= 3 { w.out[0] = 0x3 // start code emulation prevention _, err := w.wr.Write(w.out) @@ -58,7 +58,7 @@ func (w *EBSPWriter) Write(bits uint, n int) { } w.n -= 8 } - w.v &= mask(8) + w.v &= Mask(8) } // WriteExpGolomb - write an exponential Golomb code diff --git a/bits/ebspwriter_test.go b/bits/ebspwriter_test.go new file mode 100644 index 00000000..ca70ce80 --- /dev/null +++ b/bits/ebspwriter_test.go @@ -0,0 +1,84 @@ +package bits_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/Eyevinn/mp4ff/bits" + "github.com/go-test/deep" +) + +func TestEBSPWriter(t *testing.T) { + t.Run("insert startcode emulation prevention bytes", func(t *testing.T) { + testCases := []struct { + in []byte + out []byte + }{ + { + in: []byte{0, 0, 0, 1}, + out: []byte{0, 0, 3, 0, 1}, + }, + { + in: []byte{1, 0, 0, 2}, + out: []byte{1, 0, 0, 3, 2}, + }, + { + in: []byte{0, 0, 0, 0, 0}, + out: []byte{0, 0, 3, 0, 0, 3, 0}, + }, + { + in: []byte{0, 0, 0, 0, 5, 0}, + out: []byte{0, 0, 3, 0, 0, 5, 0}, + }, + } + for _, tc := range testCases { + buf := bytes.Buffer{} + w := bits.NewEBSPWriter(&buf) + for _, b := range tc.in { + w.Write(uint(b), 8) + } + diff := deep.Equal(buf.Bytes(), tc.out) + if diff != nil { + t.Errorf("Got %v but wanted %d", buf.Bytes(), tc.out) + } + } + }) + t.Run("write exp golomb", func(t *testing.T) { + cases := []struct { + bits string + n uint + }{ + {"1", 0}, + {"010", 1}, + {"011", 2}, + {"00100", 3}, + {"0001111", 14}, + {"000010000", 15}, + {"000011111", 30}, + } + + for _, tc := range cases { + b := bytes.Buffer{} + w := bits.NewEBSPWriter(&b) + w.WriteExpGolomb(tc.n) + gotBits := getBitsWritten(w, &b) + if gotBits != tc.bits { + t.Errorf("wanted %s but got %s for %d", tc.bits, gotBits, tc.n) + } + } + }) +} + +func getBitsWritten(w *bits.EBSPWriter, b *bytes.Buffer) string { + bits := "" + for _, c := range b.Bytes() { + bits += fmt.Sprintf("%08b", c) + } + valueInWriter, nrBitsInWriter := w.BitsInBuffer() + if nrBitsInWriter > 0 { + fullByte := fmt.Sprintf("%08b", valueInWriter) + bits += fullByte[8-nrBitsInWriter:] + } + return bits +} diff --git a/bits/fixedslicereader_test.go b/bits/fixedslicereader_test.go index f5d6240a..c07fec3e 100644 --- a/bits/fixedslicereader_test.go +++ b/bits/fixedslicereader_test.go @@ -1,162 +1,260 @@ -package bits +package bits_test import ( "fmt" "testing" + + "github.com/Eyevinn/mp4ff/bits" ) -func TestFixedSliceReaderInts(t *testing.T) { - nrBytes := 64 - data := make([]byte, nrBytes) - for i := 0; i < nrBytes; i++ { - data[i] = byte(i) - } - sr := NewFixedSliceReader(data) - if sr.Length() != nrBytes { - t.Errorf("sliceReader length = %d instead of %d", sr.Length(), nrBytes) - } - for i := 0; i < nrBytes+2; i++ { - val := sr.ReadUint8() - switch { - case i < nrBytes: - if i != int(val) { - t.Errorf("val at %d not %d but %d", i, i, val) - } - default: // Read passed end - if sr.AccError() == nil { - t.Errorf("should have had an accumulated error") - } - if 0 != val { - t.Errorf("should have zero value") +func TestFixedSliceReader(t *testing.T) { + t.Run("read ints", func(t *testing.T) { + nrBytes := 64 + data := make([]byte, nrBytes) + for i := 0; i < nrBytes; i++ { + data[i] = byte(i) + } + sr := bits.NewFixedSliceReader(data) + if sr.Length() != nrBytes { + t.Errorf("sliceReader length = %d instead of %d", sr.Length(), nrBytes) + } + for i := 0; i < nrBytes+2; i++ { + val := sr.ReadUint8() + switch { + case i < nrBytes: + if i != int(val) { + t.Errorf("val at %d not %d but %d", i, i, val) + } + default: // Read passed end + if sr.AccError() == nil { + t.Errorf("should have had an accumulated error") + } + if val != 0 { + t.Errorf("should have zero value") + } } } - } - sr = NewFixedSliceReader(data) - nrVals := nrBytes / 2 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadUint16() - switch { - case i < nrVals: - expectedVal := uint16(expVal(i, 2)) - if val != expectedVal { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals := nrBytes / 2 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadUint16() + switch { + case i < nrVals: + expectedVal := uint16(expVal(i, 2)) + if val != expectedVal { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } - sr = NewFixedSliceReader(data) - nrVals = nrBytes / 2 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadInt16() - switch { - case i < nrVals: - expectedVal := int16(expVal(i, 2)) - if val != expectedVal { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals = nrBytes / 2 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadInt16() + switch { + case i < nrVals: + expectedVal := int16(expVal(i, 2)) + if val != expectedVal { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } - sr = NewFixedSliceReader(data) - nrVals = nrBytes / 3 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadUint24() - switch { - case i < nrVals: - expectedVal := uint32(expVal(i, 3)) - if val != expectedVal { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals = nrBytes / 3 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadUint24() + switch { + case i < nrVals: + expectedVal := uint32(expVal(i, 3)) + if val != expectedVal { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } - sr = NewFixedSliceReader(data) - nrVals = nrBytes / 4 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadUint32() - switch { - case i < nrVals: - expectedVal := uint32(expVal(i, 4)) - if val != expectedVal { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals = nrBytes / 4 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadUint32() + switch { + case i < nrVals: + expectedVal := uint32(expVal(i, 4)) + if val != expectedVal { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } - sr = NewFixedSliceReader(data) - nrVals = nrBytes / 4 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadInt32() - switch { - case i < nrVals: - expectedVal := int32(expVal(i, 4)) - if val != expectedVal { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals = nrBytes / 4 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadInt32() + switch { + case i < nrVals: + expectedVal := int32(expVal(i, 4)) + if val != expectedVal { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } - sr = NewFixedSliceReader(data) - nrVals = nrBytes / 8 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadUint64() - switch { - case i < nrVals: - expectedVal := uint64(expVal(i, 8)) - if val != expectedVal { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals = nrBytes / 8 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadUint64() + switch { + case i < nrVals: + expectedVal := uint64(expVal(i, 8)) + if val != expectedVal { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } - sr = NewFixedSliceReader(data) - nrVals = nrBytes / 8 - for i := 0; i < nrVals+2; i++ { - val := sr.ReadInt64() - switch { - case i < nrVals: - expectedVal := int64(expVal(i, 8)) - if val != int64(expVal(i, 8)) { - t.Errorf("val at %d not %d but %d", i, expectedVal, val) + sr = bits.NewFixedSliceReader(data) + nrVals = nrBytes / 8 + for i := 0; i < nrVals+2; i++ { + val := sr.ReadInt64() + switch { + case i < nrVals: + expectedVal := int64(expVal(i, 8)) + if val != int64(expVal(i, 8)) { + t.Errorf("val at %d not %d but %d", i, expectedVal, val) + } + default: // Read passed end + verifyAccErrorInt(t, sr, int(val)) } - default: // Read passed end - verifyAccErrorInt(t, sr, int(val)) } - } + }) + t.Run("read strings", func(t *testing.T) { + data := []byte("0123456789\x00abcdef") + nrBytes := len(data) + sr := bits.NewFixedSliceReader(data) + if sr.Length() != nrBytes { + t.Errorf("sliceReader length = %d instead of %d", sr.Length(), nrBytes) + } + val := sr.ReadFixedLengthString(4) + if val != "0123" { + t.Errorf(`read string is %q instead of "0123"`, val) + } + val = sr.ReadZeroTerminatedString(10) + if val != "456789" { + t.Errorf(`read string is %q instead of "456789"`, val) + } + val = sr.ReadZeroTerminatedString(2) + verifyAccErrorString(t, sr, val) + + val = sr.ReadFixedLengthString(2) + verifyAccErrorString(t, sr, val) + + val = sr.ReadZeroTerminatedString(2) + verifyAccErrorString(t, sr, val) + + sr = bits.NewFixedSliceReader(data) + val = sr.ReadFixedLengthString(nrBytes + 2) + verifyAccErrorString(t, sr, val) + }) + t.Run("read bytes", func(t *testing.T) { + data := []byte("0123456789abcdef") + nrBytes := len(data) + sr := bits.NewFixedSliceReader(data) + if sr.Length() != nrBytes { + t.Errorf("sliceReader length = %d instead of %d", sr.Length(), nrBytes) + } + val := sr.ReadBytes(4) + if string(val) != "0123" { + t.Errorf(`read bytes are %q instead of "0123"`, string(val)) + } + if sr.NrRemainingBytes() != nrBytes-4 { + t.Errorf("nr remaining = %d instead of %d", sr.NrRemainingBytes(), nrBytes-4) + } + sr.SkipBytes(2) + if sr.NrRemainingBytes() != nrBytes-6 { + t.Errorf("nr remaining = %d instead of %d", sr.NrRemainingBytes(), nrBytes-6) + } + sr.SetPos(8) + if sr.NrRemainingBytes() != nrBytes-8 { + t.Errorf("nr remaining = %d instead of %d", sr.NrRemainingBytes(), nrBytes-8) + } + lookAhead := make([]byte, 8) + err := sr.LookAhead(0, lookAhead) + if err != nil { + t.Error(err) + } + if string(lookAhead) != "89abcdef" { + t.Errorf(`lookahead %q instead of "89abcdef"`, lookAhead) + } + lookAhead = make([]byte, 9) + err = sr.LookAhead(0, lookAhead) + if err == nil { + t.Errorf("should be out of bounds") + } + remaining := sr.RemainingBytes() + if string(remaining) != "89abcdef" { + t.Errorf(`remaining %q instead of "89abcdef"`, remaining) + } + + // Beyond end + val = sr.ReadBytes(2) + verifyAccErrorBytes(t, sr, val) + + val = sr.ReadBytes(2) + verifyAccErrorBytes(t, sr, val) + + val = sr.RemainingBytes() + verifyAccErrorBytes(t, sr, val) + + valNr := sr.NrRemainingBytes() + verifyAccErrorInt(t, sr, valNr) + + sr.SkipBytes(2) + verifyAccErrorInt(t, sr, 0) + + sr = bits.NewFixedSliceReader(data) + sr.SkipBytes(nrBytes + 2) + verifyAccErrorInt(t, sr, 0) + + sr.SetPos(nrBytes + 2) + wantedErrMsg := fmt.Sprintf("attempt to set pos %d beyond slice len %d", nrBytes+2, nrBytes) + if sr.AccError().Error() != wantedErrMsg { + t.Errorf("got error msg %q instead of %q", sr.AccError().Error(), wantedErrMsg) + } + }) } -func verifyAccErrorInt(t *testing.T, sr *FixedSliceReader, val int) { +func verifyAccErrorInt(t *testing.T, sr *bits.FixedSliceReader, val int) { if sr.AccError() == nil { t.Errorf("should have had an accumulated error") } - if 0 != val { + if val != 0 { t.Errorf("should have zero value") } } -func verifyAccErrorString(t *testing.T, sr *FixedSliceReader, val string) { +func verifyAccErrorString(t *testing.T, sr *bits.FixedSliceReader, val string) { if sr.AccError() == nil { t.Errorf("should have had an accumulated error") } - if "" != val { + if val != "" { t.Errorf("should be empty string") } } -func verifyAccErrorBytes(t *testing.T, sr *FixedSliceReader, val []byte) { +func verifyAccErrorBytes(t *testing.T, sr *bits.FixedSliceReader, val []byte) { if sr.AccError() == nil { t.Errorf("should have had an accumulated error") } @@ -172,99 +270,3 @@ func expVal(i, nrBytes int) (val uint64) { } return val } - -func TestFixedSliceReaderStrings(t *testing.T) { - data := []byte("0123456789\x00abcdef") - nrBytes := len(data) - sr := NewFixedSliceReader(data) - if sr.Length() != nrBytes { - t.Errorf("sliceReader length = %d instead of %d", sr.Length(), nrBytes) - } - val := sr.ReadFixedLengthString(4) - if val != "0123" { - t.Errorf(`read string is %q instead of "0123"`, val) - } - val = sr.ReadZeroTerminatedString(10) - if val != "456789" { - t.Errorf(`read string is %q instead of "456789"`, val) - } - val = sr.ReadZeroTerminatedString(2) - verifyAccErrorString(t, sr, val) - - val = sr.ReadFixedLengthString(2) - verifyAccErrorString(t, sr, val) - - val = sr.ReadZeroTerminatedString(2) - verifyAccErrorString(t, sr, val) - - sr = NewFixedSliceReader(data) - val = sr.ReadFixedLengthString(nrBytes + 2) - verifyAccErrorString(t, sr, val) -} - -func TestFixedSliceReadBytes(t *testing.T) { - data := []byte("0123456789abcdef") - nrBytes := len(data) - sr := NewFixedSliceReader(data) - if sr.Length() != nrBytes { - t.Errorf("sliceReader length = %d instead of %d", sr.Length(), nrBytes) - } - val := sr.ReadBytes(4) - if string(val) != "0123" { - t.Errorf(`read bytes are %q instead of "0123"`, string(val)) - } - if sr.NrRemainingBytes() != nrBytes-4 { - t.Errorf("nr remaining = %d instead of %d", sr.NrRemainingBytes(), nrBytes-4) - } - sr.SkipBytes(2) - if sr.NrRemainingBytes() != nrBytes-6 { - t.Errorf("nr remaining = %d instead of %d", sr.NrRemainingBytes(), nrBytes-6) - } - sr.SetPos(8) - if sr.NrRemainingBytes() != nrBytes-8 { - t.Errorf("nr remaining = %d instead of %d", sr.NrRemainingBytes(), nrBytes-8) - } - lookAhead := make([]byte, 8) - err := sr.LookAhead(0, lookAhead) - if err != nil { - t.Error(err) - } - if string(lookAhead) != "89abcdef" { - t.Errorf(`lookahead %q instead of "89abcdef"`, lookAhead) - } - lookAhead = make([]byte, 9) - err = sr.LookAhead(0, lookAhead) - if err == nil { - t.Errorf("should be out of bounds") - } - remaining := sr.RemainingBytes() - if string(remaining) != "89abcdef" { - t.Errorf(`remaining %q instead of "89abcdef"`, remaining) - } - - // Beyond end - val = sr.ReadBytes(2) - verifyAccErrorBytes(t, sr, val) - - val = sr.ReadBytes(2) - verifyAccErrorBytes(t, sr, val) - - val = sr.RemainingBytes() - verifyAccErrorBytes(t, sr, val) - - valNr := sr.NrRemainingBytes() - verifyAccErrorInt(t, sr, valNr) - - sr.SkipBytes(2) - verifyAccErrorInt(t, sr, 0) - - sr = NewFixedSliceReader(data) - sr.SkipBytes(nrBytes + 2) - verifyAccErrorInt(t, sr, 0) - - sr.SetPos(nrBytes + 2) - wantedErrMsg := fmt.Sprintf("attempt to set pos %d beyond slice len %d", nrBytes+2, nrBytes) - if sr.AccError().Error() != wantedErrMsg { - t.Errorf("got error msg %q instead of %q", sr.AccError().Error(), wantedErrMsg) - } -} diff --git a/bits/fixedslicewriter.go b/bits/fixedslicewriter.go index f71c9925..05a8e08b 100644 --- a/bits/fixedslicewriter.go +++ b/bits/fixedslicewriter.go @@ -64,7 +64,7 @@ func (sw *FixedSliceWriter) AccError() error { // WriteUint8 - write byte to slice func (sw *FixedSliceWriter) WriteUint8(n byte) { if sw.off+1 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } sw.buf[sw.off] = n @@ -74,7 +74,7 @@ func (sw *FixedSliceWriter) WriteUint8(n byte) { // WriteUint16 - write uint16 to slice func (sw *FixedSliceWriter) WriteUint16(n uint16) { if sw.off+2 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } binary.BigEndian.PutUint16(sw.buf[sw.off:], n) @@ -84,7 +84,7 @@ func (sw *FixedSliceWriter) WriteUint16(n uint16) { // WriteInt16 - write int16 to slice func (sw *FixedSliceWriter) WriteInt16(n int16) { if sw.off+2 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } binary.BigEndian.PutUint16(sw.buf[sw.off:], uint16(n)) @@ -94,7 +94,7 @@ func (sw *FixedSliceWriter) WriteInt16(n int16) { // WriteUint24 - write uint24 to slice func (sw *FixedSliceWriter) WriteUint24(n uint32) { if sw.off+3 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } sw.WriteUint8(byte(n >> 16)) @@ -104,7 +104,7 @@ func (sw *FixedSliceWriter) WriteUint24(n uint32) { // WriteUint32 - write uint32 to slice func (sw *FixedSliceWriter) WriteUint32(n uint32) { if sw.off+4 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } binary.BigEndian.PutUint32(sw.buf[sw.off:], n) @@ -114,7 +114,7 @@ func (sw *FixedSliceWriter) WriteUint32(n uint32) { // WriteInt32 - write int32 to slice func (sw *FixedSliceWriter) WriteInt32(n int32) { if sw.off+4 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } binary.BigEndian.PutUint32(sw.buf[sw.off:], uint32(n)) @@ -123,7 +123,8 @@ func (sw *FixedSliceWriter) WriteInt32(n int32) { // WriteUint48 - write uint48 func (sw *FixedSliceWriter) WriteUint48(u uint64) { - if sw.accError != nil { + if sw.off+6 > len(sw.buf) { + sw.accError = ErrSliceWrite return } msb := uint16(u >> 32) @@ -138,7 +139,7 @@ func (sw *FixedSliceWriter) WriteUint48(u uint64) { // WriteUint64 - write uint64 to slice func (sw *FixedSliceWriter) WriteUint64(n uint64) { if sw.off+8 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } binary.BigEndian.PutUint64(sw.buf[sw.off:], n) @@ -148,7 +149,7 @@ func (sw *FixedSliceWriter) WriteUint64(n uint64) { // WriteInt64 - write int64 to slice func (sw *FixedSliceWriter) WriteInt64(n int64) { if sw.off+8 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } binary.BigEndian.PutUint64(sw.buf[sw.off:], uint64(n)) @@ -162,7 +163,7 @@ func (sw *FixedSliceWriter) WriteString(s string, addZeroEnd bool) { nrNew++ } if sw.off+nrNew > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } copy(sw.buf[sw.off:sw.off+len(s)], s) @@ -176,7 +177,7 @@ func (sw *FixedSliceWriter) WriteString(s string, addZeroEnd bool) { // WriteZeroBytes - write n byte of zeroes func (sw *FixedSliceWriter) WriteZeroBytes(n int) { if sw.off+n > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } for i := 0; i < n; i++ { @@ -188,7 +189,7 @@ func (sw *FixedSliceWriter) WriteZeroBytes(n int) { // WriteBytes - write []byte func (sw *FixedSliceWriter) WriteBytes(byteSlice []byte) { if sw.off+len(byteSlice) > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } copy(sw.buf[sw.off:sw.off+len(byteSlice)], byteSlice) @@ -198,7 +199,7 @@ func (sw *FixedSliceWriter) WriteBytes(byteSlice []byte) { // WriteUnityMatrix - write a unity matrix for mvhd or tkhd func (sw *FixedSliceWriter) WriteUnityMatrix() { if sw.off+36 > len(sw.buf) { - sw.accError = SliceWriterError + sw.accError = ErrSliceWrite return } sw.WriteUint32(0x00010000) // = 1 fixed 16.16 @@ -217,14 +218,14 @@ func (sw *FixedSliceWriter) WriteBits(bits uint, n int) { return } sw.v <<= uint(n) - sw.v |= bits & mask(n) + sw.v |= bits & Mask(n) sw.n += n for sw.n >= 8 { - b := byte((sw.v >> (uint(sw.n) - 8)) & mask(8)) + b := byte((sw.v >> (uint(sw.n) - 8)) & Mask(8)) sw.WriteUint8(b) sw.n -= 8 } - sw.v &= mask(8) + sw.v &= Mask(8) } // WriteFlag writes a flag as 1 bit. @@ -243,7 +244,7 @@ func (sw *FixedSliceWriter) FlushBits() { return } if sw.n != 0 { - b := byte((sw.v << (8 - uint(sw.n))) & mask(8)) + b := byte((sw.v << (8 - uint(sw.n))) & Mask(8)) sw.WriteUint8(b) } } diff --git a/bits/fixedslicewriter_test.go b/bits/fixedslicewriter_test.go index c21e9c8c..a1fcfa38 100644 --- a/bits/fixedslicewriter_test.go +++ b/bits/fixedslicewriter_test.go @@ -1,46 +1,117 @@ -package bits +package bits_test import ( "bytes" "testing" + + "github.com/Eyevinn/mp4ff/bits" ) func TestFixedSliceWriter(t *testing.T) { - sw := NewFixedSliceWriter(20) - sw.WriteUint8(0xff) - sw.WriteUint16(0xffff) - sw.WriteUint24(0xffffff) - sw.WriteUint32(0xffffffff) - sw.WriteUint64(0xffffffffffffffff) - sw.WriteUint16(0) - expected := make([]byte, 20) - for i := range expected { - if i == 18 { - break - } - expected[i] = 0xff - } - if !bytes.Equal(expected, sw.Bytes()) { - t.Errorf("bytes differ: %v %v", expected, sw.Bytes()) - } - sw.WriteUint24(0xffffff) - if sw.AccError() == nil { - t.Errorf("no overflow error") - } -} - -func TestWriteBits(t *testing.T) { - sw := NewFixedSliceWriter(4) - sw.WriteBits(0xf, 4) - sw.WriteBits(0x2, 4) - sw.WriteBits(0x5, 4) - sw.FlushBits() - if sw.AccError() != nil { - t.Errorf("unexpected error writing bits") - } - result := sw.Bytes() - if !(result[0] == 0xf2 && result[1] == 0x50) { - t.Errorf("got %02x%02x instead of 0xf250", result[0], result[1]) - } - sw.WriteUint16(0xffff) + t.Run("Write uints and overflow", func(t *testing.T) { + sw := bits.NewFixedSliceWriter(20) + sw.WriteUint8(0xff) + sw.WriteUint16(0xffff) + sw.WriteUint24(0xffffff) + sw.WriteUint32(0xffffffff) + sw.WriteUint64(0xffffffffffffffff) + sw.WriteUint16(0) + expected := make([]byte, 20) + for i := range expected { + if i == 18 { + break + } + expected[i] = 0xff + } + if !bytes.Equal(expected, sw.Bytes()) { + t.Errorf("bytes differ: %v %v", expected, sw.Bytes()) + } + sw.WriteUint24(0xffffff) + if sw.AccError() == nil { + t.Errorf("no overflow error") + } + // Write past end of slice (should not panic) + sw.WriteUint8(0xff) + sw.WriteUint16(0xffff) + sw.WriteUint32(0xffffffff) + sw.WriteUint48(0xffffffffffff) + sw.WriteUint64(0xffffffff00112233) + sw.WriteInt16(-1) + sw.WriteInt32(-1) + sw.WriteInt64(-1) + sw.WriteString("hello", true) + sw.WriteBytes([]byte{0x01, 0x02, 0x03}) + sw.WriteZeroBytes(7) + sw.WriteUnityMatrix() + sw.WriteBits(0x0f, 4) + sw.FlushBits() + if sw.AccError() == nil { + t.Errorf("no overflow error") + } + }) + t.Run("write signed", func(t *testing.T) { + s := make([]byte, 14) + sw := bits.NewFixedSliceWriterFromSlice(s) + sw.WriteInt16(-1) + sw.WriteInt32(-2) + sw.WriteInt64(-1) + if sw.AccError() != nil { + t.Errorf("unexpected error writing signed") + } + expected := []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + if !bytes.Equal(expected, sw.Bytes()) { + t.Errorf("bytes differ: %v %v", expected, sw.Bytes()) + } + }) + t.Run("write other", func(t *testing.T) { + sw := bits.NewFixedSliceWriter(55) + sw.WriteUint48(0x123456789abc) + sw.WriteZeroBytes(3) + sw.WriteBytes([]byte{0x01, 0x02, 0x03}) + sw.WriteString("hello", true) + sw.WriteUnityMatrix() + sw.WriteFlag(false) + sw.WriteFlag(true) + sw.FlushBits() + if sw.AccError() != nil { + t.Errorf("unexpected error writing other") + } + expected := []byte{0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, + 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, + 0x40} + if !bytes.Equal(sw.Bytes(), expected) { + t.Errorf("bytes differ: %v %v", expected, sw.Bytes()) + } + }) + t.Run("write bits", func(t *testing.T) { + s := make([]byte, 4) + sw := bits.NewFixedSliceWriterFromSlice(s) + cap := sw.Capacity() + if cap != 4 { + t.Errorf("unexpected capacity %d", cap) + } + sw.WriteBits(0xf, 4) + sw.WriteBits(0x2, 4) + sw.WriteBits(0x5, 4) + sw.FlushBits() + if sw.AccError() != nil { + t.Errorf("unexpected error writing bits") + } + wantedOffset := 2 + offset := sw.Offset() + if offset != wantedOffset { + t.Errorf("unexpected offset %d instead of %d", offset, wantedOffset) + } + result := sw.Bytes() + if !(result[0] == 0xf2 && result[1] == 0x50) { + t.Errorf("got %02x%02x instead of 0xf250", result[0], result[1]) + } + sw.WriteUint16(0xffff) + if sw.Len() != 4 { + t.Errorf("unexpected length %d after 4 bytes", sw.Len()) + } + }) } diff --git a/bits/aereader.go b/bits/reader.go similarity index 57% rename from bits/aereader.go rename to bits/reader.go index 8826ad66..725b338c 100644 --- a/bits/aereader.go +++ b/bits/reader.go @@ -7,9 +7,9 @@ import ( "io/ioutil" ) -// AccErrReader - bit reader that accumulates error -// First error can be fetched as reader.AccError() -type AccErrReader struct { +// Reader is a bit reader that stops reading at first error and stores it. +// First error can be fetched usiin AccError(). +type Reader struct { rd io.Reader err error nrBits int // current number of bits @@ -17,19 +17,19 @@ type AccErrReader struct { } // AccError - accumulated error is first error that occurred -func (r *AccErrReader) AccError() error { +func (r *Reader) AccError() error { return r.err } -// NewAccErrReader - return a new Reader -func NewAccErrReader(rd io.Reader) *AccErrReader { - return &AccErrReader{ +// NewReader return a new Reader that accumulates errors. +func NewReader(rd io.Reader) *Reader { + return &Reader{ rd: rd, } } // Read - read n bits. Return 0, if error now or previously -func (r *AccErrReader) Read(n int) uint { +func (r *Reader) Read(n int) uint { if r.err != nil { return 0 } @@ -49,13 +49,13 @@ func (r *AccErrReader) Read(n int) uint { value := r.value >> uint(r.nrBits-n) r.nrBits -= n - r.value &= mask(r.nrBits) + r.value &= Mask(r.nrBits) return value } // ReadSigned reads a 2-complemented signed int with n bits. -func (r *AccErrReader) ReadSigned(n int) int { +func (r *Reader) ReadSigned(n int) int { nr := int(r.Read(n)) firstBit := nr >> (n - 1) if firstBit == 1 { @@ -64,8 +64,8 @@ func (r *AccErrReader) ReadSigned(n int) int { return nr } -// ReadFlag - read 1 bit into flag. Return false if error now or previously -func (r *AccErrReader) ReadFlag() bool { +// ReadFlag reads 1 bit and interprets as a boolean flag. Returns false if error now or previously. +func (r *Reader) ReadFlag() bool { bit := r.Read(1) if r.err != nil { return false @@ -73,19 +73,8 @@ func (r *AccErrReader) ReadFlag() bool { return bit == 1 } -// ReadFlag - Read i(v) which is 2-complement of n bits -func (r *AccErrReader) ReadVInt(n int) int { - uval := r.Read(n) - var ival int - - if uval >= 1<<(n/2) { - ival = int(uval) - (1 << n) - } - return ival -} - -// ReadRemainingBytes - read remaining bytes if byte-aligned -func (r *AccErrReader) ReadRemainingBytes() []byte { +// ReadRemainingBytes reads remaining bytes if byte-aligned. Returns nil if error now or previously. +func (r *Reader) ReadRemainingBytes() []byte { if r.err != nil { return nil } diff --git a/bits/reader_test.go b/bits/reader_test.go new file mode 100644 index 00000000..fb000bbb --- /dev/null +++ b/bits/reader_test.go @@ -0,0 +1,164 @@ +package bits_test + +import ( + "bytes" + "io" + "testing" + + "github.com/Eyevinn/mp4ff/bits" +) + +func TestAccErrReader(t *testing.T) { + t.Run("Read bits", func(t *testing.T) { + input := []byte{0xff, 0x0f} // 1111 1111 0000 1111 + rd := bytes.NewReader(input) + reader := bits.NewReader(rd) + + cases := []struct { + n int + want uint + }{ + {2, 3}, // 11 + {3, 7}, // 111 + {5, 28}, // 11100 + {3, 1}, // 001 + {3, 7}, // 111 + } + + for _, tc := range cases { + got := reader.Read(tc.n) + + if got != tc.want { + t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) + } + } + err := reader.AccError() + if err != nil { + t.Errorf("Got accumulated error: %s", err.Error()) + } + }) + t.Run("Read signed bits", func(t *testing.T) { + input := []byte{0xff, 0x0f} // 1111 1111 0000 1111 + rd := bytes.NewReader(input) + reader := bits.NewReader(rd) + + cases := []struct { + n int + want uint + }{ + {2, 3}, // 11 + {3, 7}, // 111 + {5, 28}, // 11100 + {3, 1}, // 001 + {3, 7}, // 111 + } + + for _, tc := range cases { + got := reader.Read(tc.n) + + if got != tc.want { + t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) + } + } + err := reader.AccError() + if err != nil { + t.Errorf("Got accumulated error: %s", err.Error()) + } + }) + t.Run("Read remaining bytes", func(t *testing.T) { + input := []byte{0xef, 0x0f} // 1110 1111 0000 1111 + rd := bytes.NewReader(input) + r := bits.NewReader(rd) + gotRemaining := r.ReadRemainingBytes() + if !bytes.Equal(gotRemaining, input) { + t.Errorf("ReadRemainingBytes()=%b, want=%b", gotRemaining, input) + } + // Next check with non-bytealigned status + rd = bytes.NewReader(input) + r = bits.NewReader(rd) + _ = r.ReadFlag() + _ = r.ReadRemainingBytes() + err := r.AccError() + if err == nil { + t.Errorf("Expected error due to not byte-aligned, but got nil") + } + }) + + t.Run("Read flags", func(t *testing.T) { + input := []byte{0xe5} // 1110 0101 + rd := bytes.NewReader(input) + r := bits.NewReader(rd) + for i := 0; i < 8; i++ { + expectedValue := (input[0] >> (7 - i) & 1) > 0 + gotFlag := r.ReadFlag() + if gotFlag != expectedValue { + t.Errorf("Read flag %t, but wanted %t, case %d", gotFlag, expectedValue, i) + } + } + _ = r.ReadFlag() + wantedError := io.EOF + if err := r.AccError(); err != wantedError { + t.Errorf("wanted error %s, got %s", wantedError, err.Error()) + } + }) +} + +func TestAccErrReaderSigned(t *testing.T) { + input := []byte{0xff, 0x0c} // 1111 1111 0000 1100 + rd := bytes.NewReader(input) + reader := bits.NewReader(rd) + + cases := []struct { + n int + want int + }{ + {2, -1}, // 11 + {3, -1}, // 111 + {5, -4}, // 11100 + {3, 1}, // 001 + {3, -4}, // 100 + } + + for _, tc := range cases { + got := reader.ReadSigned(tc.n) + + if got != tc.want { + t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) + } + } + err := reader.AccError() + if err != nil { + t.Errorf("Got accumulated error: %s", err.Error()) + } +} + +func TestBadAccErrReader(t *testing.T) { + // Check that reading beyond EOF provides value = 0 after acc error + input := []byte{0xff, 0x0f} // 1111 1111 0000 1111 + rd := bytes.NewReader(input) + reader := bits.NewReader(rd) + + cases := []struct { + err error + n int + want uint + }{ + {nil, 2, 3}, // 11 + {nil, 3, 7}, // 111 + {io.EOF, 12, 0}, // 0 because of error + {io.EOF, 3, 0}, // 0 because of acc error + {io.EOF, 3, 0}, // 0 because of acc error + } + + for _, tc := range cases { + got := reader.Read(tc.n) + + if got != tc.want { + t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) + } + } + err := reader.AccError() + if err != io.EOF { + t.Errorf("Wanted io.EOF but got %v", err) + } +} diff --git a/bits/slicewriter.go b/bits/slicewriter.go index cc1767cc..f771202b 100644 --- a/bits/slicewriter.go +++ b/bits/slicewriter.go @@ -4,6 +4,8 @@ import ( "errors" ) +var ErrSliceWrite = errors.New("overflow in SliceWriter") + type SliceWriter interface { Len() int Capacity() int @@ -27,5 +29,3 @@ type SliceWriter interface { WriteFlag(f bool) FlushBits() } - -var SliceWriterError = errors.New("overflow in SliceWriter") diff --git a/bits/writer.go b/bits/writer.go new file mode 100644 index 00000000..35a7509a --- /dev/null +++ b/bits/writer.go @@ -0,0 +1,69 @@ +package bits + +import ( + "encoding/binary" + "io" +) + +// Writer writes bits into underlying io.Writer. Stops writing at first error. +// That first error is stored and can later be checked with AccError(). +type Writer struct { + wr io.Writer + err error // The first error caused by any write operation + out []byte // Slice of length 1 to avoid allocation at output + n int // current number of bits + v uint // current accumulated value +} + +// NewWriter returns a new Writer +func NewWriter(w io.Writer) *Writer { + return &Writer{ + wr: w, + out: make([]byte, 1), + } +} + +// Write writes n bits from bits and saves error state. +func (w *Writer) Write(bits uint, n int) { + if w.err != nil { + return + } + w.v <<= uint(n) + w.v |= bits & Mask(n) + w.n += n + for w.n >= 8 { + b := (w.v >> (uint(w.n) - 8)) & Mask(8) + w.out[0] = uint8(b) + _, err := w.wr.Write(w.out) + if err != nil { + w.err = err + return + } + w.n -= 8 + } + w.v &= Mask(8) +} + +// Flush writes remaining bits to the underlying io.Writer by adding zeros to the right. +func (w *Writer) Flush() { + if w.err != nil { + return + } + if w.n != 0 { + b := (w.v << (8 - uint(w.n))) & Mask(8) + if err := binary.Write(w.wr, binary.BigEndian, uint8(b)); err != nil { + w.err = err + return + } + } +} + +// AccError returns the first error that occurred and stopped writing. +func (w *Writer) AccError() error { + return w.err +} + +// Mask returns a binary mask for the n least significant bits. +func Mask(n int) uint { + return (1 << uint(n)) - 1 +} diff --git a/bits/bits_benchmark_test.go b/bits/writer_benchmark_test.go similarity index 63% rename from bits/bits_benchmark_test.go rename to bits/writer_benchmark_test.go index 3214bc08..3cb182df 100644 --- a/bits/bits_benchmark_test.go +++ b/bits/writer_benchmark_test.go @@ -1,24 +1,26 @@ -package bits +package bits_test import ( "io/ioutil" "testing" + + "github.com/Eyevinn/mp4ff/bits" ) func BenchmarkWrite(b *testing.B) { - writer := NewWriter(ioutil.Discard) + writer := bits.NewWriter(ioutil.Discard) b.ResetTimer() for i := 0; i < b.N; i++ { writer.Write(0xff, 8) } - err := writer.Error() + err := writer.AccError() if err != nil { b.Fatal(err) } } func BenchmarkEbspWrite(b *testing.B) { - writer := NewEBSPWriter(ioutil.Discard) + writer := bits.NewEBSPWriter(ioutil.Discard) b.ResetTimer() for i := 0; i < b.N; i++ { writer.Write(0xff, 8) diff --git a/bits/bits_test.go b/bits/writer_test.go similarity index 63% rename from bits/bits_test.go rename to bits/writer_test.go index 63546b5f..61a383f5 100644 --- a/bits/bits_test.go +++ b/bits/writer_test.go @@ -1,40 +1,12 @@ -package bits +package bits_test import ( "bytes" "fmt" - "io" "testing" -) - -func TestReader(t *testing.T) { - input := []byte{0xff, 0x0f} // 1111 1111 0000 1111 - rd := bytes.NewReader(input) - reader := NewReader(rd) - - cases := []struct { - n int - want uint - }{ - {2, 3}, // 11 - {3, 7}, // 111 - {5, 28}, // 11100 - {3, 1}, // 001 - {3, 7}, // 111 - } - - for _, tc := range cases { - got, err := reader.Read(tc.n) - if err != nil && err != io.EOF { - t.Fatalf("Read(%d) should not fail: %s", tc.n, err) - } - - if got != tc.want { - t.Errorf("Read(%d)=%b, want=%b", tc.n, got, tc.want) - } - } -} + "github.com/Eyevinn/mp4ff/bits" +) func TestWriter(t *testing.T) { cases := []struct { @@ -53,18 +25,18 @@ func TestWriter(t *testing.T) { for _, tc := range cases { var buf bytes.Buffer - writer := NewWriter(&buf) + writer := bits.NewWriter(&buf) for _, input := range tc.inputs { writer.Write(input, tc.size) } - err := writer.Error() + err := writer.AccError() if err != nil { t.Fatalf("Write should not fail: %s", err) } writer.Flush() - err = writer.Error() + err = writer.AccError() if err != nil { t.Fatalf("Flush should not fail: %s", err) } @@ -86,7 +58,7 @@ func TestMask(t *testing.T) { } for _, tc := range cases { - m := mask(tc.input) + m := bits.Mask(tc.input) if got := fmt.Sprintf("%08b", m); got != tc.want { t.Errorf("mask(%d)=%s,want=%s", tc.input, got, tc.want) } diff --git a/hevc/pps.go b/hevc/pps.go index 37e5c327..1e46d03d 100644 --- a/hevc/pps.go +++ b/hevc/pps.go @@ -182,7 +182,7 @@ func ParsePPSNALUnit(data []byte, spsMap map[uint32]*SPS) (*PPS, error) { pps := &PPS{} rd := bytes.NewReader(data) - r := bits.NewAccErrEBSPReader(rd) + r := bits.NewEBSPReader(rd) // Note! First two bytes are NALU Header naluHdrBits := r.Read(16) @@ -323,7 +323,7 @@ func ParsePPSNALUnit(data []byte, spsMap map[uint32]*SPS) (*PPS, error) { return pps, nil } -func parseRangeExtension(r *bits.AccErrEBSPReader, transformSkipEnabled bool) (*RangeExtension, error) { +func parseRangeExtension(r *bits.EBSPReader, transformSkipEnabled bool) (*RangeExtension, error) { ext := &RangeExtension{} if transformSkipEnabled { ext.Log2MaxTransformSkipBlockSizeMinus2 = r.ReadExpGolomb() @@ -349,7 +349,7 @@ func parseRangeExtension(r *bits.AccErrEBSPReader, transformSkipEnabled bool) (* return ext, nil } -func parseMultilayerExtension(r *bits.AccErrEBSPReader) (*MultilayerExtension, error) { +func parseMultilayerExtension(r *bits.EBSPReader) (*MultilayerExtension, error) { ext := &MultilayerExtension{} ext.PocResetInfoPresentFlag = r.ReadFlag() ext.InferScalingListFlag = r.ReadFlag() @@ -404,7 +404,7 @@ func parseMultilayerExtension(r *bits.AccErrEBSPReader) (*MultilayerExtension, e return ext, nil } -func parseColourMappingTable(r *bits.AccErrEBSPReader) (*ColourMappingTable, error) { +func parseColourMappingTable(r *bits.EBSPReader) (*ColourMappingTable, error) { cm := &ColourMappingTable{} // value shall be in the range of 0 to 61, inclusive cm.NumCmRefLayersMinus1 = uint8(r.ReadExpGolomb()) @@ -447,7 +447,7 @@ func parseColourMappingTable(r *bits.AccErrEBSPReader) (*ColourMappingTable, err return cm, nil } -func parseColourMappingOctants(r *bits.AccErrEBSPReader, octantDepth uint, partNumY uint, resLsBits int, +func parseColourMappingOctants(r *bits.EBSPReader, octantDepth uint, partNumY uint, resLsBits int, inpDepth, idxY, idxCb, idxCr, inpLength uint) (map[string][4]Octant, error) { var octs map[string][4]Octant @@ -502,7 +502,7 @@ func makeKeyOctant(idxShiftY, idxCb, idxCr uint) string { return fmt.Sprintf("%d-%d-%d", idxShiftY, idxCb, idxCr) } -func parseSccExtension(r *bits.AccErrEBSPReader) (*SccExtension, error) { +func parseSccExtension(r *bits.EBSPReader) (*SccExtension, error) { ext := &SccExtension{} ext.CurrPicRefEnabledFlag = r.ReadFlag() ext.ResidualAdaptiveColourTransformEnabledFlag = r.ReadFlag() @@ -546,7 +546,7 @@ func parseSccExtension(r *bits.AccErrEBSPReader) (*SccExtension, error) { return ext, nil } -func parse3dExtension(r *bits.AccErrEBSPReader) (*D3Extension, error) { +func parse3dExtension(r *bits.EBSPReader) (*D3Extension, error) { ext := &D3Extension{} ext.DltsPresentFlag = r.ReadFlag() if ext.DltsPresentFlag { @@ -585,7 +585,7 @@ func parse3dExtension(r *bits.AccErrEBSPReader) (*D3Extension, error) { return ext, nil } -func parseDeltaDlt(r *bits.AccErrEBSPReader, BitDepthForDepthLayers int) (*DeltaDlt, error) { +func parseDeltaDlt(r *bits.EBSPReader, BitDepthForDepthLayers int) (*DeltaDlt, error) { dd := &DeltaDlt{} dd.NumValDeltaDlt = r.Read(BitDepthForDepthLayers) if dd.NumValDeltaDlt > 0 { diff --git a/hevc/slice.go b/hevc/slice.go index a8d731f1..ed7ce88e 100644 --- a/hevc/slice.go +++ b/hevc/slice.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "github.com/Eyevinn/mp4ff/bits" ) @@ -111,7 +112,7 @@ func ParseSliceHeader(nalu []byte, spsMap map[uint32]*SPS, ppsMap map[uint32]*PP sh := &SliceHeader{} buf := bytes.NewBuffer(nalu) - r := bits.NewAccErrEBSPReader(buf) + r := bits.NewEBSPReader(buf) naluHdrBits := r.Read(16) naluType := GetNaluType(byte(naluHdrBits >> 8)) @@ -380,7 +381,7 @@ func ParseSliceHeader(nalu []byte, spsMap map[uint32]*SPS, ppsMap map[uint32]*PP return sh, nil } -func parseRefPicListsModification(r *bits.AccErrEBSPReader, sliceType SliceType, +func parseRefPicListsModification(r *bits.EBSPReader, sliceType SliceType, refIdxL0Minus1, refIdxL1Minus1 uint8, numPicTotalCurr uint8) (*RefPicListsModification, error) { rplm := &RefPicListsModification{ RefPicListModificationFlagL0: r.ReadFlag(), @@ -408,7 +409,7 @@ func parseRefPicListsModification(r *bits.AccErrEBSPReader, sliceType SliceType, return rplm, nil } -func parsePredWeightTable(r *bits.AccErrEBSPReader, sliceType SliceType, +func parsePredWeightTable(r *bits.EBSPReader, sliceType SliceType, refIdxL0Minus1, refIdxL1Minus1 uint8, chromaArrayType byte) (*PredWeightTable, error) { pwt := &PredWeightTable{ // value shall be in the range of 0 to 7, inclusive diff --git a/hevc/sps.go b/hevc/sps.go index b7edb3e2..cfcaf810 100644 --- a/hevc/sps.go +++ b/hevc/sps.go @@ -103,7 +103,7 @@ func flagFrom(flags uint64, bitNr uint) bool { } // parseProfileTierLevel follows ISO/IEC 23008-2 Section 7.3.3 -func parseProfileTierLevel(r *bits.AccErrEBSPReader, profilePresentFlag bool, maxNumSubLayersMinus1 byte) ProfileTierLevel { +func parseProfileTierLevel(r *bits.EBSPReader, profilePresentFlag bool, maxNumSubLayersMinus1 byte) ProfileTierLevel { ptl := ProfileTierLevel{} if profilePresentFlag { ptl.GeneralProfileSpace = byte(r.Read(2)) @@ -310,7 +310,7 @@ func ParseSPSNALUnit(data []byte) (*SPS, error) { sps := &SPS{} rd := bytes.NewReader(data) - r := bits.NewAccErrEBSPReader(rd) + r := bits.NewEBSPReader(rd) // Note! First two bytes are NALU Header naluHdrBits := r.Read(16) @@ -497,7 +497,7 @@ func (s *SPS) ImageSize() (width, height uint32) { // parseVUI - parse VUI (Visual Usability Information) // if parseVUIBeyondAspectRatio is false, stop after AspectRatio has been parsed -func parseVUI(r *bits.AccErrEBSPReader, MaxSubLayersMinus1 byte) *VUIParameters { +func parseVUI(r *bits.EBSPReader, MaxSubLayersMinus1 byte) *VUIParameters { vui := &VUIParameters{} aspectRatioInfoPresentFlag := r.ReadFlag() if aspectRatioInfoPresentFlag { @@ -564,7 +564,7 @@ func parseVUI(r *bits.AccErrEBSPReader, MaxSubLayersMinus1 byte) *VUIParameters return vui } -func parseHrdParameters(r *bits.AccErrEBSPReader, +func parseHrdParameters(r *bits.EBSPReader, commonInfPresentFlag bool, maxNumSubLayersMinus1 byte) *HrdParameters { hp := &HrdParameters{} if commonInfPresentFlag { @@ -622,7 +622,7 @@ func parseHrdParameters(r *bits.AccErrEBSPReader, return hp } -func parseSubLayerHrdParameters(r *bits.AccErrEBSPReader, +func parseSubLayerHrdParameters(r *bits.EBSPReader, cpbCntMinus1 uint8, subPicHrdParamsPresentFlag bool) []SubLayerHrdParameters { slhp := make([]SubLayerHrdParameters, cpbCntMinus1+1) for i := uint8(0); i <= cpbCntMinus1; i++ { @@ -638,7 +638,7 @@ func parseSubLayerHrdParameters(r *bits.AccErrEBSPReader, return slhp } -func parseBitstreamRestrictions(r *bits.AccErrEBSPReader) *BitstreamRestrictions { +func parseBitstreamRestrictions(r *bits.EBSPReader) *BitstreamRestrictions { br := BitstreamRestrictions{} br.TilesFixedStructureFlag = r.ReadFlag() br.MVOverPicBoundariesFlag = r.ReadFlag() @@ -682,7 +682,7 @@ const maxSTRefPics = 16 // parseShortTermRPS - short-term reference pictures with syntax from 7.3.7. // Focus is on reading/parsing beyond this structure in SPS (and possibly in slice header) -func parseShortTermRPS(r *bits.AccErrEBSPReader, idx, numSTRefPicSets byte, sps *SPS) ShortTermRPS { +func parseShortTermRPS(r *bits.EBSPReader, idx, numSTRefPicSets byte, sps *SPS) ShortTermRPS { stps := ShortTermRPS{} interRPSPredFlag := false @@ -740,7 +740,7 @@ func parseShortTermRPS(r *bits.AccErrEBSPReader, idx, numSTRefPicSets byte, sps } // readPastScalingListData - read and parse all bits of scaling list, without storing values -func readPastScalingListData(r *bits.AccErrEBSPReader) { +func readPastScalingListData(r *bits.EBSPReader) { for sizeId := 0; sizeId < 4; sizeId++ { nrMatrixIds := 6 if sizeId == 3 { @@ -770,7 +770,7 @@ func readPastScalingListData(r *bits.AccErrEBSPReader) { } } -func parseSPS3dExtension(r *bits.AccErrEBSPReader) *SPS3dExtension { +func parseSPS3dExtension(r *bits.EBSPReader) *SPS3dExtension { ext := &SPS3dExtension{ IvDiMcEnabledFlag0: r.ReadFlag(), IvMvScalEnabledFlag0: r.ReadFlag(), @@ -793,7 +793,7 @@ func parseSPS3dExtension(r *bits.AccErrEBSPReader) *SPS3dExtension { return ext } -func parseSPSSccExtension(r *bits.AccErrEBSPReader, ChromaFormatIDC, +func parseSPSSccExtension(r *bits.EBSPReader, ChromaFormatIDC, BitDepthLumaMinus8, BitDepthChromaMinus8 byte) *SPSSccExtension { ext := &SPSSccExtension{} ext.CurrPicRefEnabledFlag = r.ReadFlag() diff --git a/mp4/dac3.go b/mp4/dac3.go index bed22953..b442be5e 100644 --- a/mp4/dac3.go +++ b/mp4/dac3.go @@ -85,13 +85,14 @@ func decodeDac3FromData(data []byte) (Box, error) { buf := bytes.NewBuffer(data) br := bits.NewReader(buf) b := Dac3Box{} - b.FSCod = byte(br.MustRead(2)) - b.BSID = byte(br.MustRead(5)) - b.BSMod = byte(br.MustRead(3)) - b.ACMod = byte(br.MustRead(3)) - b.LFEOn = byte(br.MustRead(1)) - b.BitRateCode = byte(br.MustRead(5)) + b.FSCod = byte(br.Read(2)) + b.BSID = byte(br.Read(5)) + b.BSMod = byte(br.Read(3)) + b.ACMod = byte(br.Read(3)) + b.LFEOn = byte(br.Read(1)) + b.BitRateCode = byte(br.Read(5)) // 5 bits reserved follows + _ = br.Read(5) return &b, nil } diff --git a/mp4/dec3.go b/mp4/dec3.go index acc8b98e..5d124f49 100644 --- a/mp4/dec3.go +++ b/mp4/dec3.go @@ -85,7 +85,7 @@ func DecodeDec3SR(hdr BoxHeader, startPos uint64, sr bits.SliceReader) (Box, err func decodeDec3FromData(data []byte) (Box, error) { buf := bytes.NewBuffer(data) - br := bits.NewAccErrReader(buf) + br := bits.NewReader(buf) b := Dec3Box{} b.DataRate = uint16(br.Read(13)) nrSubs := br.Read(3) + 1 // There must be one base stream diff --git a/mp4/file.go b/mp4/file.go index 8d6332d7..7a55141d 100644 --- a/mp4/file.go +++ b/mp4/file.go @@ -371,10 +371,20 @@ func (f *File) findAndReadMfra(r io.Reader) error { if !ok { return fmt.Errorf("expecting mfra box, but got %T", b) } - if len(mfra.Tfras) != 1 { - return fmt.Errorf("only supports exactly one tfra in mfra") + f.tfra = mfra.Tfras[0] + for i := 1; i < len(mfra.Tfras); i++ { + if mfra.Tfras[i].TrackID == f.tfra.TrackID { + return fmt.Errorf("only one tfra box per trackID is supported") + } + if len(mfra.Tfras[i].Entries) != len(f.tfra.Entries) { + return fmt.Errorf("tfra boxes with different number of entries are not supported") + } + for j := 0; j < len(mfra.Tfras[i].Entries); j++ { + if mfra.Tfras[i].Entries[j].MoofOffset != f.tfra.Entries[j].MoofOffset { + return fmt.Errorf("tfra boxes with different moof offsets for different tracks are not supported") + } + } } - f.tfra = mfra.Tfra _, err = rs.Seek(0, io.SeekStart) return err } diff --git a/sei/sei.go b/sei/sei.go index 1a946bb5..208230bc 100644 --- a/sei/sei.go +++ b/sei/sei.go @@ -553,7 +553,7 @@ func (s *SEIData) Size() uint { // In case the rbsp_trailing_bits 0x80 byte is missing at end, []seiData and // an ErrMissingRbspTrailingBits error are both returned. func ExtractSEIData(r io.ReadSeeker) (seiData []SEIData, err error) { - ar := bits.NewAccErrEBSPReader(r) + ar := bits.NewEBSPReader(r) for { payloadType := uint(0) for { diff --git a/sei/sei136.go b/sei/sei136.go index 43bafd0e..1afe4bbb 100644 --- a/sei/sei136.go +++ b/sei/sei136.go @@ -41,7 +41,7 @@ func CreateClockTS() ClockTS { return ClockTS{} } -func DecodeClockTS(br *bits.AccErrReader) ClockTS { +func DecodeClockTS(br *bits.Reader) ClockTS { c := CreateClockTS() c.ClockTimeStampFlag = br.ReadFlag() if c.ClockTimeStampFlag { @@ -80,7 +80,7 @@ func DecodeClockTS(br *bits.AccErrReader) ClockTS { // DecodeTimeCodeSEI decodes SEI message 136 TimeCode. func DecodeTimeCodeSEI(sd *SEIData) (SEIMessage, error) { buf := bytes.NewBuffer(sd.Payload()) - br := bits.NewAccErrReader(buf) + br := bits.NewReader(buf) numClockTS := int(br.Read(2)) tc := TimeCodeSEI{make([]ClockTS, 0, numClockTS)} for i := 0; i < numClockTS; i++ { diff --git a/sei/sei1_avc.go b/sei/sei1_avc.go index 542e516c..651fec28 100644 --- a/sei/sei1_avc.go +++ b/sei/sei1_avc.go @@ -45,7 +45,7 @@ func DecodePicTimingAvcSEI(sd *SEIData) (SEIMessage, error) { // It is assumed that pict_struct_present_flag is true, so that a 4-bit pict_struct value is present. func DecodePicTimingAvcSEIHRD(sd *SEIData, cbpDbpDelay *CbpDbpDelay, timeOffsetLen byte) (SEIMessage, error) { buf := bytes.NewBuffer(sd.Payload()) - br := bits.NewAccErrReader(buf) + br := bits.NewReader(buf) var outCbDbpDelay CbpDbpDelay if cbpDbpDelay != nil { outCbDbpDelay = *cbpDbpDelay @@ -168,7 +168,7 @@ func CreateClockTSAvc(timeOffsetLen byte) ClockTSAvc { } } -func DecodeClockTSAvc(br *bits.AccErrReader, timeOffsetLen byte) ClockTSAvc { +func DecodeClockTSAvc(br *bits.Reader, timeOffsetLen byte) ClockTSAvc { c := CreateClockTSAvc(timeOffsetLen) c.ClockTimeStampFlag = br.ReadFlag() if c.ClockTimeStampFlag { diff --git a/sei/sei1_hevc.go b/sei/sei1_hevc.go index 966d1053..1fe54aeb 100644 --- a/sei/sei1_hevc.go +++ b/sei/sei1_hevc.go @@ -43,7 +43,7 @@ type HEVCFrameFieldInfo struct { func DecodePicTimingHevcSEI(sd *SEIData, exPar HEVCPicTimingParams) (SEIMessage, error) { buf := bytes.NewBuffer(sd.Payload()) - br := bits.NewAccErrEBSPReader(buf) + br := bits.NewEBSPReader(buf) pt := PicTimingHevcSEI{ payload: sd.Payload(), }