PseudoLabelsClassifier

class scikit_weak.classification.PseudoLabelsClassifier(estimator=LogisticRegression(), n_iterations=10, n_restarts=5, threshold=0.5, random_state=None)

A class to perform classification for weakly supervised data, based on the pseudo-labels strategy. The y input to the fit method should be given as an iterable of GenericWeakLabel

Parameters
  • estimator (estimator class, default=LogisticRegression) – Base estimator objects to be fitted. Should support predict and predict_proba

  • n_restarts (int, default = 5) – The number of restarts

  • n_iterations (int, default=10) – The number of iterations for fitting

  • threshold (float, default=0.5) – The threshold for pseudo-label selection

  • random_state (int, default=None) – Random seed

Variables
  • estimator (estimator) – The last fitted estimator

  • __n_classes (int) – The number of unique classes in y

  • __classes (list of int) – The unique classes in y

fit(X, y)

Fit the PseudoLabelsClassifier model

predict(X)

Returns predictions for the given X

predict_proba(X)

Returns probability distributions for the given X