ertk.train.get_cv_splitter

ertk.train.get_cv_splitter(group: bool, k: int, test_size: float = 0.2, shuffle: bool = False, random_state: int | None = None) BaseCrossValidator

Gets an appropriate cross-validation splitter for the given number of folds and groups, or a single random split.

Parameters:
group: bool

Whether to split over pre-defined groups of instances.

k: int

If k > 1 then do k-fold CV. If k == 1 then do one random split. If k = -1 then do leave-one-out. If k == 0 then use the whole train set as validation split.

test_size: float

The size of the test set when k == 1 (one random split).

shuffle: bool

Whether to shuffle when using k-fold for k > 1.

random_state: int, optional

The random state to set for splitters with shuffling behaviour.

Returns:
splitter: BaseCrossValidator

Cross-validation splitter that has split() and get_n_splits() methods.