Skip to content

Commit

Permalink
Merge pull request #39 from SciNim/add_contiguous
Browse files Browse the repository at this point in the history
add contiguous function
  • Loading branch information
Clonkk authored Oct 13, 2023
2 parents 1ac1a01 + 90b40e3 commit 0b9afe7
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flambeau/install/torch_installer.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
std/[asyncdispatch, httpclient,
std/[httpclient,
strformat, strutils, os],
#zippy/ziparchives,
zip/zipfiles
Expand Down
1 change: 1 addition & 0 deletions flambeau/raw/bindings/rawtensors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ template `==`*(a, b: RawTensor): bool =
# Functions.h
# -----------------------------------------------------------------------

func contiguous*(self: RawTensor): RawTensor {.importcpp: "#.contiguous(@)".}
func toType*(self: RawTensor, dtype: ScalarKind): RawTensor {.importcpp: "#.toType(@)".}
func toSparse*(self: RawTensor): RawTensor {.importcpp: "#.to_sparse()".}
func toSparse*(self: RawTensor, sparseDim: int64): RawTensor {.importcpp: "#.to_sparse(@)".}
Expand Down
5 changes: 5 additions & 0 deletions flambeau/tensors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,11 @@ func backward*[T](self: var Tensor[T]) =

# # Functions.h
# # -----------------------------------------------------------------------
func contiguous*[T](self: Tensor[T]) : Tensor[T] =
asTensor[T](
rawtensors.contiguous(asRaw(self))
)

func toType*[T](self: Tensor[T], dtype: ScalarKind): Tensor[T] =
asTensor[T](
rawtensors.toType(asRaw(self), dtype)
Expand Down
5 changes: 5 additions & 0 deletions tests/tensor/test_accessors_slicer.nim
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ proc main() =
let test = @[@[1], @[16], @[81], @[256], @[625]]
check: t_van[_.._, 3] == test.toTensor().squeeze()

test "Span slices - foo[_, 3] in assignment":
let test = @[@[1], @[16], @[81], @[256], @[625]]
var tmp = t_van[_, 3]
check tmp == test.toTensor().squeeze()

test "Stepping - foo[1..3|2, 3]":
let test = @[@[16], @[256]]
check: t_van[1..3|2, 3] == test.toTensor().squeeze()
Expand Down

0 comments on commit 0b9afe7

Please sign in to comment.