Skip to content

Commit

Permalink
Add extractModel and extractData
Browse files Browse the repository at this point in the history
  • Loading branch information
fouodo committed Nov 15, 2024
1 parent 2d5f2dc commit cc1ae16
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 18 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
S3method(predict,Training)
S3method(predict,bestSpecificLearner)
S3method(predict,weightedMeanLearner)
S3method(summary,Testing)
S3method(summary,Training)
export(Data)
export(HashTable)
export(Lrner)
Expand All @@ -27,6 +29,8 @@ export(createTesting)
export(createTrainLayer)
export(createTrainMetaLayer)
export(createTraining)
export(extractData)
export(extractModel)
export(fusemlr)
export(upsetplot)
export(varSelection)
Expand Down
6 changes: 0 additions & 6 deletions R/Lrner.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ Lrner <- R6Class("Lrner",
cat(sprintf(" TrainLayer : %s\n", private$train_layer$getId()))
cat(sprintf(" Package : %s\n", private$package))
cat(sprintf(" Learn function : %s\n", private$lrn_fct))
cat("Predicting parameter\n")
print(expand.grid(private$param_train))
if (!length(private$param_pred)) {
cat("Predicting parameter\n")
print(expand.grid(private$param_pred))
}
},
#' @description
#' Learner and prediction parameter interface. Use this function
Expand Down
16 changes: 8 additions & 8 deletions R/TrainData.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ TrainData <- R6Class("TrainData",
#'
summary = function (...) {
if ("TrainMetaLayer" %in% class(private$train_layer)) {
cat(sprintf("TrainData : %s\n", "meta data"))
cat(sprintf(" TrainData : %s\n", "meta data"))
} else {
cat(sprintf("TrainData : %s\n", private$id))
cat(sprintf(" TrainData : %s\n", private$id))
}
cat(sprintf(" Layer : %s\n", private$train_layer$getId()))
cat(sprintf(" Ind. id. : %s\n", private$ind_col))
cat(sprintf(" Target : %s\n", private$target))
cat(sprintf(" n : %s\n", nrow(private$data_frame)))
cat(sprintf(" Missing : %s\n", sum(!complete.cases(private$data_frame))))
cat(sprintf(" p : %s\n", ncol(private$data_frame)))
cat(sprintf(" Layer : %s\n", private$train_layer$getId()))
cat(sprintf(" Ind. id. : %s\n", private$ind_col))
cat(sprintf(" Target : %s\n", private$target))
cat(sprintf(" n : %s\n", nrow(private$data_frame)))
cat(sprintf(" Missing : %s\n", sum(!complete.cases(private$data_frame))))
cat(sprintf(" p : %s\n", ncol(private$data_frame)))
},
#' @description
#' Getter of the current \code{data.frame} wihtout individual
Expand Down
39 changes: 39 additions & 0 deletions R/Training.R
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,44 @@ Training <- R6Class("Training",
return(meta_layer)
},
#' @description
#' Retrieve models from all layer.
#'
#' @return
#' A \code{list} containing all (base and meta) models.
#' @export
#'
getModel = function() {
layers = self$getKeyClass()
# This code accesses each layer (except TrainMetaLayer) level
# and get the individual IDs.
layers = layers[layers$class %in% c("TrainLayer", "TrainMetaLayer"), ]
current_model = NULL
models = list()
for (k in layers$key) {
layer = self$getFromHashTable(key = k)
models[[layer$getId()]] = layer$getModel()$getBaseModel()
}
return(models)
},
#' @description
#' Retrieve meta data.
#'
#' @return
#' A \code{list} containing all (base and meta) models.
#' @export
#'
getData = function() {
layers = self$getKeyClass()
layers = layers[layers$class %in% c("TrainLayer", "TrainMetaLayer"), ]
current_model = NULL
all_data = list()
for (k in layers$key) {
layer = self$getFromHashTable(key = k)
all_data[[layer$getId()]] = layer$getTrainData()$getDataFrame()
}
return(all_data)
},
#' @description
#' Remove a layer of a given ID.
#'
#' @param id `character(1)` \cr
Expand Down Expand Up @@ -619,6 +657,7 @@ Training <- R6Class("Training",
cat("----------------\n")
cat("\n")
layers = self$getKeyClass()
layers = layers[layers$class %in% c("TrainLayer", "TrainMetaLayer"), ]
for (k in layers$key) {
layer = self$getFromHashTable(key = k)
layer$summary()
Expand Down
15 changes: 15 additions & 0 deletions R/testingFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,18 @@ createTestLayer = function (testing,
new_layer = test_layer)
return(testing)
}

#' @title Testing object Summaries
#' @description
#' Summaries a fuseMLR [Testing] object.
#'
#' @param object (`Testing(1)`) \cr
#' The [Testing] object of interest.
#' @param ... \cr
#' Further arguments.
#'
#' @export
#'
summary.Testing = function (object, ...) {
return(object$summary())
}
45 changes: 45 additions & 0 deletions R/trainingFunctions.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,51 @@ predict.Training = function (object,
return(predictions)
}

#' @title extractModel
#' @description
#' Extracts models stored on each layers; base and meta models are extracted.
#'
#' @param training (`Training(1)`) \cr
#' The [Training] object of interest.
#'
#' @return
#' A list of models is returned.
#' @export
#'
extractModel = function (training) {
return(training$getModel())
}

#' @title extractData
#' @description
#' Extracts data stored on each layers; base and meta data are extracted.
#'
#' @param training (`Training(1)`) \cr
#' The [Training] object of interest.
#'
#' @return
#' A list of data is returned.
#' @export
#'
extractData = function (training) {
return(training$getData())
}

#' @title Training object Summaries
#' @description
#' Summaries a fuseMLR [Training] object.
#'
#' @param object (`Training(1)`) \cr
#' The [Training] object of interest.
#' @param ... \cr
#' Further arguments.
#'
#' @export
#'
summary.Training = function (object, ...) {
return(object$summary())
}

#' @title upsetplot
#' @description
#' An upset plot of overlapping individuals.
Expand Down
8 changes: 4 additions & 4 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,14 @@ training <- fusemlr(training = training,
verbose = FALSE)
print(training)
# See also summary(training)
```

- Retrieve the basic model of a specific layer.
- Use `extractModel` to retrieve the list of stored models and `extractData` to retrieve training data.

```{r basic_lrnr, include=TRUE, eval=TRUE}
lay_genexpr <- training$getLayer(id = "geneexpr")
model_ge <- lay_genexpr$getModel()
print(model_ge)
models_list <- extractModel(training = training)
data_list <- extractData(training = training)
```

#### E) Predicting
Expand Down
28 changes: 28 additions & 0 deletions man/Training.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/extractData.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/extractModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/summary.Testing.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions man/summary.Training.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit cc1ae16

Please sign in to comment.