Skip to content

Commit

Permalink
Update to use the new Python custom op APIs
Browse files Browse the repository at this point in the history
Won't land this until 2.4 comes by.

[ghstack-poisoned]
  • Loading branch information
zou3519 committed Apr 26, 2024
1 parent a5ed0b0 commit f99d09c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 32 deletions.
1 change: 0 additions & 1 deletion extension_cpp/csrc/lltm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}

// Defines the operators
TORCH_LIBRARY(extension_cpp, m) {
m.impl_abstract_pystub("extension_cpp.ops");
m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)");
m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)");
}
Expand Down
66 changes: 35 additions & 31 deletions extension_cpp/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,41 @@
def lltm(
input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor
) -> Tuple[Tensor, Tensor]:
return LLTMFunction.apply(input, weights, bias, old_h, old_cell)


class LLTMFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = torch.ops.extension_cpp.lltm_forward.default(
input, weights, bias, old_h, old_cell
)
new_h, new_cell = outputs[:2]
variables = list(outputs[1:]) + [weights]
ctx.save_for_backward(*variables)

return new_h, new_cell

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_h, grad_cell):
(
d_old_h,
d_input,
d_weights,
d_bias,
d_old_cell,
) = torch.ops.extension_cpp.lltm_backward.default(
grad_h, grad_cell, *ctx.saved_tensors
)
return d_input, d_weights, d_bias, d_old_h, d_old_cell


@torch.library.impl_abstract("extension_cpp::lltm_forward")
"""The lltm API"""
outputs = torch.ops.extension_cpp.lltm_forward.default(
input, weights, bias, old_h, old_cell
)
new_h, new_cell = outputs[:2]
return new_h, new_cell


# This is the backward for lltm_forward.
# lltm_forward has 7 returns so they all get gradients.
def backward(ctx, grad_h, grad_cell, _0, _1, _2, _3, _4):
(
d_old_h,
d_input,
d_weights,
d_bias,
d_old_cell,
) = torch.ops.extension_cpp.lltm_backward.default(
grad_h, grad_cell, *ctx.saved_tensors
)
return d_input, d_weights, d_bias, d_old_h, d_old_cell


def setup_context(ctx, inputs, output):
weights = inputs[1]
new_h, new_cell = output[:2]
variables = list(output[1:]) + [weights]
ctx.save_for_backward(*variables)


torch.library.register_autograd(
"extension_cpp::lltm_forward", backward, setup_context=setup_context)


@torch.library.register_fake("extension_cpp::lltm_forward")
def _(input, weights, bias, old_h, old_cell):
X = torch.cat([old_h, input], dim=1)
gate_weights = torch.nn.functional.linear(X, weights, bias)
Expand Down

0 comments on commit f99d09c

Please sign in to comment.