debeir.training.train_reranker
1from typing import List 2 3from debeir.datasets.types import RelevanceExample 4from debeir.training.utils import _train_sentence_transformer 5from sentence_transformers.evaluation import SentenceEvaluator 6 7 8def train_cross_encoder_reranker(model_fp_or_name: str, output_dir: str, train_dataset: List[RelevanceExample], 9 dev_dataset: List[RelevanceExample], train_batch_size=32, num_epochs=3, 10 warmup_steps=None, 11 evaluate_every_n_step: int = 1000, 12 special_tokens=None, pooling_mode=None, loss_func=None, 13 evaluator: SentenceEvaluator = None, 14 *args, **kwargs): 15 """ 16 Trains a reranker with relevance signals 17 18 :param model_fp_or_name: The model name or path to the model 19 :param output_dir: Output directory to save model, logs etc. 20 :param train_dataset: Training Examples 21 :param dev_dataset: Dev examples 22 :param train_batch_size: Training batch size 23 :param num_epochs: Number of epochs 24 :param warmup_steps: Warmup steps for the scheduler 25 :param evaluate_every_n_step: Evaluate the model every n steps 26 :param special_tokens: Special tokens to add, defaults to [DOC], [QRY] tokens (bi-encoder) 27 :param pooling_mode: Pooling mode for a sentence transformer model 28 :param loss_func: Loss function(s) to use 29 :param evaluator: Evaluator to use 30 """ 31 32 if special_tokens is None: 33 special_tokens = ["[DOC]", "[QRY]"] 34 35 return _train_sentence_transformer(model_fp_or_name, output_dir, train_dataset, 36 dev_dataset, train_batch_size, 37 num_epochs, warmup_steps, evaluate_every_n_step, 38 special_tokens, pooling_mode, loss_func, 39 evaluator)
def
train_cross_encoder_reranker( model_fp_or_name: str, output_dir: str, train_dataset: List[debeir.datasets.types.RelevanceExample], dev_dataset: List[debeir.datasets.types.RelevanceExample], train_batch_size=32, num_epochs=3, warmup_steps=None, evaluate_every_n_step: int = 1000, special_tokens=None, pooling_mode=None, loss_func=None, evaluator: sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator = None, *args, **kwargs):
9def train_cross_encoder_reranker(model_fp_or_name: str, output_dir: str, train_dataset: List[RelevanceExample], 10 dev_dataset: List[RelevanceExample], train_batch_size=32, num_epochs=3, 11 warmup_steps=None, 12 evaluate_every_n_step: int = 1000, 13 special_tokens=None, pooling_mode=None, loss_func=None, 14 evaluator: SentenceEvaluator = None, 15 *args, **kwargs): 16 """ 17 Trains a reranker with relevance signals 18 19 :param model_fp_or_name: The model name or path to the model 20 :param output_dir: Output directory to save model, logs etc. 21 :param train_dataset: Training Examples 22 :param dev_dataset: Dev examples 23 :param train_batch_size: Training batch size 24 :param num_epochs: Number of epochs 25 :param warmup_steps: Warmup steps for the scheduler 26 :param evaluate_every_n_step: Evaluate the model every n steps 27 :param special_tokens: Special tokens to add, defaults to [DOC], [QRY] tokens (bi-encoder) 28 :param pooling_mode: Pooling mode for a sentence transformer model 29 :param loss_func: Loss function(s) to use 30 :param evaluator: Evaluator to use 31 """ 32 33 if special_tokens is None: 34 special_tokens = ["[DOC]", "[QRY]"] 35 36 return _train_sentence_transformer(model_fp_or_name, output_dir, train_dataset, 37 dev_dataset, train_batch_size, 38 num_epochs, warmup_steps, evaluate_every_n_step, 39 special_tokens, pooling_mode, loss_func, 40 evaluator)
Trains a reranker with relevance signals
Parameters
- model_fp_or_name: The model name or path to the model
- output_dir: Output directory to save model, logs etc.
- train_dataset: Training Examples
- dev_dataset: Dev examples
- train_batch_size: Training batch size
- num_epochs: Number of epochs
- warmup_steps: Warmup steps for the scheduler
- evaluate_every_n_step: Evaluate the model every n steps
- special_tokens: Special tokens to add, defaults to [DOC], [QRY] tokens (bi-encoder)
- pooling_mode: Pooling mode for a sentence transformer model
- loss_func: Loss function(s) to use
- evaluator: Evaluator to use