Skip to content

Commit

Permalink
Bounds checks for Predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
ynse01 committed Nov 16, 2024
1 parent e58a775 commit 51d54c6
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 84 deletions.
13 changes: 8 additions & 5 deletions src/ImageSharp/Formats/Heif/Av1/Prediction/Av1DcFillPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,19 @@ public Av1DcFillPredictor(Av1TransformSize transformSize)
this.blockHeight = (uint)transformSize.GetHeight();
}

public static void PredictScalar(Av1TransformSize transformSize, ref byte destination, nuint stride, ref byte above, ref byte left)
=> new Av1DcFillPredictor(transformSize).PredictScalar(ref destination, stride, ref above, ref left);
public static void PredictScalar(Av1TransformSize transformSize, Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
=> new Av1DcFillPredictor(transformSize).PredictScalar(destination, stride, above, left);

public void PredictScalar(ref byte destination, nuint stride, ref byte above, ref byte left)
public void PredictScalar(Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
{
const byte expectedDc = 0x80;
Guard.MustBeGreaterThanOrEqualTo(stride, this.blockWidth, nameof(stride));
Guard.MustBeSizedAtLeast(destination, (int)this.blockHeight * (int)stride, nameof(destination));
ref byte destinationRef = ref destination[0];
for (uint r = 0; r < this.blockHeight; r++)
{
Unsafe.InitBlock(ref destination, expectedDc, this.blockWidth);
destination = ref Unsafe.Add(ref destination, stride);
Unsafe.InitBlock(ref destinationRef, expectedDc, this.blockWidth);
destinationRef = ref Unsafe.Add(ref destinationRef, stride);
}
}
}
17 changes: 11 additions & 6 deletions src/ImageSharp/Formats/Heif/Av1/Prediction/Av1DcLeftPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,27 @@ public Av1DcLeftPredictor(Av1TransformSize transformSize)
this.blockHeight = (uint)transformSize.GetHeight();
}

public static void PredictScalar(Av1TransformSize transformSize, ref byte destination, nuint stride, ref byte above, ref byte left)
=> new Av1DcLeftPredictor(transformSize).PredictScalar(ref destination, stride, ref above, ref left);
public static void PredictScalar(Av1TransformSize transformSize, Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
=> new Av1DcLeftPredictor(transformSize).PredictScalar(destination, stride, above, left);

public void PredictScalar(ref byte destination, nuint stride, ref byte above, ref byte left)
public void PredictScalar(Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
{
int sum = 0;
Guard.MustBeGreaterThanOrEqualTo(stride, this.blockWidth, nameof(stride));
Guard.MustBeSizedAtLeast(left, (int)this.blockHeight, nameof(left));
Guard.MustBeSizedAtLeast(destination, (int)this.blockHeight * (int)stride, nameof(destination));
ref byte leftRef = ref left[0];
ref byte destinationRef = ref destination[0];
for (uint i = 0; i < this.blockHeight; i++)
{
sum += Unsafe.Add(ref left, i);
sum += Unsafe.Add(ref leftRef, i);
}

byte expectedDc = (byte)((sum + (this.blockHeight >> 1)) / this.blockHeight);
for (uint r = 0; r < this.blockHeight; r++)
{
Unsafe.InitBlock(ref destination, expectedDc, this.blockWidth);
destination = ref Unsafe.Add(ref destination, stride);
Unsafe.InitBlock(ref destinationRef, expectedDc, this.blockWidth);
destinationRef = ref Unsafe.Add(ref destinationRef, stride);
}
}
}
21 changes: 14 additions & 7 deletions src/ImageSharp/Formats/Heif/Av1/Prediction/Av1DcPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,35 @@ public Av1DcPredictor(Av1TransformSize transformSize)
this.blockHeight = (uint)transformSize.GetHeight();
}

public static void PredictScalar(Av1TransformSize transformSize, ref byte destination, nuint stride, ref byte above, ref byte left)
=> new Av1DcPredictor(transformSize).PredictScalar(ref destination, stride, ref above, ref left);
public static void PredictScalar(Av1TransformSize transformSize, Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
=> new Av1DcPredictor(transformSize).PredictScalar(destination, stride, above, left);

public void PredictScalar(ref byte destination, nuint stride, ref byte above, ref byte left)
public void PredictScalar(Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
{
int sum = 0;
Guard.MustBeGreaterThanOrEqualTo(stride, this.blockWidth, nameof(stride));
Guard.MustBeSizedAtLeast(left, (int)this.blockHeight, nameof(left));
Guard.MustBeSizedAtLeast(above, (int)this.blockWidth, nameof(above));
Guard.MustBeSizedAtLeast(destination, (int)this.blockHeight * (int)stride, nameof(destination));
ref byte leftRef = ref left[0];
ref byte aboveRef = ref above[0];
ref byte destinationRef = ref destination[0];
uint count = this.blockWidth + this.blockHeight;
for (uint i = 0; i < this.blockWidth; i++)
{
sum += Unsafe.Add(ref above, i);
sum += Unsafe.Add(ref aboveRef, i);
}

for (uint i = 0; i < this.blockHeight; i++)
{
sum += Unsafe.Add(ref left, i);
sum += Unsafe.Add(ref leftRef, i);
}

byte expectedDc = (byte)((sum + (count >> 1)) / count);
for (uint r = 0; r < this.blockHeight; r++)
{
Unsafe.InitBlock(ref destination, expectedDc, this.blockWidth);
destination = ref Unsafe.Add(ref destination, stride);
Unsafe.InitBlock(ref destinationRef, expectedDc, this.blockWidth);
destinationRef = ref Unsafe.Add(ref destinationRef, stride);
}
}
}
17 changes: 11 additions & 6 deletions src/ImageSharp/Formats/Heif/Av1/Prediction/Av1DcTopPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,27 @@ public Av1DcTopPredictor(Av1TransformSize transformSize)
this.blockHeight = (uint)transformSize.GetHeight();
}

public static void PredictScalar(Av1TransformSize transformSize, ref byte destination, nuint stride, ref byte above, ref byte left)
=> new Av1DcTopPredictor(transformSize).PredictScalar(ref destination, stride, ref above, ref left);
public static void PredictScalar(Av1TransformSize transformSize, Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
=> new Av1DcTopPredictor(transformSize).PredictScalar(destination, stride, above, left);

public void PredictScalar(ref byte destination, nuint stride, ref byte above, ref byte left)
public void PredictScalar(Span<byte> destination, nuint stride, Span<byte> above, Span<byte> left)
{
int sum = 0;
Guard.MustBeGreaterThanOrEqualTo(stride, this.blockWidth, nameof(stride));
Guard.MustBeSizedAtLeast(above, (int)this.blockWidth, nameof(above));
Guard.MustBeSizedAtLeast(destination, (int)this.blockHeight * (int)stride, nameof(destination));
ref byte aboveRef = ref above[0];
ref byte destinationRef = ref destination[0];
for (uint i = 0; i < this.blockWidth; i++)
{
sum += Unsafe.Add(ref above, i);
sum += Unsafe.Add(ref aboveRef, i);
}

byte expectedDc = (byte)((sum + (this.blockWidth >> 1)) / this.blockWidth);
for (uint r = 0; r < this.blockHeight; r++)
{
Unsafe.InitBlock(ref destination, expectedDc, this.blockWidth);
destination = ref Unsafe.Add(ref destination, stride);
Unsafe.InitBlock(ref destinationRef, expectedDc, this.blockWidth);
destinationRef = ref Unsafe.Add(ref destinationRef, stride);
}
}
}
73 changes: 35 additions & 38 deletions src/ImageSharp/Formats/Heif/Av1/Prediction/Av1PredictionDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using SixLabors.ImageSharp.Formats.Heif.Av1.Prediction.ChromaFromLuma;
using SixLabors.ImageSharp.Formats.Heif.Av1.Tiling;
using SixLabors.ImageSharp.Formats.Heif.Av1.Transform;
using SixLabors.ImageSharp.Memory;

namespace SixLabors.ImageSharp.Formats.Heif.Av1.Prediction;

Expand Down Expand Up @@ -37,12 +36,9 @@ public void Decode(
int blockModeInfoRowOffset)
{
int bytesPerPixel = (bitDepth == Av1BitDepth.EightBit && !this.is16BitPipeline) ? 2 : 1;
ref byte pixelRef = ref pixelBuffer[0];
ref byte topNeighbor = ref pixelRef;
ref byte leftNeighbor = ref pixelRef;
int stride = pixelStride * bytesPerPixel;
topNeighbor = Unsafe.Subtract(ref topNeighbor, stride);
leftNeighbor = Unsafe.Subtract(ref leftNeighbor, 1);
Span<byte> topNeighbor = pixelBuffer.Slice(-stride);
Span<byte> leftNeighbor = pixelBuffer.Slice(-1);

bool is16BitPipeline = this.is16BitPipeline;
Av1PredictionMode mode = (plane == Av1Plane.Y) ? partitionInfo.ModeInfo.YMode : partitionInfo.ModeInfo.UvMode;
Expand All @@ -54,10 +50,10 @@ public void Decode(
plane,
transformSize,
tileInfo,
ref pixelRef,
pixelBuffer,
stride,
ref topNeighbor,
ref leftNeighbor,
topNeighbor,
leftNeighbor,
stride,
mode,
blockModeInfoColumnOffset,
Expand All @@ -80,10 +76,10 @@ public void Decode(
plane,
transformSize,
tileInfo,
ref pixelRef,
pixelBuffer,
stride,
ref topNeighbor,
ref leftNeighbor,
topNeighbor,
leftNeighbor,
stride,
mode,
blockModeInfoColumnOffset,
Expand Down Expand Up @@ -199,10 +195,10 @@ private void PredictIntraBlock(
Av1Plane plane,
Av1TransformSize transformSize,
Av1TileInfo tileInfo,
ref byte pixelBuffer,
Span<byte> pixelBuffer,
int pixelBufferStride,
ref byte topNeighbor,
ref byte leftNeighbor,
Span<byte> topNeighbor,
Span<byte> leftNeighbor,
int referenceStride,
Av1PredictionMode mode,
int blockModeInfoColumnOffset,
Expand Down Expand Up @@ -290,10 +286,10 @@ private void PredictIntraBlock(
{
this.DecodeBuildIntraPredictors(
partitionInfo,
ref topNeighbor,
ref leftNeighbor,
topNeighbor,
leftNeighbor,
(nuint)referenceStride,
ref pixelBuffer,
pixelBuffer,
(nuint)pixelBufferStride,
mode,
angleDelta,
Expand Down Expand Up @@ -560,10 +556,10 @@ private static bool IntraHasTopRight(Av1BlockSize superblockSize, Av1BlockSize b

private void DecodeBuildIntraPredictors(
Av1PartitionInfo partitionInfo,
ref byte aboveNeighbor,
ref byte leftNeighbor,
Span<byte> aboveNeighbor,
Span<byte> leftNeighbor,
nuint referenceStride,
ref byte destination,
Span<byte> destination,
nuint destinationStride,
Av1PredictionMode mode,
int angleDelta,
Expand Down Expand Up @@ -630,17 +626,18 @@ private void DecodeBuildIntraPredictors(
byte val;
if (needLeft)
{
val = (byte)((topPixelCount > 0) ? aboveNeighbor : 129);
val = (byte)((topPixelCount > 0) ? aboveNeighbor[0] : 129);
}
else
{
val = (byte)((leftPixelCount > 0) ? leftNeighbor : 127);
val = (byte)((leftPixelCount > 0) ? leftNeighbor[0] : 127);
}

ref byte destinationRef = ref destination[0];
for (int i = 0; i < transformHeight; ++i)
{
Unsafe.InitBlock(ref destination, val, (uint)transformWidth);
destination = ref Unsafe.Add(ref destination, destinationStride);
Unsafe.InitBlock(ref destinationRef, val, (uint)transformWidth);
destinationRef = ref Unsafe.Add(ref destinationRef, destinationStride);
}

return;
Expand All @@ -666,15 +663,15 @@ private void DecodeBuildIntraPredictors(
{
for (; i < leftPixelCount; i++)
{
leftColumn[i] = Unsafe.Add(ref leftNeighbor, i * (int)referenceStride);
leftColumn[i] = leftNeighbor[i * (int)referenceStride];
}

if (needBottom && bottomLeftPixelCount > 0)
{
Guard.IsTrue(i == transformHeight, nameof(i), string.Empty);
for (; i < transformHeight + bottomLeftPixelCount; i++)
{
leftColumn[i] = Unsafe.Add(ref leftNeighbor, i * (int)referenceStride);
leftColumn[i] = leftNeighbor[i * (int)referenceStride];
}
}

Expand All @@ -687,7 +684,7 @@ private void DecodeBuildIntraPredictors(
{
if (topPixelCount > 0)
{
Unsafe.InitBlock(ref leftColumn[0], aboveNeighbor, numLeftPixelsNeeded);
Unsafe.InitBlock(ref leftColumn[0], aboveNeighbor[0], numLeftPixelsNeeded);
}
else
{
Expand All @@ -713,12 +710,12 @@ private void DecodeBuildIntraPredictors(
uint numTopPixelsNeeded = (uint)(transformWidth + (needRight ? transformHeight : 0));
if (topPixelCount > 0)
{
Unsafe.CopyBlock(ref aboveRow[0], ref aboveNeighbor, (uint)topPixelCount);
Unsafe.CopyBlock(ref aboveRow[0], ref aboveNeighbor[0], (uint)topPixelCount);
int i = topPixelCount;
if (needRight && topPixelCount > 0)
{
Guard.IsTrue(topPixelCount == transformWidth, nameof(topPixelCount), string.Empty);
Unsafe.CopyBlock(ref aboveRow[transformWidth], ref Unsafe.Add(ref aboveNeighbor, transformWidth), (uint)topPixelCount);
Unsafe.CopyBlock(ref aboveRow[transformWidth], ref aboveNeighbor[transformWidth], (uint)topPixelCount);
i += topPixelCount;
}

Expand All @@ -731,7 +728,7 @@ private void DecodeBuildIntraPredictors(
{
if (leftPixelCount > 0)
{
Unsafe.InitBlock(ref aboveRow[0], leftNeighbor, numTopPixelsNeeded);
Unsafe.InitBlock(ref aboveRow[0], leftNeighbor[0], numTopPixelsNeeded);
}
else
{
Expand All @@ -744,15 +741,15 @@ private void DecodeBuildIntraPredictors(
{
if (topPixelCount > 0 && leftPixelCount > 0)
{
aboveRow[-1] = Unsafe.Subtract(ref aboveNeighbor, 1);
aboveRow[-1] = aboveNeighbor[-1];
}
else if (topPixelCount > 0)
{
aboveRow[-1] = aboveNeighbor;
aboveRow[-1] = aboveNeighbor[0];
}
else if (leftPixelCount > 0)
{
aboveRow[-1] = leftNeighbor;
aboveRow[-1] = leftNeighbor[0];
}
else
{
Expand All @@ -764,7 +761,7 @@ private void DecodeBuildIntraPredictors(

if (useFilterIntra)
{
Av1PredictorFactory.FilterIntraPredictor(ref destination, destinationStride, transformSize, aboveRow, leftColumn, filterIntraMode);
Av1PredictorFactory.FilterIntraPredictor(destination, destinationStride, transformSize, aboveRow, leftColumn, filterIntraMode);
return;
}

Expand Down Expand Up @@ -819,18 +816,18 @@ private void DecodeBuildIntraPredictors(
}
}

Av1PredictorFactory.DirectionalPredictor(ref destination, destinationStride, transformSize, aboveRow, leftColumn, upsampleAbove, upsampleLeft, angle);
Av1PredictorFactory.DirectionalPredictor(destination, destinationStride, transformSize, aboveRow, leftColumn, upsampleAbove, upsampleLeft, angle);
return;
}

// predict
if (mode == Av1PredictionMode.DC)
{
Av1PredictorFactory.DcPredictor(leftPixelCount > 0, topPixelCount > 0, transformSize, ref destination, destinationStride, aboveRow, leftColumn);
Av1PredictorFactory.DcPredictor(leftPixelCount > 0, topPixelCount > 0, transformSize, destination, destinationStride, aboveRow, leftColumn);
}
else
{
Av1PredictorFactory.GeneralPredictor(mode, transformSize, ref destination, destinationStride, aboveRow, leftColumn);
Av1PredictorFactory.GeneralPredictor(mode, transformSize, destination, destinationStride, aboveRow, leftColumn);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,35 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Prediction;

internal class Av1PredictorFactory
{
internal static void DcPredictor(bool hasLeft, bool hasAbove, Av1TransformSize transformSize, ref byte destination, nuint destinationStride, Span<byte> aboveRow, Span<byte> leftColumn)
internal static void DcPredictor(bool hasLeft, bool hasAbove, Av1TransformSize transformSize, Span<byte> destination, nuint destinationStride, Span<byte> aboveRow, Span<byte> leftColumn)
{
if (hasLeft)
{
if (hasAbove)
{
Av1DcPredictor.PredictScalar(transformSize, ref destination, destinationStride, ref aboveRow[0], ref leftColumn[0]);
Av1DcPredictor.PredictScalar(transformSize, destination, destinationStride, aboveRow, leftColumn);
}
else
{
Av1DcLeftPredictor.PredictScalar(transformSize, ref destination, destinationStride, ref aboveRow[0], ref leftColumn[0]);
Av1DcLeftPredictor.PredictScalar(transformSize, destination, destinationStride, aboveRow, leftColumn);
}
}
else
{
if (hasAbove)
{
Av1DcTopPredictor.PredictScalar(transformSize, ref destination, destinationStride, ref aboveRow[0], ref leftColumn[0]);
Av1DcTopPredictor.PredictScalar(transformSize, destination, destinationStride, aboveRow, leftColumn);
}
else
{
Av1DcFillPredictor.PredictScalar(transformSize, ref destination, destinationStride, ref aboveRow[0], ref leftColumn[0]);
Av1DcFillPredictor.PredictScalar(transformSize, destination, destinationStride, aboveRow, leftColumn);
}
}
}

internal static void DirectionalPredictor(ref byte destination, nuint destinationStride, Av1TransformSize transformSize, Span<byte> aboveRow, Span<byte> leftColumn, bool upsampleAbove, bool upsampleLeft, int angle) => throw new NotImplementedException();
internal static void DirectionalPredictor(Span<byte> destination, nuint destinationStride, Av1TransformSize transformSize, Span<byte> aboveRow, Span<byte> leftColumn, bool upsampleAbove, bool upsampleLeft, int angle) => throw new NotImplementedException();

internal static void FilterIntraPredictor(ref byte destination, nuint destinationStride, Av1TransformSize transformSize, Span<byte> aboveRow, Span<byte> leftColumn, Av1FilterIntraMode filterIntraMode) => throw new NotImplementedException();
internal static void FilterIntraPredictor(Span<byte> destination, nuint destinationStride, Av1TransformSize transformSize, Span<byte> aboveRow, Span<byte> leftColumn, Av1FilterIntraMode filterIntraMode) => throw new NotImplementedException();

internal static void GeneralPredictor(Av1PredictionMode mode, Av1TransformSize transformSize, ref byte destination, nuint destinationStride, Span<byte> aboveRow, Span<byte> leftColumn) => throw new NotImplementedException();
internal static void GeneralPredictor(Av1PredictionMode mode, Av1TransformSize transformSize, Span<byte> destination, nuint destinationStride, Span<byte> aboveRow, Span<byte> leftColumn) => throw new NotImplementedException();
}
Loading

0 comments on commit 51d54c6

Please sign in to comment.