Skip to content

Commit

Permalink
Merge pull request #55 from mlverse/updates
Browse files Browse the repository at this point in the history
Updates
  • Loading branch information
edgararuiz authored Oct 30, 2023
2 parents beb4b31 + 01579db commit ba43481
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 59 deletions.
21 changes: 8 additions & 13 deletions R/connections-pane.R
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,18 @@ catalog_sql <- function(
catalog = NULL,
schema = NULL,
name = NULL,
type = NULL) {
type = NULL,
catalog_tbl = in_catalog("system", "information_schema", "catalogs"),
schema_tbl = in_catalog("system", "information_schema", "schemata"),
tables_tbl = in_catalog("system", "information_schema", "tables")
) {

limit <- as.numeric(
Sys.getenv("SPARKLYR_CONNECTION_OBJECT_LIMIT", unset = 100)
)

if(is.null(catalog)) {
all_catalogs <- tbl(
src = con,
in_catalog("system", "information_schema", "catalogs")
)
all_catalogs <- tbl(src = con, catalog_tbl)

get_catalogs <- all_catalogs %>%
select(catalog_name, comment) %>%
Expand All @@ -203,10 +204,7 @@ catalog_sql <- function(
}

if(is.null(schema) && !is.null(catalog)) {
all_schema <- tbl(
src = con,
in_catalog("system", "information_schema", "schemata")
)
all_schema <- tbl(src = con, schema_tbl)

get_schema <- all_schema %>%
filter(catalog_name == catalog) %>%
Expand All @@ -226,10 +224,7 @@ catalog_sql <- function(
}

if(!is.null(schema) && !is.null(catalog)) {
all_tables <- tbl(
src = con,
in_catalog("system", "information_schema", "tables")
)
all_tables <- tbl(src = con, tables_tbl)

get_tables <- all_tables %>%
filter(
Expand Down
61 changes: 35 additions & 26 deletions R/install.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,33 +101,12 @@ install_as_job <- function(
in_rstudio <- TRUE
}
if (as_job && in_rstudio) {
args$as_job <- NULL
args$method <- args$method[[1]]

install_code <- build_job_code(args)
job_name <- paste0("Installing '", libs, "' version '", version, "'")

arg_list <- args %>%
imap(~ {
if (inherits(.x, "character")) {
x <- paste0("\"", .x, "\"")
} else {
x <- .x
}
paste0(.y, " = ", x)
}) %>%
as.character() %>%
paste0(collapse = ", ")

install_code <- paste0(
"pysparklyr:::install_environment(", arg_list, ")"
)
temp_file <- tempfile()
writeLines(install_code, temp_file)
invisible(
jobRunScript(
path = temp_file,
name = job_name
)
jobRunScript(path = temp_file, name = job_name)
)
cli_div(theme = cli_colors())
cli_alert_success("{.header Running installation as an RStudio job }")
Expand All @@ -145,6 +124,26 @@ install_as_job <- function(
}
}

build_job_code <- function(args) {
args$as_job <- NULL
args$method <- args$method[[1]]
arg_list <- args %>%
imap(~ {
if (inherits(.x, "character")) {
x <- paste0("\"", .x, "\"")
} else {
x <- .x
}
paste0(.y, " = ", x)
}) %>%
as.character() %>%
paste0(collapse = ", ")

paste0(
"pysparklyr:::install_environment(", arg_list, ")"
)
}

install_environment <- function(
libs = NULL,
version = NULL,
Expand Down Expand Up @@ -180,10 +179,17 @@ install_environment <- function(
version <- paste0(version, ".*")
}

add_torch <- FALSE
if (is.null(envname)) {
if (libs == "databricks-connect") {
if (compareVersion(as.character(ver_name), "14.1") >= 0) {
add_torch <- TRUE
}
ln <- "databricks"
} else {
if (compareVersion(as.character(ver_name), "3.5") >= 0) {
add_torch <- TRUE
}
ln <- libs
}
envname <- glue("r-sparklyr-{ln}-{ver_name}")
Expand All @@ -198,11 +204,13 @@ install_environment <- function(
"PyArrow",
"grpcio",
"google-api-python-client",
"grpcio_status",
"torch",
"torcheval"
"grpcio_status"
)

if (add_torch) {
packages <- c(packages, "torch", "torcheval")
}

method <- match.arg(method)

if (new_env) {
Expand Down Expand Up @@ -303,6 +311,7 @@ py_library_info <- function(lib, ver = NULL) {
}

version_prep <- function(version) {
version <- as.character(version)
ver <- version %>%
strsplit("\\.") %>%
unlist()
Expand Down
27 changes: 9 additions & 18 deletions R/spark-connect.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
spark_connect_method.spark_method_spark_connect <- function(
x,
method,
master = NULL,
master,
spark_home,
config,
app_name,
version,
packages,
hadoop_version,
extensions,
scala_version,
...) {
Expand All @@ -29,7 +29,7 @@ spark_connect_method.spark_method_databricks_connect <- function(
config,
app_name,
version,
packages,
hadoop_version,
extensions,
scala_version,
...) {
Expand All @@ -53,24 +53,21 @@ py_spark_connect <- function(master,
env_base <- "r-sparklyr-pyspark-"
envs <- find_environments(env_base)
if (length(envs) == 0) {
cli_div(theme = cli_colors())
cli_abort(c(
cli_internal_abort(c(
paste0(
"{.header No environment name provided, and }",
"{.header no environment was automatically identified.}"
),
"Run {.run pysparklyr::install_pyspark()} to install."
), call = NULL)
cli_end()
))
} else {
if (!is.null(spark_version)) {
sp_version <- version_prep(spark_version)
envname <- glue("{env_base}{sp_version}")
matched <- envs[envs == envname]
if (length(matched) == 0) {
envname <- envs[[1]]
cli_div(theme = cli_colors())
cli_alert_warning(paste(
cli_internal_alert_warning(paste(
"{.header A Python environment with a matching version was not found}",
"* {.header Will attempt connecting using }{.emph '{envname}'}",
paste0(
Expand All @@ -79,7 +76,6 @@ py_spark_connect <- function(master,
),
sep = "\n"
))
cli_end()
} else {
envname <- matched
}
Expand Down Expand Up @@ -118,8 +114,7 @@ py_spark_connect <- function(master,
matched <- envs[envs == envname]
if (length(matched) == 0) {
envname <- envs[[1]]
cli_div(theme = cli_colors())
cli_alert_warning(paste(
cli_internal_alert_warning(paste(
"{.header A Python environment with a matching version was not found}",
"* {.header Will attempt connecting using }{.emph '{envname}'}",
paste0(
Expand All @@ -128,7 +123,6 @@ py_spark_connect <- function(master,
),
sep = "\n"
))
cli_end()
}
} else {
if (!is.na(reticulate_python)) {
Expand All @@ -137,9 +131,7 @@ py_spark_connect <- function(master,
"{.emph 'RETICULATE_PYTHON' }{.header environment variable}",
"{.class ({py_exe()})}"
)
cli_div(theme = cli_colors())
cli_alert_warning(msg)
cli_end()
cli_internal_alert_warning(msg)
envname <- reticulate_python
}
}
Expand Down Expand Up @@ -257,8 +249,7 @@ build_user_agent <- function() {
cluster_dbr_version <- function(cluster_id,
host = Sys.getenv("DATABRICKS_HOST"),
token = Sys.getenv("DATABRICKS_TOKEN")) {
cli_div(theme = cli_colors())
cli_alert_warning(
cli_internal_alert_warning(
"{.header Retrieving version from cluster }{.emph '{cluster_id}'}"
)

Expand Down
12 changes: 12 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,15 @@ cli_colors <- function(envir = parent.frame()) {
span.spark = list(color = "darkgray")
)
}

cli_internal_alert_warning <- function(msg) {
cli_div(theme = cli_colors())
cli_alert_warning(msg)
cli_end()
}

cli_internal_abort <- function(msg) {
cli_div(theme = cli_colors())
cli_abort(msg, call = NULL)
cli_end()
}
48 changes: 48 additions & 0 deletions tests/testthat/_snaps/compat-sparklyr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Internal functions work

Code
.strsep("domino", 3)
Output
[[1]]
[1] "dom"
[[2]]
[1] "ino"

---

Code
.str_split_n("hello", "l", n_max = 3)
Output
[[1]]
[1] "he" "" "o"

---

Code
.list_indices(c("one", "two", "three"), 2)
Output
[1] "one, two, ..."

---

Code
.simplify_pieces(c("one", "two", "three"), 2, FALSE)
Output
$strings
$strings[[1]]
[1] NA NA NA
$strings[[2]]
[1] NA NA NA
$too_big
NULL
$too_sml
[1] 1 2 3

8 changes: 8 additions & 0 deletions tests/testthat/test-compat-sparklyr.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_that("Internal functions work", {
expect_snapshot(.strsep("domino", 3))
expect_equal(.slice_match("hello", 1), "hello")
expect_snapshot(.str_split_n("hello", "l", n_max = 3))
expect_snapshot(.list_indices(c("one", "two", "three"), 2))
expect_snapshot(.simplify_pieces(c("one", "two", "three"), 2, FALSE))
expect_warning(.str_split_fixed("hello", "l", 1))
})
25 changes: 25 additions & 0 deletions tests/testthat/test-connections-pane.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,28 @@ test_that("Object retrieval function work", {
)

})

test_that("DB sql", {
sc <- test_spark_connect()
temp_items <- data.frame(
catalog_name = "catalog",
schema_name = "schema",
table_catalog = "catalog",
table_schema = "schema",
table_name = "table",
comment = "this is a comment"
)
tbl_items <- copy_to(sc, temp_items, overwrite = TRUE)

expect_s3_class(catalog_sql(sc, catalog_tbl = "temp_items"), "data.frame")
expect_s3_class(
catalog_sql(sc, catalog = "catalog", schema_tbl = "temp_items")
, "data.frame"
)
expect_s3_class(
catalog_sql(sc, catalog = "catalog", schema = "schema", tables_tbl = "temp_items"),
"data.frame"
)

dbRemoveTable(sc, "temp_items")
})
8 changes: 6 additions & 2 deletions tests/testthat/test-install.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ skip_if(
)

test_that("PySpark installation works", {
expect_output(install_pyspark("3.3", python = Sys.which("python")))
expect_output(install_pyspark("3.3", as_job = FALSE, python = Sys.which("python")))
reticulate::virtualenv_remove("r-sparklyr-pyspark-3.3", confirm = FALSE)
})

test_that("DB Connect installation works", {
expect_output(install_databricks("13.0", python = Sys.which("python")))
expect_output(install_databricks("13.0", as_job = FALSE, python = Sys.which("python")))
reticulate::virtualenv_remove("r-sparklyr-databricks-13.0", confirm = FALSE)
})

test_that("Install code is correctly created", {
expect_snapshot(build_job_code(list(a = 1)))
})

0 comments on commit ba43481

Please sign in to comment.