SetFit first fine-tunes a Sentence Transformer model on a small number of labeled examples (typically 8 or 16 per class). This is followed by training a classifier head on the embeddings generated from the fine-tuned Sentence Transformer.
Benchmark + Evaluation
SetFit outperforms PET and GPT-3 - it is beaten by T-Few which is a model that is approx 100x bigger. It isn’t far off the human baseline for the benchmarked task on the RAFT leaderboard
SetFit and Python
We can easily train a SetFit model with a few lines of Python.
Installation
We need to install the library pip install setfit - this will pull in any further dependencies on pytorch, huggingface, sentence-transformer etc.
Training a SetFit model
Using the Huggingface dataset library and pandas we can take a CSV with text and label columns and use it to train and test setfit:
We can then use the model to predict stuff:
We can also save the model for loading via from_pretrained next time:
Predicting with Confidence Scores:
We can use model.predict_proba() to get an array of label probabilities (after softmax) per input. We then need to map them back on to the known label:
Multi-Label Training
As per the documentation we need to pass a multi_target_strategy argument when we load the base model. We also need to pass binarized labels as label in the training dataset which Argilla‘s prepare_dataset_for_training will do automatically.