Training

Backends

Classes and utilities

Training and evaluation classes and functions

Config classes

TrainConfig([balanced, reps, normalise, ...])

Class to hold training configuration.

ExperimentResult(scores_dict, ...)

Class to hold results of an experiment.

ExperimentConfig(name, data, model, eval, ...)

Class to hold experiment configuration.

ModelConfig(type, config, args, ...)

Class to hold model configuration.

CrossValidationConfig([part, kfold, test_size])

Class to hold cross-validation configuration.

EvalConfig([cv, train, valid, test, ...])

Class to hold evaluation configuration.

Splitting classes and functions

TrainValidation()

Validation method that uses the training set as validation set.

ShuffleGroupKFold([n_splits, shuffle, ...])

Like GroupKFold but with random combinations of groups instead of deterministic combinations based on group size.

ValidationSplit(train_idx, valid_idx)

Validation method that uses a pre-defined validation set.

get_cv_splitter(group, k[, test_size, ...])

Gets an appropriate cross-validation splitter for the given number of folds and groups, or a single random split.

Miscellaneous functions

get_pipeline_params(params, pipeline)

Modifies parameter names to pass to a Pipeline instance's fit() method.

get_scores(scoring, y_pred, y_true)

Get dictionary of scores for predictions.

scores_to_df(scores[, index])

Convert scikit-learn scores dictionary to pandas dataframe.

Classification

Classification functions.

This module contains functions for performing classification tasks.

binary_accuracy_score(y_true, y_pred, *[, ...])

Calculated binary accuracy.

standard_class_scoring(classes)

Given a list of classes, returns scikit-learn scorers for overall metrics and per-class metrics, for multiclass classification.

cross_validate(clf_lib, clf, x[, y, groups, ...])

Cross validate a classifier.

dataset_cross_validation(clf, dataset, clf_lib)

Cross validates a Classifier instance on a single dataset.

dataset_train_val_test(clf, dataset, ...[, ...])

Trains a Classifier instance on some training data, optionally using validation data, and returns results on given test data.

get_balanced_class_weights(classes)

Gets class weights such that each class has the same total weight across all instances.

get_balanced_sample_weights(labels)

Gets sample weights such that each unique label has the same total weight across all instances.

class_ratings_to_probs(ratings[, classes])

Convert annotator ratings into distribution over classes for each instance.

Transforms

Transforms to use in estimators.

group_transform(x, groups, transform, *[, ...])

Per-group (offline) transformation (e.g.

instance_transform(x, transform, *[, inplace])

Per-instance transformation (e.g.

GroupTransformWrapper(transformer)

Transform that modifies groups independently without storing parameters.

InstanceTransformWrapper(transformer)

Transform that modifies instances independently without storing parameters.

SequenceTransform()

Transform designed to process sequences of vectors.

SequenceTransformWrapper(transformer, method)

Wrapper around a scikit-learn transform that can process sequences of vectors.