Skip to content

Commit

Permalink
Add return types to Tensor class
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Apr 6, 2024
1 parent 44f8a0b commit b88b6a2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ export class Tensor {
* If you would like a copy, use `tensor.clone()` before squeezing.
*
* @param {number} [dim=null] If given, the input will be squeezed only in the specified dimensions.
* @returns The squeezed tensor
* @returns {Tensor} The squeezed tensor
*/
squeeze(dim = null) {
return new Tensor(
Expand All @@ -504,7 +504,7 @@ export class Tensor {
* NOTE: The returned tensor shares the same underlying data with this tensor.
*
* @param {number} dim The index at which to insert the singleton dimension
* @returns The unsqueezed tensor
* @returns {Tensor} The unsqueezed tensor
*/
unsqueeze(dim = null) {
return new Tensor(
Expand Down Expand Up @@ -543,7 +543,7 @@ export class Tensor {
* and ending with `end_dim` are flattened. The order of elements in input is unchanged.
* @param {number} start_dim the first dim to flatten
* @param {number} end_dim the last dim to flatten
* @returns The flattened tensor.
* @returns {Tensor} The flattened tensor.
*/
flatten(start_dim = 0, end_dim = -1) {
return this.clone().flatten_(start_dim, end_dim);
Expand Down Expand Up @@ -601,7 +601,7 @@ export class Tensor {
* Clamps all elements in input into the range [ min, max ]
* @param {number} min lower-bound of the range to be clamped to
* @param {number} max upper-bound of the range to be clamped to
* @returns the output tensor.
* @returns {Tensor} the output tensor.
*/
clamp(min, max) {
return this.clone().clamp_(min, max);
Expand All @@ -619,7 +619,7 @@ export class Tensor {

/**
* Rounds elements of input to the nearest integer.
* @returns the output tensor.
* @returns {Tensor} the output tensor.
*/
round() {
return this.clone().round_();
Expand Down Expand Up @@ -828,7 +828,7 @@ export function layer_norm(input, normalized_shape, {
* Helper function to calculate new dimensions when performing a squeeze operation.
* @param {number[]} dims The dimensions of the tensor.
* @param {number|number[]|null} dim The dimension(s) to squeeze.
* @returns The new dimensions.
* @returns {number[]} The new dimensions.
* @private
*/
function calc_squeeze_dims(dims, dim) {
Expand All @@ -851,7 +851,7 @@ function calc_squeeze_dims(dims, dim) {
* Helper function to calculate new dimensions when performing an unsqueeze operation.
* @param {number[]} dims The dimensions of the tensor.
* @param {number} dim The dimension to unsqueeze.
* @returns The new dimensions.
* @returns {number[]} The new dimensions.
* @private
*/
function calc_unsqueeze_dims(dims, dim) {
Expand Down Expand Up @@ -1038,7 +1038,7 @@ export function std_mean(input, dim = null, correction = 1, keepdim = false) {
* @param {Tensor} input the input tensor.
* @param {number|null} dim the dimension to reduce.
* @param {boolean} keepdim whether the output tensor has dim retained or not.
* @returns A new tensor with means taken along the specified dimension.
* @returns {Tensor} A new tensor with means taken along the specified dimension.
*/
export function mean(input, dim = null, keepdim = false) {

Expand Down

0 comments on commit b88b6a2

Please sign in to comment.