Skip to content

Commit

Permalink
Merge pull request #27 from UchidaMizuki/fix-matrixOps-#26
Browse files Browse the repository at this point in the history
Fix matrixOps #26
  • Loading branch information
UchidaMizuki authored Dec 8, 2024
2 parents ce45e39 + e600cc6 commit b5dbd77
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 9 deletions.
28 changes: 19 additions & 9 deletions R/methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,36 @@ matrixOps_dibble <- function(e1, e2) {
e2 <- as_ddf_col(e2)

class <- class(e1)
dim_names_x <- dimnames(e1)
dim_names_y <- dimnames(e2)
dim_names_e1 <- dimnames(e1)
dim_names_e2 <- dimnames(e2)

if (vec_size(dim_names_x) == 1L) {
size_dim_names_e1 <- vec_size(dim_names_e1)
size_dim_names_e2 <- vec_size(dim_names_e2)

new_dim_name <- union_dim_names(list(dim_names_e1[size_dim_names_e1], dim_names_e2[1]))
dim_names_e1[size_dim_names_e1] <- new_dim_name
dim_names_e2[1] <- new_dim_name

e1 <- broadcast(e1, dim_names_e1)
e2 <- broadcast(e2, dim_names_e2)

if (vec_size(dim_names_e1) == 1L) {
e1 <- as.vector(e1)
dim_names_x <- NULL
dim_names_e1 <- NULL
} else {
e1 <- as.matrix(e1)
dim_names_x <- dim_names_x[1L]
dim_names_e1 <- dim_names_e1[1L]
}

if (vec_size(dim_names_y) == 1L) {
if (vec_size(dim_names_e2) == 1L) {
e2 <- as.vector(e2)
dim_names_y <- NULL
dim_names_e2 <- NULL
} else {
e2 <- as.matrix(e2)
dim_names_y <- dim_names_y[2L]
dim_names_e2 <- dim_names_e2[2L]
}

new_dim_names <- purrr::compact(c(dim_names_x, dim_names_y))
new_dim_names <- purrr::compact(c(dim_names_e1, dim_names_e2))

out <- NextMethod()

Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-matrix.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
test_that("`%*%`() works", {
rev_axis <- function(x, axis) {
dim_names <- dimnames(x)
dim_names[[axis]] <- rev(dim_names[[axis]])
broadcast(x, dim_names)
}

# mat %*% mat
mat_x <- matrix(1:9, 3,
dimnames = list(axis1 = 1:3,
Expand All @@ -9,10 +15,14 @@ test_that("`%*%`() works", {
ddf_x <- as_dibble(mat_x)
ddf_y <- as_dibble(mat_y)
expect_equal(as.matrix(ddf_x %*% ddf_y), unname(mat_x %*% mat_y))
expect_equal(as.matrix(rev_axis(ddf_x, 2) %*% ddf_y), unname(mat_x %*% mat_y))
expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), unname(mat_x %*% mat_y))

ddf_x <- dibble(x = ddf_x)
ddf_y <- dibble(x = ddf_y)
expect_equal(as.matrix(ddf_x %*% ddf_y), unname(mat_x %*% mat_y))
expect_equal(as.matrix(rev_axis(ddf_x, 2) %*% ddf_y), unname(mat_x %*% mat_y))
expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), unname(mat_x %*% mat_y))

# vec %*% mat
vec_x <- array(1:3, 3,
Expand All @@ -23,6 +33,8 @@ test_that("`%*%`() works", {
ddf_x <- as_dibble(vec_x)
ddf_y <- as_dibble(mat_y)
expect_equal(as.matrix(ddf_x %*% ddf_y), t(unname(vec_x %*% mat_y)))
expect_equal(as.matrix(rev_axis(ddf_x, 1) %*% ddf_y), t(unname(vec_x %*% mat_y)))
expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), t(unname(vec_x %*% mat_y)))

# mat %*% vec
mat_x <- matrix(1:9, 3,
Expand All @@ -33,6 +45,8 @@ test_that("`%*%`() works", {
ddf_x <- as_dibble(mat_x)
ddf_y <- as_dibble(vec_y)
expect_equal(as.matrix(ddf_x %*% ddf_y), unname(mat_x %*% vec_y))
expect_equal(as.matrix(rev_axis(ddf_x, 2) %*% ddf_y), unname(mat_x %*% vec_y))
expect_equal(as.matrix(ddf_x %*% rev_axis(ddf_y, 1)), unname(mat_x %*% vec_y))

# vec %*% vec
vec_x <- array(1:3, 3,
Expand All @@ -42,6 +56,8 @@ test_that("`%*%`() works", {
ddf_x <- as_dibble(vec_x)
ddf_y <- as_dibble(vec_y)
expect_equal(ddf_x %*% ddf_y, as.vector(vec_x %*% vec_y))
expect_equal(rev_axis(ddf_x, 1) %*% ddf_y, as.vector(vec_x %*% vec_y))
expect_equal(ddf_x %*% rev_axis(ddf_y, 1), as.vector(vec_x %*% vec_y))
})

test_that("t() works", {
Expand Down

0 comments on commit b5dbd77

Please sign in to comment.