Title: | Inference on the Generalization Error |
---|---|
Description: | Confidence interval and resampling methods for inference on the generalization error. |
Authors: | Sebastian Fischer [cre, aut] , Hannah Schulz-Kümpel [aut] |
Maintainer: | Sebastian Fischer <[email protected]> |
License: | LGPL-3 |
Version: | 0.1.0 |
Built: | 2025-01-14 08:26:54 UTC |
Source: | https://github.com/mlr-org/mlr3inferr |
Confidence interval and resampling methods for inference on the generalization error.
Maintainer: Sebastian Fischer [email protected] (ORCID)
Authors:
Hannah Schulz-Kümpel [email protected] (ORCID)
Useful links:
Report bugs at https://github.com/mlr-org/mlr3inferr/issues
Base class for confidence interval measures. See section Inheriting on how to add a new method.
The aggregator of the wrapped measure is ignored, as the inheriting CI dictates how the point
estimate is constructed. If a measure for which to calculate a CI has $obs_loss
but also a $trafo
,
(such as RMSE), the delta method is used to obtain confidence intervals.
alpha
:: numeric(1)
The desired alpha level.
This is initialized to $0.05$.
within_range
:: logical(1)
Whether to restrict the confidence interval within the range of possible values.
This is initialized to TRUE
.
To define a new CI method, inherit from the abstract base class and implement the private method:
ci: function(tbl: data.table, rr: ResampleResult, param_vals: named
list()) -> numeric(3)
If requires_obs_loss
is set to TRUE
, tbl
contains the columns loss
, row_id
and iteration
, which are the pointwise loss,
Otherwise, tbl
contains the result of rr$score()
with the name of the loss column set to "loss"
.
the identifier of the observation and the resampling iteration.
It should return a vector containing the estimate
, lower
and upper
boundary in that order.
In case the confidence interval is not of the form (estimate, estimate - z * se, estimate + z * se)
it is also necessary to implement the private method:
.trafo: function(ci: numeric(3), measure: Measure) -> numeric(3)
Which receives a confidence interval for a pointwise loss (e.g. squared-error) and transforms it according
to the transformation measure$trafo
(e.g. sqrt to go from mse to rmse).
mlr3::Measure
-> MeasureAbstractCi
resamplings
(character()
)
On which resampling classes this method can operate.
measure
(Measure
)
new()
Creates a new instance of this R6 class.
MeasureAbstractCi$new( measure = NULL, param_set = ps(), packages = character(), resamplings, label, delta_method = FALSE, requires_obs_loss = TRUE )
measure
(Measure
)
The measure for which to calculate a confidence interval. Must have $obs_loss
.
param_set
(ParamSet
)
Set of hyperparameters.
packages
(character()
)
Set of required packages.
A warning is signaled by the constructor if at least one of the packages is not installed,
but loaded (not attached) later on-demand via requireNamespace()
.
resamplings
(character()
)
To which resampling classes this measure can be applied.
label
(character(1)
)
Label for the new instance.
delta_method
(logical(1)
)
Whether to use the delta method for measures (such RMSE) that have a trafo.
requires_obs_loss
(logical(1)
)
Whether the inference method requires a pointwise loss function.
aggregate()
Obtain a point estimate, as well as lower and upper CI boundary.
MeasureAbstractCi$aggregate(rr)
rr
(ResampleResult
)
The resample result.
named numeric(3)
clone()
The objects of this class are cloneable with this method.
MeasureAbstractCi$clone(deep = FALSE)
deep
Whether to make a deep clone.
For certain resampling methods, there are default confidence interval methods.
See mlr3::mlr_reflections$default_ci_methods
for a selection.
This measure will select the appropriate CI method depending on the class of the
used Resampling
.
Only those from MeasureAbstractCi
.
mlr3::Measure
-> mlr3inferr::MeasureAbstractCi
-> Measure
new()
Creates a new instance of this R6 class.
MeasureCi$new(measure)
measure
(Measure
or character(1)
)
A measure of ID of a measure.
aggregate()
Obtain a point estimate, as well as lower and upper CI boundary.
MeasureCi$aggregate(rr)
rr
(ResampleResult
)
Resample result.
named numeric(3)
clone()
The objects of this class are cloneable with this method.
MeasureCi$clone(deep = FALSE)
deep
Whether to make a deep clone.
rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("holdout")) rr$aggregate(msr("ci", "classif.acc")) # is the same as: rr$aggregate(msr("ci.holdout", "classif.acc"))
rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("holdout")) rr$aggregate(msr("ci", "classif.acc")) # is the same as: rr$aggregate(msr("ci.holdout", "classif.acc"))
The conservative-z confidence intervals based on the ResamplingPairedSubsampling
.
Because the variance estimate is obtained using only n / 2
observations, it tends to be conservative.
This inference method can also be applied to non-decomposable losses.
Only those from MeasureAbstractCi
.
mlr3::Measure
-> mlr3inferr::MeasureAbstractCi
-> MeasureCiConZ
new()
Creates a new instance of this R6 class.
MeasureCiConZ$new(measure)
measure
(Measure
or character(1)
)
A measure of ID of a measure.
clone()
The objects of this class are cloneable with this method.
MeasureCiConZ$clone(deep = FALSE)
deep
Whether to make a deep clone.
Nadeau, Claude, Bengio, Yoshua (1999). “Inference for the generalization error.” Advances in neural information processing systems, 12.
ci_conz = msr("ci.con_z", "classif.acc") ci_conz
ci_conz = msr("ci.con_z", "classif.acc") ci_conz
Corrected-T confidence intervals based on ResamplingSubsampling
.
A heuristic factor is applied to correct for the dependence between the iterations.
The confidence intervals tend to be liberal.
This inference method can also be applied to non-decomposable losses.
Only those from MeasureAbstractCi
.
mlr3::Measure
-> mlr3inferr::MeasureAbstractCi
-> MeasureCiCorrectedT
new()
Creates a new instance of this R6 class.
MeasureCiCorrectedT$new(measure)
measure
(Measure
or character(1)
)
A measure of ID of a measure.
clone()
The objects of this class are cloneable with this method.
MeasureCiCorrectedT$clone(deep = FALSE)
deep
Whether to make a deep clone.
Nadeau, Claude, Bengio, Yoshua (1999). “Inference for the generalization error.” Advances in neural information processing systems, 12.
m_cort = msr("ci.cor_t", "classif.acc") m_cort rr = resample( tsk("sonar"), lrn("classif.featureless"), rsmp("subsampling", repeats = 10) ) rr$aggregate(m_cort)
m_cort = msr("ci.cor_t", "classif.acc") m_cort rr = resample( tsk("sonar"), lrn("classif.featureless"), rsmp("subsampling", repeats = 10) ) rr$aggregate(m_cort)
Standard holdout CI. This inference method can only be applied to decomposable losses.
Only those from MeasureAbstractCi
.
mlr3::Measure
-> mlr3inferr::MeasureAbstractCi
-> MeasureCiHoldout
new()
Creates a new instance of this R6 class.
MeasureCiHoldout$new(measure)
measure
(Measure
or character(1)
)
A measure of ID of a measure.
clone()
The objects of this class are cloneable with this method.
MeasureCiHoldout$clone(deep = FALSE)
deep
Whether to make a deep clone.
ci_ho = msr("ci.holdout", "classif.acc") ci_ho rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("holdout")) rr$aggregate(ci_ho)
ci_ho = msr("ci.holdout", "classif.acc") ci_ho rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("holdout")) rr$aggregate(ci_ho)
Confidence Intervals based on ResamplingNestedCV
, including bias-correction.
This inference method can only be applied to decomposable losses.
Those from MeasureAbstractCi
, as well as:
bias
:: logical(1)
Whether to do bias correction. This is initialized to TRUE
.
If FALSE
, the outer iterations are used for the point estimate
and no bias correction is applied.
mlr3::Measure
-> mlr3inferr::MeasureAbstractCi
-> MeasureCiNestedCV
new()
Creates a new instance of this R6 class.
MeasureCiNestedCV$new(measure)
measure
(Measure
or character(1)
)
A measure of ID of a measure.
clone()
The objects of this class are cloneable with this method.
MeasureCiNestedCV$clone(deep = FALSE)
deep
Whether to make a deep clone.
Bates, Stephen, Hastie, Trevor, Tibshirani, Robert (2024). “Cross-validation: what does it estimate and how well does it do it?” Journal of the American Statistical Association, 119(546), 1434–1445.
ci_ncv = msr("ci.ncv", "classif.acc") ci_ncv
ci_ncv = msr("ci.ncv", "classif.acc") ci_ncv
Confidence intervals for cross-validation. The method is asymptotically exact for the so called Test Error as defined by Bayle et al. (2020). For the (expected) risk, the confidence intervals tend to be too liberal. This inference method can only be applied to decomposable losses.
Those from MeasureAbstractCi
, as well as:
variance
:: "all-pairs"
or "within-fold"
How to estimate the variance. The results tend to be very similar.
mlr3::Measure
-> mlr3inferr::MeasureAbstractCi
-> MeasureCiWaldCV
new()
Creates a new instance of this R6 class.
MeasureCiWaldCV$new(measure)
measure
(Measure
or character(1)
)
A measure of ID of a measure.
clone()
The objects of this class are cloneable with this method.
MeasureCiWaldCV$clone(deep = FALSE)
deep
Whether to make a deep clone.
Bayle, Pierre, Bayle, Alexandre, Janson, Lucas, Mackey, Lester (2020). “Cross-validation confidence intervals for test error.” Advances in Neural Information Processing Systems, 33, 16339–16350.
m_waldcv = msr("ci.wald_cv", "classif.ce") m_waldcv rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("cv")) rr$aggregate(m_waldcv)
m_waldcv = msr("ci.wald_cv", "classif.ce") m_waldcv rr = resample(tsk("sonar"), lrn("classif.featureless"), rsmp("cv")) rr$aggregate(m_waldcv)
This implements the Nested CV resampling procedure by Bates et al. (2024).
folds
:: integer(1)
The number of folds. This is initialized to 5
.
repeats
:: integer(1)
The number of repetitions. THis is initialized to 10
.
mlr3::Resampling
-> ResamplingNestedCV
iters
(integer(1)
)
The total number of resampling iterations.
new()
Creates a new instance of this R6 class.
ResamplingNestedCV$new()
unflatten()
Convert a resampling iteration to a more useful representation.
For outer resampling iterations, inner
is NA
.
ResamplingNestedCV$unflatten(iter)
iter
(integer(1)
)
The iteration.
list(rep, outer, inner)
clone()
The objects of this class are cloneable with this method.
ResamplingNestedCV$clone(deep = FALSE)
deep
Whether to make a deep clone.
Bates, Stephen, Hastie, Trevor, Tibshirani, Robert (2024). “Cross-validation: what does it estimate and how well does it do it?” Journal of the American Statistical Association, 119(546), 1434–1445.
ncv = rsmp("nested_cv", folds = 3, repeats = 10L) ncv rr = resample(tsk("mtcars"), lrn("regr.featureless"), ncv)
ncv = rsmp("nested_cv", folds = 3, repeats = 10L) ncv rr = resample(tsk("mtcars"), lrn("regr.featureless"), ncv)
Paired Subsampling to enable inference on the generalization error.
One should not directlu call $aggregate()
with a non-CI measure on a resample result using paired subsampling,
as most of the resampling iterations are only intended
The first repeats_in
iterations are a standard ResamplingSubsampling
and should be used to obtain a point estimate of the generalization error.
The remaining iterations should be used to estimate the standard error.
Here, the data is divided repeats_out
times into two equally sized disjunct subsets, to each of which subsampling
which, a subsampling with repeats_in
repetitions is applied.
See the $unflatten(iter)
method to map the iterations to this nested structure.
repeats_in
:: integer(1)
The inner repetitions.
repeats_out
:: integer(1)
The outer repetitions.
ratio
:: numeric(1)
The proportion of data to use for training.
mlr3::Resampling
-> ResamplingPairedSubsampling
iters
(integer(1)
)
The total number of resampling iterations.
new()
Creates a new instance of this R6 class.
ResamplingPairedSubsampling$new()
unflatten()
Unflatten the resampling iteration into a more informative representation:
inner
: The subsampling iteration
outer
: NA
for the first repeats_in
iterations. Otherwise it indicates
the outer iteration of the paired subsamplings.
partition
: NA
for the first repeats_in
iterations.
Otherwise it indicates whether the subsampling is applied to the first or second partition
Of the two disjoint halfs.
ResamplingPairedSubsampling$unflatten(iter)
iter
(integer(1)
)
Resampling iteration.
list(outer, partition, inner)
clone()
The objects of this class are cloneable with this method.
ResamplingPairedSubsampling$clone(deep = FALSE)
deep
Whether to make a deep clone.
Nadeau, Claude, Bengio, Yoshua (1999). “Inference for the generalization error.” Advances in neural information processing systems, 12.
pw_subs = rsmp("paired_subsampling") pw_subs
pw_subs = rsmp("paired_subsampling") pw_subs