--- title: Creating a new Learner output: rmarkdown::html_vignette: toc: true vignette: > %\VignetteIndexEntry{Creating a new Learner} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- Many learners are already included in the `mlr3` ecosystem, but there are still many more that have not been implemented and so in this vignette we will look at how to add a new `Learner` to `mlr3`. Here we will already assume a lot of knowledge, so it is best to read the [mlr3book](https://mlr3book.mlr-org.com/) first! If you want to contribute a new `Learner` to `mlr3extralearners` it is recommended to first open an issue to discuss the learner. As another side note, the `mlr3extralearner::create_learner()` function can be used to generate files containing templates for the learner and test files required to create and test a `Learner`. We will not use it in this vignette for the sake of clarity, but recommend using it in practice to avoid writing a lot of boilerplate code. We will first demonstrate what the final class looks like and then we will explain it line by line (and method by method). As a working example, we will implement a slightly stripped-down version of `regr.rpart`. ```{r} library(mlr3) library(paradox) library(R6) library(mlr3misc) LearnerRegrRpartSimple = R6Class("LearnerRegrRpartSimple", inherit = LearnerRegr, public = list( initialize = function() { param_set = ps( cp = p_dbl(0, 1, default = 0.01, tags = "train"), maxcompete = p_int(0L, default = 4L, tags = "train"), maxdepth = p_int(1L, 30L, default = 30L, tags = "train"), maxsurrogate = p_int(0L, default = 5L, tags = "train"), minbucket = p_int(1L, tags = "train"), minsplit = p_int(1L, default = 20L, tags = "train"), surrogatestyle = p_int(0L, 1L, default = 0L, tags = "train"), usesurrogate = p_int(0L, 2L, default = 2L, tags = "train"), xval = p_int(0L, default = 10L, tags = "train") ) param_set$values$xval = 10 super$initialize( id = "regr.rpart_simple", feature_types = c("logical", "integer", "numeric", "factor", "ordered"), predict_types = "response", packages = "rpart", param_set = param_set, properties = c("weights", "missings", "importance"), label = "Regression Tree", man = "mlr3::mlr_learners_regr.rpart" ) }, importance = function() { if (is.null(self$model)) { stopf("No model stored") } # importance is only present if there is at least on split if (is.null(self$model$variable.importance)) { importance = set_names(numeric()) } else { importance = sort(self$model$variable.importance, decreasing = TRUE) } return(importance) } ), private = list( .train = function(task) { pv = self$param_set$get_values(tags = "train") if ("weights" %in% task$properties) { pv$weights = task$weights$weight } invoke( rpart::rpart, formula = task$formula(), data = task$data(), .args = pv ) }, .predict = function(task) { pv = self$param_set$get_values(tags = "predict") # ensure same column order in train and predict newdata = mlr3extralearners:::ordered_features(task, self) response = invoke(predict, self$model, newdata = newdata, .args = pv) list(response = unname(response)) } ) ) ``` We first specify the class name according to the `mlr3` naming convention as `"LearnerRegrRpartSimple"`. In the second line, we select the appropriate parent class from which we want to inherit. Because we are implementing a regression learner, we have to inherit from the `LearnerRegr` class. For classification we would inherit from `LearnerClassif`, for survival from `mlr3proba::LearnerSurv`, and for clustering from `mlr3cluster::LearnerClust`. In the constructor, the constructor of the super class (in this case `LearnerRegr`) is called with meta-information about the learner which we are defining. We strongly recommend reading the relevant class documentation to understand its arguments (see e.g. the [learner documentation](https://mlr3.mlr-org.com/reference/Learner.html)). These include: * `id`: The ID of the new learner. * `packages`: The upstream package name(s) of the implemented learner. * `param_set`: A set of hyperparameters and their descriptions provided as a `paradox::ParamSet`. * `predict_types`: Set of predict types the learner supports. * `feature_types`: Set of feature types the learner is able to handle. * `properties`: Set of properties of the learner. * `man`: The roxygen identifier of the learner. This is used within the `$help()` method of the super class to open the help page of the learner. * `label`: The label of the learner. This should briefly describe the learner (similar to the description's title) and is for example used for printing. To fill out some of this meta-information, one has to go through the manual pages of the upstream packages. In order to know e.g. which predict types or properties are available, reading the manual pages of the corresponding superclass(es) is indispensable. Furthermore, it is a good idea to look at already existing learner implementations for inspiration. Three properties in `mlr3` exist that deserve special consideration, which are: * `"validation"` - the learner can make use of an internal validation to e.g. track the performance * `"internal_tuning"` - the learner can internally optimize hyperparameters, e.g. via early stopping * `"marshal"` - the learner cannot be serialized without loss of information (e.g. learners from `mlr3torch` have this property) While these properties are rarely needed, it is still important to know how to implement them if they do apply. For the first two properties, see the [documentation of `Learner`](https://mlr3.mlr-org.com/reference/Learner.html) or for a concrete implementation the [`XGBoost` learners](https://github.com/mlr-org/mlr3learners/blob/main/R/LearnerClassifXgboost.R). The documentation on marshaling can be found by running `?marshaling`, and for an example on how to implement this see [`mlr3torch::LearnerTorch`](https://github.com/mlr-org/mlr3torch/blob/main/R/LearnerTorch.R). In any case, if you suspect your learner to have any of these properties it is best to first discuss the implementation with the `mlr-org` team. ## Defining the Parameter Set of a Learner The parameter set of a learner is the set of hyperparameters used in model training and predicting, this is given as a `paradox::ParamSet`. The set consists of a list of hyperparameters, where each has a specific class for the hyperparameter type. For an explanation of parameters and the `ParamSet` class we refer to the [mlr3book](https://mlr3book.mlr-org.com/), where we show how to create parameter sets to define search spaces for hyperparameter tuning. However, there are some aspects of parameters that are mostly relevant when creating new `Learner`s, so we cover them here. These are: * **tags** that organize the parameters, e.g. whether they are used during training or prediction. * **default values** which give information about the value that is used by the upstream function when no value is set. Note that this is not necessarily the default of the function signature but what is used internally: e.g. the default of parameter `a` for a function is `NULL`, but internally the value for `a` is set to `1` it is left as `NULL` by the caller. * **initial values** are the initial `$values` of the parameter set that are set during construction. * **custom checks** which allow to specify custom constraints for untyped parameters (`ParamUty`). **Tags** Each parameter has one or more tags that determine in which method they are used. I.e. parameters that are used during training must be tagged with `"train"` and those that are used during prediction with `"predict"`. In our case there are only training parameters. Parameters with certain tags, can be retrieved by calling the `$get_values()` method of a `ParamSet` with the `tag` argument. Furthermore, there are other tags that serve specific purposes: * The tag `"threads"` should be used (if applicable) to tag the parameter that determines the number of threads used for the learner's internal parallelization. Once a learner is constructed, this parameter can be set using `set_threads()`. * The tag `"required"` should be used to tag parameters that must be provided for the algorithm to be executable. **Default Values** The default values of each parameter indicate which value the upstream function (in this case `rpart::rpart()`) uses if no parameter is specified. These have to be retrieved from the package's documentation. If parameters have no default, they have to be tagged with the `"required"` tag. For such parameters, we recommended setting the parameter to a value in the initialize method, so that the learner can be executed after creation. Note that setting a parameter value in the `$initialize()` method does not change the parameter's default value (as the default denotes what happens when no value is set). Parameters expect the default values to be of their respective type, e.g. the default of `ParamInt` must be an integer. Some packages have R expressions as default values which cannot be properly expressed in `paradox`. In such cases, a practical compromise is to not specify any default value when calling `p_int()`. **Initial Values** In some cases you want to change the parameter values during initialization. You can do this by changing the `param_set$values` during the learner's construction. You can see we have done this for `regr.rpart_simple` where the default for `xval` is changed to `0`. Note however that we still annotate the default of the `xval` parameter as 10. It is recommended to initialize parameter values for good reasons and document this properly when doing so. When contributing a learner to `mlr3extralearners` this information should go into the *Initial parameter values* section. **Custom Checks** In `paradox` only integers, numerics, factors, and logicals have dedicated parameter types. For all other parameters, the `ParamUty` class has to be used, which is constructable via the `p_uty()` function. It has the argument `custom_check`, which allows specifying custom constraints. This function return `TRUE` if the input is valid and a `character(1)` with the error message otherwise. **Custom Parameters** Sometimes you might want to add paramters that are not present in the upstream function. One example is the `mtry.ratio` parameter of `lrn("surv.ranger")`. Such parameters should usually not have `default` values (because we are here not calling an upstream function with a default behaviour) but instead have their value initialized during construction and have the `"required"` tag. This avoids having to deal with the case where this parameter has no value set. When contributing a learner to `mlr3extralearners` this information should go into the *Custom parameters* section. ## Train function The train function takes a `Task` as input and must return a model, i.e. some `R` object that allows the learner to make predictions. We want to translate the following call of `rpart::rpart()` into code that can be used inside the `.train()` method. First, we write something down that works completely without `mlr3`: ```{r} data = mtcars model = rpart::rpart(mpg ~ ., data = mtcars, xval = 0) ``` We need to pass the formula notation `mpg ~ .`, the data and the hyperparameters. To get the hyperparameters, we call `self$param_set$get_values(tag = "train")` and thereby query all parameters that are using during `"train"`. Then, the dataset is extracted from the `Task`. Because the learner has the property `"weights"`, we insert the weights of the task if there are any. Then we obtain the formula from the task using `task$formula()` and access the training data using `task$data()`. Last, we call the upstream function `rpart::rpart()` with the data and pass all hyperparameters via argument `.args` using the `mlr3misc::invoke()` function. The latter is simply an optimized version of `do.call()` that we use within the `mlr3` ecosystem. The return value of this method will be available as the `$model` slot of the trained learner. ```{r} .train = function(task) { pv = self$param_set$get_values(tags = "train") if ("weights" %in% task$properties) { pv$weights = task$weights$weight } formula = task$formula() data = task$data() invoke( rpart::rpart, formula = formula, data = data, .args = pv ) } ``` At this point, an explanation for the difference between the public method `$train()` and the private method `$.train()` is in order. The former essentially calls the latter and additionally performs some checks and tasks that are the same for every learner. The same holds for the predict method. ## Predict function The internal method `$.predict()` also operates on a `Task` as well as on the fitted model that has been created by the `$train()` call previously and has been stored in `self$model`. The return value must contain the required information to produce a `mlr3::Prediction` object that is returned when `$predict()` is called on a learner. We proceed analogously to what we did in the previous section. We start with a version without any `mlr3` objects and continue to replace objects until we have reached the desired interface: ```{r eval=TRUE} # inputs: task = tsk("mtcars") self = list(model = rpart::rpart(task$formula(), data = task$data())) data = mtcars # response prediction response = predict(self$model, newdata = data) ``` Next, we transition from `data` to a `Task` again and construct a list with the return type requested by the user. This is stored in the `$predict_type` slot of a learner class. Because the `regr.rpart_simple` learner we are implementing has only one predict type, we do not have to take the value of `self$predict_type` into account. For, e.g., regression learners with predict types `"response"` and `"se"` or classification learners with predict types `"response"` and `"prob"`, the predict method must return the prediction requested by the user. The final `$.predict()` method is below, we could omit the `pv` line as there are no parameters with the `"predict"` tag but we keep it here to be consistent: ```{r} .predict = function(task) { pv = self$param_set$get_values(tags = "predict") # ensure same column order in train and predict newdata = mlr3extralearners:::ordered_features(task, self) response = invoke(predict, self$model, newdata = newdata, .args = pv) list(response = unname(response)) } ``` Beware that you cannot rely on the column order of the data returned by `task$data()` as the order of columns may be different from the order of the columns during `$.train()`. The `newdata` line ensures the ordering is the same by calling the same order as in train! ## Optional Extractors Specific learner implementations are free to implement additional getters to ease the access of certain parts of the model in the inherited subclasses. Some of these methods are standardized in `mlr3`, c.f. the `Learner` documentation. If it is one of the standardized extractors, some custom code might have to be written for the return value to comply with the `mlr3` standards (see the `Learner` documentation). Because we specified earlier that the learner has the property `"importance"`, we implement the public `$importance()` method. They access the `$model` slot of a learner after training and return one of its fields. Because the `$importance()` method is standardized, it must return the importance scores as a sorted numeric vector, the names being the features. ```{r} importance = function() { if (is.null(self$model)) { stopf("No model stored") } # importance is only present if there is at least on split if (is.null(self$model$variable.importance)) { importance = set_names(numeric()) } else { importance = sort(self$model$variable.importance, decreasing = TRUE) } return(importance) } ``` ## Testing the learner Once your learner is created, you should write tests to verify its correctness. While you should also do this when implementing a learner only for your own use, this is absolutely required when you want to contribute a new learner to `mlr3extralearners`. Besides manual tests, the `mlr3` package also provides automatic tests to check that a learner satisfies some basic sanity checks and that it implements all the available parameters of the upstream function. The manual page for these functions can be accessed via `?mlr_test_helpers`. For a bare-bone check you can just try to run a simple `$train()` call. ```{r} task = tsk("mtcars") learner = LearnerRegrRpartSimple$new() learner$train(task) p = learner$predict(task) p$score(msr("regr.mse")) learner$importance() ``` If it runs without erroring, that's a very good start! ### Autotest To ensure that your learner is able to handle all kinds of different properties and feature types, we have written an "autotest" that checks the learner for different combinations of such. In addition to the autotest we also check some other learner properties using `expect_learner()`. See section *run_autotest* and *expect_learner* of `?mlr_test_helpers` for an explanation. Because these functions are not in the `mlr3` namespace, we have to source them from the `./testthat` subfolder of the installed `mlr3` package Note that when implementing survival learners, you must also source these help files from the `mlr3proba` package. ```{r, results = 'hide'} lapply(list.files(system.file("testthat", package = "mlr3"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source) # mlr3proba # lapply(list.files(system.file("testthat", package = "mlr3proba"), pattern = "^helper.*\\.[rR]", full.names = TRUE), source) ``` We now perform the autotest. By setting `expect_learner(..., check_man = FALSE)`, we disable a test that verifies that the correct manual page exists, which is only relevant when adding the learner to `mlr3extralearners`. ```{r} library(testthat) test_that("autotest", { learner = LearnerRegrRpartSimple$new() # basic learner properties expect_learner(learner) # you can skip tests using the `exclude` argument result = run_autotest(learner) expect_true(result, info = result$error) }) ``` ### Checking Parameters Some learners have a high number of parameters and it is easy to miss out on some during the creation of a new learner or misspell a parameter. Therefore we have written a "Parameter Check" that can be used to check that the parameters are correctly implemented. See section *run_paramtest* of `?mlr_test_helpers` for an explanation. This "Parameter Check" compares the parameters of the `mlr3` ParamSet against all arguments available in the upstream function that is called during `$train()` and `$predict()`. When the `$.train()` method either calls multiple functions, or, e.g., the `...` arguments are passed further to a control function (as in the case here), a list of functions can be passed to the parameter test. The test comes with an `exclude` argument that should be used to _exclude and explain_ why certain arguments of the upstream function are not within the `ParamSet` of the learner. This will likely be required for all learners as common arguments like `x`, `target` or `data` are handled by the `mlr3` interface and are therefore not included within the ParamSet. However, there might be more parameters that need to be excluded, for example: * Type dependent parameters, i.e. parameters that only apply for classification or regression learners. * Parameters that are actually deprecated by the upstream package and which were therefore not included in the `mlr3` ParamSet. All excluded parameters should have a comment justifying their exclusion. In our example, the final paramtest script looks like: ```{r} test_that("paramtest", { learner = LearnerRegrRpartSimple$new() exclude = c("formula", "data", "weights", "subset", "na.action", "method", "model", "x", "y", "parms", "control", "cost", "keep_model") result = run_paramtest(learner, list(rpart::rpart, rpart::rpart.control), exclude, tag = "train") expect_true(result, info = result$error) exclude = c( "object", # handled internally "newdata", # handled internally "type", # handled internally "na.action", # handled internally "model" # not implemented ) result = run_paramtest(learner, rpart:::predict.rpart, exclude, tag = "predict") expect_true(result, info = result$error) }) ``` ## Contributing to mlr3extralearners When adding a `Learner` to `mlr3extralearners` there are some additional requirements that have to be satisfied: * Adding the line `.extralrns_dict$add("", function() $new())` to the bottom of the learner file to add the learner to the learners `Dictionary`. * Add the learner dependencies to *Suggests* (and not e.g. to *Imports*) in the DESCRIPTION file. * The learner must be properly documented. When using `mlr3extralearners::create_learner()` this is automatically generated. Otherwise you can use and adapt the documentation of an already implemented learner. * Conforming to the `mlr3` [style guide](https://github.com/mlr-org/mlr3/wiki/Style-Guide). * The tests must pass in the CI (also check the CodeFactor results). * Update the `NEWS.md` file and describe that you added the learner. * Add yourself as a contributor