PseudoLabelsClassifier
- class scikit_weak.classification.PseudoLabelsClassifier(estimator=LogisticRegression(), n_iterations=10, n_restarts=5, threshold=0.5, random_state=None)
A class to perform classification for weakly supervised data, based on the pseudo-labels strategy. The y input to the fit method should be given as an iterable of GenericWeakLabel
- Parameters
estimator (estimator class, default=LogisticRegression) – Base estimator objects to be fitted. Should support predict and predict_proba
n_restarts (int, default = 5) – The number of restarts
n_iterations (int, default=10) – The number of iterations for fitting
threshold (float, default=0.5) – The threshold for pseudo-label selection
random_state (int, default=None) – Random seed
- Variables
estimator (estimator) – The last fitted estimator
__n_classes (int) – The number of unique classes in y
__classes (list of int) – The unique classes in y
- fit(X, y)
Fit the PseudoLabelsClassifier model
- predict(X)
Returns predictions for the given X
- predict_proba(X)
Returns probability distributions for the given X