This function trains the a machine learning model on the training data

train.model(siamcat,
    method = c("lasso","enet","ridge","lasso_ll", "ridge_ll", "randomForest")

Arguments

siamcat

object of class siamcat-class

method

string, specifies the type of model to be trained, may be one of these: c('lasso', 'enet', 'ridge', 'lasso_ll', 'ridge_ll', 'randomForest')

stratify

boolean, should the folds in the internal cross-validation be stratified?, defaults to TRUE

modsel.crit

list, specifies the model selection criterion during internal cross-validation, may contain these: c('auc', 'f1', 'acc', 'pr'), defaults to list('auc')

min.nonzero.coeff

integer number of minimum nonzero coefficients that should be present in the model (only for 'lasso', 'ridge', and 'enet', defaults to 1

param.set

a list of extra parameters for mlr run, may contain:

  • cost - for lasso_ll and ridge_ll

  • alpha for enet

  • ntree and mtry for RandomForrest.

Defaults to NULL

verbose

control output: 0 for no output at all, 1 for only information about progress and success, 2 for normal level of information and 3 for full debug information, defaults to 1

Value

object of class siamcat-class with added model_list

Details

This functions performs the training of the machine learning model and functions as an interface to the mlr-package.

The function expects a siamcat-class-object with a prepared cross-validation (see create.data.split) in the data_split-slot of the object. It then trains a model for each fold of the datasplit.

For the machine learning methods that require additional hyperparameters (e.g. lasso_ll), the optimal hyperparameters are tuned with the function tuneParams within the mlr-package.

The methods 'lasso', 'enet', and 'ridge' are implemented as mlr-taks using the 'classif.cvglmnet' Learner, 'lasso_ll' and 'ridge_ll' use the 'classif.LiblineaRL1LogReg' and the 'classif.LiblineaRL2LogReg' Learners respectively. The 'randomForest' method is implemented via the 'classif.randomForest' Learner.

Examples

data(siamcat_example) # simple working example siamcat_validated <- train.model(siamcat_example, method='lasso')
#> Error in data.split$num.folds: $ operator not defined for this S4 class