Skip to content

Commit

Permalink
update causal cache
Browse files Browse the repository at this point in the history
  • Loading branch information
wtomin committed Nov 30, 2024
1 parent 1b79586 commit a196345
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from collections import deque
from typing import Tuple, Union

import mindspore as ms
Expand Down Expand Up @@ -115,33 +116,38 @@ def __init__(
**kwargs,
)
self.enable_cached = enable_cached
self.causal_cached = None
self.is_first_chunk = True

self.causal_cached = deque()
self.cache_offset = 0

def construct(self, x):
x_dtype = x.dtype
# x: (bs, Cin, T, H, W )
# first_frame_pad = ops.repeat_interleave(first_frame, (self.time_kernel_size - 1), axis=2)
if self.time_kernel_size - 1 > 0:
if self.causal_cached is None:
if self.is_first_chunk:
first_frame = x[:, :, :1, :, :]
first_frame_pad = mint.cat([first_frame] * (self.time_kernel_size - 1), dim=2)
# first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
else:
first_frame_pad = self.causal_cached
first_frame_pad = self.causal_cached.popleft()

x = mint.cat((first_frame_pad, x), dim=2)

if self.enable_cached and self.time_kernel_size != 1:
if (self.time_kernel_size - 1) // self.stride[0] != 0:
if self.cache_offset == 0:
self.causal_cached = x[:, :, -(self.time_kernel_size - 1) // self.stride[0] :]
causal_cached = x[:, :, -(self.time_kernel_size - 1) // self.stride[0] :]
else:
self.causal_cached = x[:, :, : -self.cache_offset][
causal_cached = x[:, :, : -self.cache_offset][
:, :, -(self.time_kernel_size - 1) // self.stride[0] :
]
else:
self.causal_cached = x[:, :, 0:0, :, :]
causal_cached = x[:, :, 0:0, :, :]
self.causal_cached.append(causal_cached.copy())
elif self.enable_cached:
self.causal_cached.append(x[:, :, 0:0, :, :].copy())

if npu_config is not None and npu_config.on_npu:
return npu_config.run_conv3d(self.conv, x, x_dtype)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from collections import deque
from typing import Tuple, Union

from opensora.npu_config import npu_config
Expand Down Expand Up @@ -324,27 +325,27 @@ def __init__(
self.conv = CausalConv3d(in_channels, out_channels, kernel_size=3, padding=1)
self.interpolate = TrilinearInterpolate()
self.enable_cached = enable_cached
self.causal_cached = None
self.causal_cached = deque()

def construct(self, x):
if x.shape[2] > 1 or self.causal_cached is not None:
if self.enable_cached and self.causal_cached is not None:
x = mint.cat([self.causal_cached, x], dim=2)
self.causal_cached = x[:, :, -2:-1]
if x.shape[2] > 1 or len(self.causal_cached) > 0:
if self.enable_cached and len(self.causal_cached) > 0:
x = mint.cat([self.causal_cached.popleft(), x], dim=2)
self.causal_cached.append(x[:, :, -2:-1].copy())
x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(2.0, 1.0, 1.0))
x = x[:, :, 2:]
x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0))
else:
if self.enable_cached:
self.causal_cached = x[:, :, -1:]
self.causal_cached.append(x[:, :, -1:].copy())
x, x_ = x[:, :, :1], x[:, :, 1:]
x_ = npu_config.run_interpolate(self.interpolate, x_, scale_factor=(2.0, 1.0, 1.0))
x_ = npu_config.run_interpolate(self.interpolate, x_, scale_factor=(1.0, 2.0, 2.0))
x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0))
x = mint.cat([x, x_], dim=2)
else:
if self.enable_cached:
self.causal_cached = x[:, :, -1:]
self.causal_cached.append(x[:, :, -1:].copy())
x = npu_config.run_interpolate(self.interpolate, x, scale_factor=(1.0, 2.0, 2.0))
return self.conv(x)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(self, enable_cached=False, dtype=ms.float32, *args, **kwargs) -> No
self.hh_v = Tensor([[[1, 1], [-1, -1]], [[-1, -1], [1, 1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536
self.gh_v = Tensor([[[1, -1], [-1, 1]], [[-1, 1], [1, -1]]], dtype=dtype).view(1, 1, 2, 2, 2) * 0.3536
self.enable_cached = enable_cached
self.causal_cached = None
self.is_first_chunk = True
self.conv_transpose3d = ops.Conv3DTranspose(1, 1, kernel_size=2, stride=2)

def construct(self, coeffs):
Expand Down Expand Up @@ -185,7 +185,7 @@ def construct(self, coeffs):
high_low_high = self.conv_transpose3d(high_low_high, self.g_v)
high_high_low = self.conv_transpose3d(high_high_low, self.hh_v)
high_high_high = self.conv_transpose3d(high_high_high, self.gh_v)
if self.enable_cached and self.causal_cached:
if self.enable_cached and not self.is_first_chunk:
reconstructed = (
low_low_low
+ low_low_high
Expand All @@ -207,7 +207,7 @@ def construct(self, coeffs):
+ high_high_low[:, :, 1:]
+ high_high_high[:, :, 1:]
)
self.causal_cached = True

reconstructed = reconstructed.reshape(b, -1, *reconstructed.shape[-3:])

return reconstructed.to(input_dtype)
Expand Down
Loading

0 comments on commit a196345

Please sign in to comment.