SetFit
Introduction
SetFit is a framework for few-shot fine-tuning of Sentence Transformers, developed by Hugging Face.
Key Features
- Prompt-Free Approach: Unlike other few-shot learning methods, SetFit does not require handcrafted prompts or verbalisers. It generates rich embeddings directly from a small number of labeled text examples.
- Efficiency: SetFit achieves high accuracy without the need for large-scale models like T0 or GPT-3, making it significantly faster to train and run inference with.
- Multilingual Support: SetFit can be used with any Sentence Transformer model available on the Hub, enabling text classification in multiple languages with ease.
How It Works
SetFit adopts a two-stage training process: 1. Fine-tuning the Sentence Transformer: In the initial stage, SetFit fine-tunes a Sentence Transformer model on a small number of labeled examples using contrastive training. The model learns to generate dense embeddings for each example. 2. Training the Classifier: In the second stage, SetFit trains a classifier head on the embeddings generated by the fine-tuned Sentence Transformer. This classifier can then predict the labels for unseen examples based on their embeddings.
Usage example
Preprocessing
from datasets import load_dataset
import pandas as pd
from few_shot_learning_nlp.utils import stratified_train_test_split
from torch.utils.data import DataLoader
from few_shot_learning_nlp.few_shot_text_classification.setfit_dataset import SetFitDataset
# Load a dataset for text classification
ag_news_dataset = load_dataset("ag_news")
# Extract necessary information from the dataset
num_classes = len(ag_news_dataset['train'].features['label'].names)
# Perform few-shot learning by selecting a limited number of classes
n_shots = 50
train_validation, test_df = stratified_train_test_split(ag_news_dataset['train'], num_shots_per_class=n_shots)
train_df, val_df = stratified_train_test_split(pd.DataFrame(train_validation), num_shots_per_class=30)
# Create SetFitDataset objects for training and validation
set_fit_data_train = SetFitDataset(train_df['text'], train_df['label'], input_example_format=True)
set_fit_data_val = SetFitDataset(val_df['text'], val_df['label'], input_example_format=False)
# Create DataLoader objects for training and validation datasets
train_dataloader = DataLoader(set_fit_data_train.data, shuffle=False)
val_dataloader = DataLoader(set_fit_data_val)
Defining Classifier
import torch
class CLF(torch.nn.Module):
def __init__(
self,
in_features : int,
out_features : int,
*args,
**kwargs
) -> None:
super().__init__(*args, **kwargs)
self.layer1 = torch.nn.Linear(in_features, 128)
self.relu = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(128, 32)
self.layer3 = torch.nn.Linear(32, out_features)
def forward(self, x : torch.Tensor):
x = self.layer1(x)
x = self.relu(x)
x = self.layer2(x)
x = self.relu(x)
return self.layer3(x)
Training the Embedding Model
import torch
from sentence_transformers import SentenceTransformer
from few_shot_learning_nlp.few_shot_text_classification.setfit import SetFitTrainer
# Load a pre-trained Sentence Transformer model
model = SentenceTransformer("whaleloops/phrase-bert")
# Initialize the SetFitTrainer with embedding model and classifier
embedding_model = model.to("cuda")
in_features = embedding_model.get_sentence_embedding_dimension()
clf = CLF(in_features, num_classes).to("cuda")
trainer = SetFitTrainer(embedding_model, clf, num_classes)
# Train the embedding model
trainer.train_embedding(train_dataloader, val_dataloader, n_epochs=10)
Training the Classifier Model
# Shuffle training data
_, class_counts = np.unique(train_df['label'], return_counts=True)
X_train_shuffled, y_train_shuffled = shuffle_two_lists(train_df['text'], train_df['label'])
# Train the classifier
history, embedding_model, clf = trainer.train_classifier(
X_train_shuffled, y_train_shuffled, val_df['text'], val_df['label'],
clf=CLF(in_features, num_classes),
n_epochs=15,
lr=1e-4
)
Testing the Models
y_true, y_pred = trainer.test(test_df)
SetFitTrainer
Introduction
The SetFitTrainer class is designed to facilitate the training and testing of embedding and classification models using Sentence Transformers and PyTorch. It provides methods for training embedding models, training classifier models, and testing the performance of trained models on test datasets.
Initialization
def __init__(
self,
embedding_model,
classifier_model: torch.nn.Module,
num_classes: int,
dataset_name: str = None,
model_name: str = None,
device: str ='cuda',
) -> None:
embedding_model: Pre-trained embedding model.classifier_model: Classifier model for text classification.num_classes: Number of classes in the classification task.dataset_name: Name of the dataset (optional).model_name: Name of the model (optional).device: Device on which calculations are performed (default: "cuda").
Methods
-
train_embedding(train_dataloader, val_dataloader, n_epochs=10, filepath=None, **kwargs)- Train the embedding model using the provided training dataloader and validate it using the validation dataloader.
- Args:
train_dataloader: DataLoader containing the training data.val_dataloader: DataLoader containing the validation data.n_epochs: Number of epochs for training (default: 10).filepath: Filepath to save the best model (default: None).**kwargs: Additional keyword arguments to pass to the embedding model's fit method.
- Returns: None
-
train_classifier(X_train, y_train, X_val, y_val, n_epochs=100, loss_fn=torch.nn.CrossEntropyLoss(), embedding_model=None, clf=None, lr=1e-5)- Train the classifier model using the provided training and validation data.
- Args:
X_train: List of training texts.y_train: List of corresponding training labels.X_val: List of validation texts.y_val: List of corresponding validation labels.n_epochs: Number of epochs for training (default: 100).loss_fn: Loss function for training (default: torch.nn.CrossEntropyLoss()).embedding_model: Pre-trained embedding model to use. If None, uses the best_model (default: None).clf: Classifier model to use. If None, uses self.clf (default: None).lr: Learning rate for optimizer (default: 1e-5).
- Returns: Tuple containing the history of F1 scores during training, the embedding model, and the best classifier model.
-
test(test_df, embedding_model=None, clf=None)- Test the performance of the trained models on the provided test dataset.
- Args:
test_df: DataFrame containing the test dataset with 'text' and 'label' columns.embedding_model: SentenceTransformer model for text embedding. If None, the best trained embedding model will be used (default: None).clf: Trained classifier model. If None, the best trained classifier will be used (default: None).
- Returns: Tuple containing the true labels and predicted labels for the test dataset.
SetFitDataset
Introduction
The SetFitDataset class is designed to create pairs of texts with their corresponding labels for training. It expands the dataset by considering all possible pairs of texts or randomly selecting pairs within a specified radius.
Initialization
def __init__(
self,
text: List[str],
labels: List[int],
R: int = -1,
input_example_format: bool = True
) -> None:
text: List of texts.labels: List of corresponding labels.R: Radius for data expansion. If negative, considers all possible pairs within the dataset. If positive, randomly selects pairs within the specified radius (default: -1).input_example_format: If True, returns expanded data in the InputExample format. If False, returns expanded data as a list of lists (default: True).
Attributes
data: Expanded dataset containing pairs of texts with their labels.
Methods
-
expand_data(X, y, R, input_example_format)- Static method to expand the dataset by creating pairs of texts with their corresponding labels.
- Args:
X: List of texts.y: List of corresponding labels.R: Radius for data expansion. If negative, considers all possible pairs within the dataset. If positive, randomly selects pairs within the specified radius (default: -1).input_example_format: If True, returns expanded data in the InputExample format. If False, returns expanded data as a list of lists (default: True).
- Returns:
- Expanded dataset containing pairs of texts with their labels.
-
__len__()- Returns the length of the dataset.
-
__getitem__(index)- Returns the item at the specified index in the dataset. ```