ertk.classification.dataset_train_val_test

ertk.classification.dataset_train_val_test(clf, dataset: Dataset, train_idx: Sequence[int] | ndarray, valid_idx: Sequence[int] | ndarray, test_idx: Sequence[int] | ndarray | None = None, label: str = 'label', clf_lib: str | None = None, sample_weight: ndarray | None = None, verbose: int = 0, scoring=None, fit_params: Dict[str, Any] = {}) ExperimentResult

Trains a Classifier instance on some training data, optionally using validation data, and returns results on given test data.

Parameters:
clf: class that implements fit() and predict()

The classifier to test.

dataset: Dataset

The dataset for within-corpus cross-validation.

clf_lib: str

One of {“sk”, “tf”, “pt”} to select which library-specific cross-validation method to use, since they’re not all quite compatible.

verbose: bool

Passed to train_val_test().

scoring: str, list, dict, optional

Scoring metric(s) to use. Can be anything accepted by scikit-learn’s cross_val* methods (i.e. str, list or dict).

fit_params: dict

Additional parameters passed to the model’s fit() method. This should be used to pass any more specific parameters not covered here.

Returns:
df: pandas.DataFrame

A dataframe holding the results from all runs with this model.