ertk.classification.cross_validate

ertk.classification.cross_validate(clf_lib: str, clf, x: ndarray, y: ndarray | None = None, *, groups: ndarray | None = None, cv: BaseCrossValidator | None = None, scoring: str | List[str] | Dict[str, Callable[[ndarray, ndarray], float]] | Callable[[...], float] = 'accuracy', verbose: int = 0, n_jobs: int = 1, fit_params: Dict[str, Any] = {})

Cross validate a classifier.

Parameters:
clf_lib: str

Classifier library to use. Must be one of “sk”, “tf”, or “pt”.

clf:

Classifier to cross validate.

x: np.ndarray

Input data.

y: np.ndarray, optional

Target labels.

groups: np.ndarray, optional

Groups to split data into for cross validation.

cv: BaseCrossValidator, optional

Cross validation strategy. If None, use StratifiedKFold.

scoring: str, list, dict, or callable, optional

Scoring metric(s) to use. If a string, must be a valid scikit-learn scorer.

verbose: int, optional

Verbosity level.

n_jobs: int, optional

Number of jobs to run in parallel. Only used for scikit-learn.

fit_params: dict, optional

Parameters to pass to the classifier’s fit method.

Returns:
result: ExperimentResult

Cross validation results.