Skip to content

Commit

Permalink
Update to test zero length tensor support.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Nov 7, 2023
1 parent d3cd86d commit a43c380
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 4 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

git_repository(
name = "ccv",
commit = "38eca3efbcb47fc4fcd06d2bd3d81e4bfbd76cd2",
commit = "e921d4d66f6def2e8ed082e48a00cebd8306085f",
remote = "https://github.com/liuliu/ccv.git",
shallow_since = "1698450773 -0400",
shallow_since = "1699332984 -0500",
)

load("@ccv//config:ccv.bzl", "ccv_deps", "ccv_setting")
Expand Down
4 changes: 2 additions & 2 deletions deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ def s4nnc_deps():
git_repository,
name = "ccv",
remote = "https://github.com/liuliu/ccv.git",
commit = "38eca3efbcb47fc4fcd06d2bd3d81e4bfbd76cd2",
shallow_since = "1698450773 -0400",
commit = "e921d4d66f6def2e8ed082e48a00cebd8306085f",
shallow_since = "1699332984 -0500",
)

_maybe(
Expand Down
12 changes: 12 additions & 0 deletions test/graph.swift
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,17 @@ final class GraphTests: XCTestCase {
XCTAssertEqual(b0[3], a1[0, 1, 0])
}

func testConcatZeroLengthTensor() throws {
let dynamicGraph = DynamicGraph()
let a0 = dynamicGraph.variable(.CPU, format: .NCHW, shape: [], of: Float.self)
let a1 = dynamicGraph.variable(Tensor<Float>([1, 2, 3, 4], .CPU, .NC(2, 2)))
let b0 = Concat(axis: 1)(inputs: a0, a1)[0].as(of: Float.self)
XCTAssertEqual(b0[0, 0], 1)
XCTAssertEqual(b0[0, 1], 2)
XCTAssertEqual(b0[1, 0], 3)
XCTAssertEqual(b0[1, 1], 4)
}

static let allTests = [
("testGEMM", testGEMM),
("testGEMMGrad", testGEMMGrad),
Expand All @@ -401,5 +412,6 @@ final class GraphTests: XCTestCase {
("testPermute", testPermute),
("testPermuteAndGetASubset", testPermuteAndGetASubset),
("testPermuteAndReshape", testPermuteAndReshape),
("testConcatZeroLengthTensor", testConcatZeroLengthTensor),
]
}
5 changes: 5 additions & 0 deletions test/tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ import NNC
import XCTest

final class TensorTests: XCTestCase {
func testCreateZeroLengthTensor() throws {
let tensor = Tensor<Float>(.CPU, format: .NHWC, shape: [])
XCTAssertEqual([], tensor.shape)
}

func testGetSetPartTensor() throws {
var tensor = Tensor<Int32>(.CPU, .NC(5, 2))
Expand Down Expand Up @@ -193,6 +197,7 @@ final class TensorTests: XCTestCase {
}

static let allTests = [
("testCreateZeroLengthTensor", testCreateZeroLengthTensor),
("testGetSetPartTensor", testGetSetPartTensor),
("testGetSetPartTensorFromArray", testGetSetPartTensorFromArray),
("testGetSetUnboundedPartTensorFromArray", testGetSetUnboundedPartTensorFromArray),
Expand Down

0 comments on commit a43c380

Please sign in to comment.