ertk.classification.cross_validate
- ertk.classification.cross_validate(clf_lib: str, clf, x: ndarray, y: ndarray | None = None, *, groups: ndarray | None = None, cv: BaseCrossValidator | None = None, scoring: str | List[str] | Dict[str, Callable[[ndarray, ndarray], float]] | Callable[[...], float] = 'accuracy', verbose: int = 0, n_jobs: int = 1, fit_params: Dict[str, Any] = {})
Cross validate a classifier.
- Parameters:
- clf_lib: str
Classifier library to use. Must be one of “sk”, “tf”, or “pt”.
- clf:
Classifier to cross validate.
- x: np.ndarray
Input data.
- y: np.ndarray, optional
Target labels.
- groups: np.ndarray, optional
Groups to split data into for cross validation.
- cv: BaseCrossValidator, optional
Cross validation strategy. If None, use StratifiedKFold.
- scoring: str, list, dict, or callable, optional
Scoring metric(s) to use. If a string, must be a valid scikit-learn scorer.
- verbose: int, optional
Verbosity level.
- n_jobs: int, optional
Number of jobs to run in parallel. Only used for scikit-learn.
- fit_params: dict, optional
Parameters to pass to the classifier’s fit method.
- Returns:
- result: ExperimentResult
Cross validation results.