-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
catboost method to embed categorical variables #138
Comments
Unfortunately catboost (the R package) is not on CRAN 😔 which is a blocker for us being able to implement catboost methods in our packages. You can see related discussion in catboost/catboost#439. |
hey Julia, |
Hey @talegari 👋 That sounds great! Feel free to open an issue, and ping me if you need any help or assistance! |
Hello @talegari 👋 Are you still interested opening a PR for this step? if not, then I will do it |
Hey @EmilHvitfeldt ... it just fell off the radar. I will submit a PR. I am planning on these lines. Let me know if you have a different suggestion. |
Amazing! That looks like a great place to start! |
by 24th Mar
ಗುರು, ಮಾರ್ಚ್ 16, 2023 ರಂದು 09:34 ಅಪರಾಹ್ನ ಸಮಯಕ್ಕೆ ರಂದು Emil Hvitfeldt <
***@***.***> ಅವರು ಬರೆದಿದ್ದಾರೆ:
… Amazing! That looks like a great place to start!
Do you know when you will have time to work on this? No rush!
—
Reply to this email directly, view it on GitHub
<#138 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ACMTTW4C6ESAZ42ZCB7WVCLW4M2Y7ANCNFSM5ZQHRD2A>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
hey @EmilHvitfeldt , there was an unforseen thing that stopped me working on this. This is to let you know that I am on it and will raise a PR shortly. |
no problem! It might not make it into the next {embed} release, but that is fine, we can send it in later |
@EmilHvitfeldt , I am one step away from raising a PR. I need your help in resolving a small issue. Here is the context: I have implemented catboost encoder as a R6 class here: Category encoder R6 class# catboost encoder core logic
pacman::p_load("tidyverse")
#' catboost_encoder R6 class
#'
#' An R6 class to encode categorical variables with the CatBoost method.
#'
#' @name catboost_encoder
#' @docType class
#' @importFrom R6 R6Class
#'
#' @slot dataset The dataset to fit the encoder
#' @slot mean The mean of the response variable in the dataset
#' @slot varnames_to_encode The names of the categorical variables to encode
#' @slot response_varname The name of the response variable in the dataset
#' @slot is_fitted A flag indicating whether the encoder has been fitted
#' @slot a A hyperparameter to control the strength of the encoding
#'
#' @section Public methods: \describe{
#' \item{\code{initialize(dataset)}}{Constructor method for the
#' catboost_encoder class} \item{\code{fit(varnames_to_encode,
#' response_varname, a = 1)}}{Fit the encoder to the data}
#' \item{\code{transform(new_data = NULL)}}{Transform a new dataset using the
#' fitted encoder} }
#'
#' @section Private methods: \describe{ \item{\code{encode_with_y(df,
#' varname_to_encode, response_varname)}}{Encode a categorical variable using
#' the response variable} \item{\code{encode_without_y(df, varname_to_encode,
#' response_varname)}}{Encode a categorical variable without using the
#' response variable} }
#'
#' @section Usage
#'
#' catboost_encoder <- catboost_encoder$new(dataset)
#' catboost_encoder$fit(varnames_to_encode, response_varname)
#' encoded_data <- catboost_encoder$transform(new_data)
#'
#' @export catboost_encoder
catboost_encoder = R6::R6Class(
"catboost_encoder",
public = list(
dataset = NULL,
mean = NULL,
varnames_to_encode = NULL,
response_varname = NULL,
is_fitted = FALSE,
a = NULL,
encode_novel_levels = NULL,
encode_missing_levels = NULL,
initialize = function(dataset){
checkmate::assert_data_frame(dataset)
self$dataset = dataset
return(invisible(NULL))
},
fit = function(varnames_to_encode,
response_varname,
a = 1,
encode_novel_levels = TRUE,
encode_missing_levels = FALSE
){
checkmate::assert_string(response_varname)
checkmate::assert_subset(response_varname,
choices = colnames(self$dataset)
)
checkmate::assert_numeric(self$dataset[[response_varname]],
any.missing = FALSE
)
checkmate::assert_character(varnames_to_encode)
checkmate::assert_subset(varnames_to_encode,
choices = colnames(self$dataset)
)
for (avarname in varnames_to_encode){
checkmate::assert_factor(self$dataset[[avarname]])
}
checkmate::assert_number(a)
checkmate::assert_flag(encode_novel_levels)
checkmate::assert_flag(encode_missing_levels)
self$varnames_to_encode = varnames_to_encode
self$response_varname = response_varname
self$mean = mean(self$dataset[[response_varname]], na.rm = TRUE)
self$a = a
self$encode_novel_levels = TRUE
self$encode_missing_levels = FALSE
self$is_fitted = TRUE
return(invisible(NULL))
},
transform = function(new_data = NULL){
new_data_is_null = TRUE
if (!is.null(new_data)){
checkmate::assert_data_frame(new_data)
checkmate::assert_false(self$response_varname %in% colnames(new_data))
names_sorted = sort(colnames(new_data))
checkmate::assert_set_equal(colnames(new_data),
setdiff(colnames(self$dataset),
self$response_varname
)
)
checkmate::assert_set_equal(
sapply(new_data, class)[names_sorted],
sapply(dplyr::select(self$dataset, -c(self$response_varname))
, class
)[names_sorted]
)
new_data_is_null = FALSE
}
if (!self$is_fitted){
stop("please 'fit' before 'transform'")
}
if (new_data_is_null){
message("transforming on the dataset")
new_data = self$dataset
}
if (new_data_is_null){
encoded_cols = map(self$varnames_to_encode,
~ private$encode_with_y(new_data, .x)
)
} else {
encoded_cols = map(self$varnames_to_encode,
~ private$encode_without_y(new_data,.x)
)
}
names(encoded_cols) = self$varnames_to_encode
res = as_tibble(encoded_cols) %>%
bind_cols(select(new_data, -c(self$varnames_to_encode))) %>%
relocate(colnames(new_data))
# encode novel (in new data case only)
if (self$encode_novel_levels && !new_data_is_null){
for (avarname in self$varnames_to_encode){
new_levels = setdiff(levels(new_data[[avarname]]),
levels(self$dataset[[avarname]])
)
if (length(new_levels) > 0){
res[[avarname]] = ifelse(new_data[[avarname]] %in% new_levels,
self$mean,
res[[avarname]]
)
}
}
}
# encode missing (in new data case only)
if (self$encode_missing_levels && !new_data_is_null){
for (avarname in self$varnames_to_encode){
res[[avarname]][ is.na(new_data[[avarname]]) ] = NA
}
}
return(res)
}
),
private = list(
encode_with_y = function(df, varname_to_encode){
# new levels: not applicable
# NA: encoded
res = df %>%
select(all_of(c(varname_to_encode, self$response_varname))) %>%
group_by(.data[[varname_to_encode]]) %>%
mutate(cs__ = cumsum(.data[[self$response_varname]]),
cc__ = row_number() - 1L
) %>%
ungroup() %>%
transmute({{varname_to_encode}} := (cs__ -
.data[[self$response_varname]] +
mean(.data[[self$response_varname]], na.rm = TRUE) *
self$a
) / (cc__ + self$a)
) %>%
pull()
return(res)
},
encode_without_y = function(df, varname_to_encode){
# new levels: NA
# NA: NA
level_means = "level_means__"
agg_frame = self$dataset %>%
select(all_of(c(varname_to_encode, self$response_varname))) %>%
group_by(.data[[varname_to_encode]]) %>%
summarise(sum__ = sum(.data[[self$response_varname]], na.rm = TRUE),
count__ = n()
) %>%
ungroup() %>%
mutate(level_means__ =
ifelse(count__ == 1,
self$mean,
(sum__ + self$mean * self$a) / (count__ + self$a)
)
) %>%
drop_na(all_of(varname_to_encode)) %>%
select(all_of(c(varname_to_encode, level_means)))
res = df %>%
select(all_of(c(varname_to_encode))) %>%
left_join(agg_frame, by = varname_to_encode) %>%
pull(level_means)
return(res)
}
)
) recipe wrapper as 'step_catboost'step_catboost = function(recipe,
...,
role = NA,
trained = FALSE,
outcome = NULL,
mapping = NULL,
skip = FALSE,
id = rand_id("catboost")
){
if (is.null(outcome)) {
rlang::abort("Please list a variable in `outcome`")
}
recipes:::add_step(
recipe,
step_catboost_new(
terms = enquos(...),
role = role,
trained = trained,
outcome = outcome,
mapping = mapping,
skip = skip,
id = id
)
)
}
step_catboost_new =
function(terms,
role,
trained,
outcome,
mapping,
skip,
id
){
step(
subclass = "catboost",
terms = terms,
role = role,
trained = trained,
outcome = outcome,
mapping = mapping,
skip = skip,
id = id
)
}
#' @export
prep.step_catboost = function(x,
training,
info = NULL,
...
){
col_names = recipes_eval_select(x$terms, training, info)
if (length(col_names) > 0) {
y_name = recipes_eval_select(x$outcome, training, info)
# instantiate R6 class obj
ce = catboost_encoder$new(training)
ce$fit(varnames_to_encode = col_names,
response_varname = y_name
)
} else {
ce = list()
}
step_catboost_new(
terms = x$terms,
role = x$role,
trained = TRUE,
outcome = x$outcome,
mapping = ce,
skip = x$skip,
id = x$id
)
}
#' @export
bake.step_catboost = function(object, new_data, ...) {
if (!is.null(new_data)){
y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
ce = object$mapping
if (y_name %in% colnames(new_data)){
new_data[[y_name]] = NULL
}
res = ce$transform(new_data)
} else {
res = ce$transform()
}
res = ce$transform(new_data)
return(res)
}
#' @rdname required_pkgs.embed
#' @export
required_pkgs.step_catboost = function(x, ...) {
c("embed")
} Examplepacman::p_load("recipes", "tidyverse")
source("~/personal/catboost_encoding_r6.R")
#> transforming on the dataset
#> transforming on the dataset
source("~/personal/step_catboost.R")
pen1 = palmerpenguins::penguins %>%
drop_na(bill_length_mm) %>%
slice_sample(prop = 0.7, by = 'species')
pen2 = palmerpenguins::penguins %>%
drop_na(bill_length_mm) %>%
setdiff(pen1)
# example with R6 class
ce = catboost_encoder$new(pen1)
ce$fit(c('species', 'sex'), response_varname = 'bill_length_mm')
# when input to transofrm is empty, it uses the training dataset
# (here it is pen1)
ce$transform()
#> transforming on the dataset
#> # A tibble: 238 × 8
#> species island bill_length_mm bill_depth_mm flipper_…¹ body_…² sex year
#> <dbl> <fct> <dbl> <dbl> <int> <int> <dbl> <int>
#> 1 43.8 Torgersen 39.6 17.2 196 3550 43.8 2008
#> 2 41.7 Dream 37.5 18.9 179 2975 43.8 2007
#> 3 40.3 Biscoe 35.5 16.2 195 3350 41.7 2008
#> 4 39.1 Torgersen 40.6 19 199 4000 43.8 2009
#> 5 39.4 Biscoe 40.1 18.9 188 4300 42.2 2008
#> 6 39.5 Dream 39.6 18.8 190 4600 41.5 2007
#> 7 39.5 Dream 32.1 15.5 188 3050 39.6 2009
#> 8 38.6 Dream 39.8 19.1 184 4650 41.0 2007
#> 9 38.7 Torgersen 34.1 18.1 193 3475 40.6 2007
#> 10 38.3 Dream 37 16.9 185 3000 37.7 2007
#> # … with 228 more rows, and abbreviated variable names ¹flipper_length_mm,
#> # ²body_mass_g
# transform on a new dataset
ce$transform(pen2 %>% select(-bill_length_mm))
#> # A tibble: 104 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <dbl> <int>
#> 1 38.7 Torgersen 18 195 3250 42.2 2007
#> 2 38.7 Torgersen 20.6 190 3650 45.6 2007
#> 3 38.7 Torgersen 17.8 181 3625 42.2 2007
#> 4 38.7 Torgersen 19.6 195 4675 45.6 2007
#> 5 38.7 Torgersen 21.2 191 3800 45.6 2007
#> 6 38.7 Torgersen 17.8 185 3700 42.2 2007
#> 7 38.7 Torgersen 20.7 197 4500 45.6 2007
#> 8 38.7 Torgersen 21.5 194 4200 45.6 2007
#> 9 38.7 Biscoe 18.6 172 3150 42.2 2007
#> 10 38.7 Dream 16.7 178 3250 42.2 2007
#> # … with 94 more rows
# example with step_catboost recipe
ar = recipe(bill_length_mm ~ ., data = pen1) %>%
step_catboost(species, outcome = "bill_length_mm") %>%
prep(training = pen1)
ar
#> Recipe
#>
#> Inputs:
#>
#> role #variables
#> outcome 1
#> predictor 7
#>
#> Training data contained 238 data points and 9 incomplete rows.
#>
#> Operations:
#>
#> $terms
#> <list_of<quosure>>
#>
#> [[1]]
#> <quosure>
#> expr: ^species
#> env: 0x7fbbb5a65120
#>
#>
#> $role
#> [1] NA
#>
#> $trained
#> [1] TRUE
#>
#> $outcome
#> [1] "bill_length_mm"
#>
#> $mapping
#> <catboost_encoder>
#> Public:
#> a: 1
#> clone: function (deep = FALSE)
#> dataset: tbl_df, tbl, data.frame
#> encode_missing_levels: FALSE
#> encode_novel_levels: TRUE
#> fit: function (varnames_to_encode, response_varname, a = 1, encode_novel_levels = TRUE,
#> initialize: function (dataset)
#> is_fitted: TRUE
#> mean: 43.7655462184874
#> response_varname: bill_length_mm
#> transform: function (new_data = NULL)
#> varnames_to_encode: species
#> Private:
#> encode_with_y: function (df, varname_to_encode)
#> encode_without_y: function (df, varname_to_encode)
#>
#> $skip
#> [1] FALSE
#>
#> $id
#> [1] "catboost_LGVzz"
#>
#> attr(,"class")
#> [1] "step_catboost" "step"
ar %>%
juice()
#> # A tibble: 238 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 17.2 196 3550 female 2008
#> 2 38.7 Dream 18.9 179 2975 <NA> 2007
#> 3 38.7 Biscoe 16.2 195 3350 female 2008
#> 4 38.7 Torgersen 19 199 4000 male 2009
#> 5 38.7 Biscoe 18.9 188 4300 male 2008
#> 6 38.7 Dream 18.8 190 4600 male 2007
#> 7 38.7 Dream 15.5 188 3050 female 2009
#> 8 38.7 Dream 19.1 184 4650 male 2007
#> 9 38.7 Torgersen 18.1 193 3475 <NA> 2007
#> 10 38.7 Dream 16.9 185 3000 female 2007
#> # … with 228 more rows
ar %>%
bake(new_data = NULL)
#> # A tibble: 238 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 17.2 196 3550 female 2008
#> 2 38.7 Dream 18.9 179 2975 <NA> 2007
#> 3 38.7 Biscoe 16.2 195 3350 female 2008
#> 4 38.7 Torgersen 19 199 4000 male 2009
#> 5 38.7 Biscoe 18.9 188 4300 male 2008
#> 6 38.7 Dream 18.8 190 4600 male 2007
#> 7 38.7 Dream 15.5 188 3050 female 2009
#> 8 38.7 Dream 19.1 184 4650 male 2007
#> 9 38.7 Torgersen 18.1 193 3475 <NA> 2007
#> 10 38.7 Dream 16.9 185 3000 female 2007
#> # … with 228 more rows
ar %>%
bake(new_data = pen1)
#> # A tibble: 238 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 17.2 196 3550 female 2008
#> 2 38.7 Dream 18.9 179 2975 <NA> 2007
#> 3 38.7 Biscoe 16.2 195 3350 female 2008
#> 4 38.7 Torgersen 19 199 4000 male 2009
#> 5 38.7 Biscoe 18.9 188 4300 male 2008
#> 6 38.7 Dream 18.8 190 4600 male 2007
#> 7 38.7 Dream 15.5 188 3050 female 2009
#> 8 38.7 Dream 19.1 184 4650 male 2007
#> 9 38.7 Torgersen 18.1 193 3475 <NA> 2007
#> 10 38.7 Dream 16.9 185 3000 female 2007
#> # … with 228 more rows
ar %>%
bake(new_data = pen2)
#> # A tibble: 104 × 7
#> species island bill_depth_mm flipper_length_mm body_mass_g sex year
#> <dbl> <fct> <dbl> <int> <int> <fct> <int>
#> 1 38.7 Torgersen 18 195 3250 female 2007
#> 2 38.7 Torgersen 20.6 190 3650 male 2007
#> 3 38.7 Torgersen 17.8 181 3625 female 2007
#> 4 38.7 Torgersen 19.6 195 4675 male 2007
#> 5 38.7 Torgersen 21.2 191 3800 male 2007
#> 6 38.7 Torgersen 17.8 185 3700 female 2007
#> 7 38.7 Torgersen 20.7 197 4500 male 2007
#> 8 38.7 Torgersen 21.5 194 4200 male 2007
#> 9 38.7 Biscoe 18.6 172 3150 female 2007
#> 10 38.7 Dream 16.7 178 3250 female 2007
#> # … with 94 more rows Issue: The |
Hello @talegari Sorry for taking a while to answer. I'm not terrible familiar with {R6} so I'm not sure how much I can help you. However, I can tell you where something might happen. In if (!is.null(new_data)){
y_name = purrr::map_chr(object$outcome, rlang::as_name) # string
ce = object$mapping
if (y_name %in% colnames(new_data)){
new_data[[y_name]] = NULL
}
res = ce$transform(new_data)
} else {
res = ce$transform()
} I'm assuming that you thought this was needed to deal with Secondly, I'm sad to say since you put in a lot of effort, but I don't want to include {R6} and {checkmate} as dependencies just to include this step. If you don't want to go through the work on translating away from {R6} and {checkmate} I understand, and If you want I can take over and do the last parts. Thanks again for all the work! |
Hi Emil,
I am planning to implement a
step_catboost
(on these lines). IMHO, it should belong here.Let me know if you are open for PR?
The text was updated successfully, but these errors were encountered: