Skip to content

Commit

Permalink
upstream: Make it compatible with nn.Linear
Browse files Browse the repository at this point in the history
  • Loading branch information
VlaDexa committed May 27, 2024
1 parent c74c30a commit b669f9e
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ impl<B: Backend> Linear<B> {

pub(crate) fn forward(&self, x: &Tensor<B, 2>) -> Tensor<B, 2> {
assert_eq!(x.dims()[1], self.in_features as usize);
let mut original_shape = x.dims();
let x = x.clone().reshape([-1, self.in_features as i32]);
let base_output = burn::nn::Linear {
weight: self.base_weight.clone(),
bias: None,
Expand All @@ -284,6 +286,10 @@ impl<B: Backend> Linear<B> {
}
.forward(spline_output_input);

base_output.add(spline_output)
let output = base_output.add(spline_output);
if let Some(x) = original_shape.last_mut() {
*x = self.out_features as usize;
}
output.reshape(original_shape)
}
}

0 comments on commit b669f9e

Please sign in to comment.