debeir.training.hparm_tuning.trainer
1import abc 2from collections import defaultdict 3from functools import partial 4from typing import Dict, Sequence, Union 5 6import loguru 7import optuna 8import torch 9import torch_optimizer 10from debeir.training.hparm_tuning.config import HparamConfig 11from debeir.training.hparm_tuning.types import Hparam 12from debeir.training.utils import LoggingEvaluator, LoggingLoss 13from sentence_transformers import SentenceTransformer, losses 14from torch.utils.data import DataLoader 15from wandb import wandb 16 17from datasets import Dataset, DatasetDict 18 19 20class OptimizersWrapper: 21 def __getattr__(self, name): 22 if name in torch.optim.__dict__: 23 return getattr(torch.optim, name) 24 elif name in torch_optimizer.__dict__: 25 return getattr(torch_optimizer, name) 26 else: 27 raise ModuleNotFoundError("Optimizer is not implemented, doesn't exist or is not supported.") 28 29 30class Trainer: 31 """ 32 Wrapper class for a trainer class. 33 34 """ 35 36 def __init__(self, model, evaluator_fn, dataset_loading_fn): 37 self.evaluator_fn = evaluator_fn 38 self.model_cls = model # Trainer object or method we will initialize 39 self.dataset_loading_fn = dataset_loading_fn 40 41 @abc.abstractmethod 42 def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset): 43 raise NotImplementedError() 44 45 46class SentenceTransformerHparamTrainer(Trainer): 47 " See Optuna documentation for types! " 48 model: SentenceTransformer 49 50 def __init__(self, dataset_loading_fn, evaluator_fn, hparams_config: HparamConfig): 51 super().__init__(SentenceTransformer, evaluator_fn, dataset_loading_fn) 52 self.loss_fn = None 53 self.hparams = hparams_config.parse_config_to_py() if hparams_config else None 54 55 def get_optuna_hparams(self, trial: optuna.Trial, hparams: Sequence[Hparam] = None): 56 """ 57 Get hyperparameters suggested by the optuna library 58 59 :param trial: The optuna trial object 60 :param hparams: Optional, pass a dictionary of HparamType[Enum] objects 61 :return: 62 """ 63 64 loguru.logger.info("Fitting the trainer.") 65 66 hparam_values = defaultdict(lambda: 0.0) 67 68 hparams = hparams if hparams else self.hparams 69 70 if hparams is None: 71 raise RuntimeError("No hyperparameters were specified") 72 73 for key, hparam in hparams.items(): 74 if hasattr(hparam, 'suggest'): 75 hparam_values[hparam.name] = hparam.suggest(trial) 76 loguru.logger.info(f"Using {hparam_values[hparam.name]} for {hparam.name}.") 77 else: 78 hparam_values[key] = hparam 79 80 return hparam_values 81 82 def build_kwargs_and_model(self, hparams: Dict): 83 kwargs = {} 84 85 for hparam, hparam_value in list(hparams.items()): 86 loguru.logger.info(f"Building model with {hparam}: {hparam_value}") 87 88 if hparam == "lr": 89 kwargs["optimizer_params"] = { 90 "lr": hparam_value 91 } 92 elif hparam == "model_name": 93 self.model = self.model_cls(hparam_value) 94 elif hparam == "optimizer": 95 kwargs["optimizer_class"] = getattr(OptimizersWrapper(), hparam_value) 96 elif hparam == "loss_fn": 97 self.loss_fn = getattr(losses, hparam_value) 98 else: 99 kwargs[hparam] = hparam_value 100 101 return kwargs 102 103 def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset): 104 hparams = self.get_optuna_hparams(in_trial) 105 kwargs = self.build_kwargs_and_model(hparams) 106 107 evaluator = self.evaluator_fn.from_input_examples(val_dataset) 108 loss = self.loss_fn(model=self.model) 109 train_dataloader = DataLoader(train_dataset, shuffle=True, 110 batch_size=int(kwargs.pop("batch_size")), drop_last=True) 111 112 self.model.fit( 113 train_objectives=[(train_dataloader, loss)], 114 **kwargs, 115 evaluator=evaluator, 116 use_amp=True, 117 callback=partial(trial_callback, in_trial) 118 ) 119 120 return self.model.evaluate(evaluator) 121 122 123def trial_callback(trial, score, epoch, *args, **kwargs): 124 trial.report(score, epoch) 125 # Handle pruning based on the intermediate value 126 if trial.should_prune(): 127 raise optuna.exceptions.TrialPruned() 128 129 130class SentenceTransformerTrainer(SentenceTransformerHparamTrainer): 131 def __init__(self, dataset: Union[DatasetDict, Dict[str, Dataset]], hparams_config: HparamConfig, 132 evaluator_fn=None, evaluator=None, use_wandb=False): 133 super().__init__(None, evaluator_fn, hparams_config) 134 self.evaluator = evaluator 135 self.use_wandb = use_wandb 136 self.dataset = dataset 137 138 def fit(self, **extra_kwargs): 139 kwargs = self.build_kwargs_and_model(self.hparams) 140 141 if not self.evaluator: 142 self.evaluator = LoggingEvaluator(self.evaluator_fn.from_input_examples(self.dataset['val']), wandb) 143 144 loss = self.loss_fn(model=self.model) 145 146 if self.use_wandb: 147 wandb.watch(self.model) 148 loss = LoggingLoss(loss, wandb) 149 150 train_dataloader = DataLoader(self.dataset['train'], shuffle=True, 151 batch_size=int(kwargs.pop("batch_size")), 152 drop_last=True) 153 154 self.model.fit( 155 train_objectives=[(train_dataloader, loss)], 156 **kwargs, 157 evaluator=self.evaluator, 158 use_amp=True, 159 **extra_kwargs 160 ) 161 162 return self.model.evaluate(self.evaluator)
class
OptimizersWrapper:
21class OptimizersWrapper: 22 def __getattr__(self, name): 23 if name in torch.optim.__dict__: 24 return getattr(torch.optim, name) 25 elif name in torch_optimizer.__dict__: 26 return getattr(torch_optimizer, name) 27 else: 28 raise ModuleNotFoundError("Optimizer is not implemented, doesn't exist or is not supported.")
class
Trainer:
31class Trainer: 32 """ 33 Wrapper class for a trainer class. 34 35 """ 36 37 def __init__(self, model, evaluator_fn, dataset_loading_fn): 38 self.evaluator_fn = evaluator_fn 39 self.model_cls = model # Trainer object or method we will initialize 40 self.dataset_loading_fn = dataset_loading_fn 41 42 @abc.abstractmethod 43 def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset): 44 raise NotImplementedError()
Wrapper class for a trainer class.
47class SentenceTransformerHparamTrainer(Trainer): 48 " See Optuna documentation for types! " 49 model: SentenceTransformer 50 51 def __init__(self, dataset_loading_fn, evaluator_fn, hparams_config: HparamConfig): 52 super().__init__(SentenceTransformer, evaluator_fn, dataset_loading_fn) 53 self.loss_fn = None 54 self.hparams = hparams_config.parse_config_to_py() if hparams_config else None 55 56 def get_optuna_hparams(self, trial: optuna.Trial, hparams: Sequence[Hparam] = None): 57 """ 58 Get hyperparameters suggested by the optuna library 59 60 :param trial: The optuna trial object 61 :param hparams: Optional, pass a dictionary of HparamType[Enum] objects 62 :return: 63 """ 64 65 loguru.logger.info("Fitting the trainer.") 66 67 hparam_values = defaultdict(lambda: 0.0) 68 69 hparams = hparams if hparams else self.hparams 70 71 if hparams is None: 72 raise RuntimeError("No hyperparameters were specified") 73 74 for key, hparam in hparams.items(): 75 if hasattr(hparam, 'suggest'): 76 hparam_values[hparam.name] = hparam.suggest(trial) 77 loguru.logger.info(f"Using {hparam_values[hparam.name]} for {hparam.name}.") 78 else: 79 hparam_values[key] = hparam 80 81 return hparam_values 82 83 def build_kwargs_and_model(self, hparams: Dict): 84 kwargs = {} 85 86 for hparam, hparam_value in list(hparams.items()): 87 loguru.logger.info(f"Building model with {hparam}: {hparam_value}") 88 89 if hparam == "lr": 90 kwargs["optimizer_params"] = { 91 "lr": hparam_value 92 } 93 elif hparam == "model_name": 94 self.model = self.model_cls(hparam_value) 95 elif hparam == "optimizer": 96 kwargs["optimizer_class"] = getattr(OptimizersWrapper(), hparam_value) 97 elif hparam == "loss_fn": 98 self.loss_fn = getattr(losses, hparam_value) 99 else: 100 kwargs[hparam] = hparam_value 101 102 return kwargs 103 104 def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset): 105 hparams = self.get_optuna_hparams(in_trial) 106 kwargs = self.build_kwargs_and_model(hparams) 107 108 evaluator = self.evaluator_fn.from_input_examples(val_dataset) 109 loss = self.loss_fn(model=self.model) 110 train_dataloader = DataLoader(train_dataset, shuffle=True, 111 batch_size=int(kwargs.pop("batch_size")), drop_last=True) 112 113 self.model.fit( 114 train_objectives=[(train_dataloader, loss)], 115 **kwargs, 116 evaluator=evaluator, 117 use_amp=True, 118 callback=partial(trial_callback, in_trial) 119 ) 120 121 return self.model.evaluate(evaluator)
See Optuna documentation for types!
SentenceTransformerHparamTrainer( dataset_loading_fn, evaluator_fn, hparams_config: debeir.training.hparm_tuning.config.HparamConfig)
def
get_optuna_hparams( self, trial: optuna.trial._trial.Trial, hparams: Sequence[debeir.training.hparm_tuning.types.Hparam] = None):
56 def get_optuna_hparams(self, trial: optuna.Trial, hparams: Sequence[Hparam] = None): 57 """ 58 Get hyperparameters suggested by the optuna library 59 60 :param trial: The optuna trial object 61 :param hparams: Optional, pass a dictionary of HparamType[Enum] objects 62 :return: 63 """ 64 65 loguru.logger.info("Fitting the trainer.") 66 67 hparam_values = defaultdict(lambda: 0.0) 68 69 hparams = hparams if hparams else self.hparams 70 71 if hparams is None: 72 raise RuntimeError("No hyperparameters were specified") 73 74 for key, hparam in hparams.items(): 75 if hasattr(hparam, 'suggest'): 76 hparam_values[hparam.name] = hparam.suggest(trial) 77 loguru.logger.info(f"Using {hparam_values[hparam.name]} for {hparam.name}.") 78 else: 79 hparam_values[key] = hparam 80 81 return hparam_values
Get hyperparameters suggested by the optuna library
Parameters
- trial: The optuna trial object
- hparams: Optional, pass a dictionary of HparamType[Enum] objects
Returns
def
build_kwargs_and_model(self, hparams: Dict):
83 def build_kwargs_and_model(self, hparams: Dict): 84 kwargs = {} 85 86 for hparam, hparam_value in list(hparams.items()): 87 loguru.logger.info(f"Building model with {hparam}: {hparam_value}") 88 89 if hparam == "lr": 90 kwargs["optimizer_params"] = { 91 "lr": hparam_value 92 } 93 elif hparam == "model_name": 94 self.model = self.model_cls(hparam_value) 95 elif hparam == "optimizer": 96 kwargs["optimizer_class"] = getattr(OptimizersWrapper(), hparam_value) 97 elif hparam == "loss_fn": 98 self.loss_fn = getattr(losses, hparam_value) 99 else: 100 kwargs[hparam] = hparam_value 101 102 return kwargs
def
fit( self, in_trial: optuna.trial._trial.Trial, train_dataset, val_dataset):
104 def fit(self, in_trial: optuna.Trial, train_dataset, val_dataset): 105 hparams = self.get_optuna_hparams(in_trial) 106 kwargs = self.build_kwargs_and_model(hparams) 107 108 evaluator = self.evaluator_fn.from_input_examples(val_dataset) 109 loss = self.loss_fn(model=self.model) 110 train_dataloader = DataLoader(train_dataset, shuffle=True, 111 batch_size=int(kwargs.pop("batch_size")), drop_last=True) 112 113 self.model.fit( 114 train_objectives=[(train_dataloader, loss)], 115 **kwargs, 116 evaluator=evaluator, 117 use_amp=True, 118 callback=partial(trial_callback, in_trial) 119 ) 120 121 return self.model.evaluate(evaluator)
def
trial_callback(trial, score, epoch, *args, **kwargs):
131class SentenceTransformerTrainer(SentenceTransformerHparamTrainer): 132 def __init__(self, dataset: Union[DatasetDict, Dict[str, Dataset]], hparams_config: HparamConfig, 133 evaluator_fn=None, evaluator=None, use_wandb=False): 134 super().__init__(None, evaluator_fn, hparams_config) 135 self.evaluator = evaluator 136 self.use_wandb = use_wandb 137 self.dataset = dataset 138 139 def fit(self, **extra_kwargs): 140 kwargs = self.build_kwargs_and_model(self.hparams) 141 142 if not self.evaluator: 143 self.evaluator = LoggingEvaluator(self.evaluator_fn.from_input_examples(self.dataset['val']), wandb) 144 145 loss = self.loss_fn(model=self.model) 146 147 if self.use_wandb: 148 wandb.watch(self.model) 149 loss = LoggingLoss(loss, wandb) 150 151 train_dataloader = DataLoader(self.dataset['train'], shuffle=True, 152 batch_size=int(kwargs.pop("batch_size")), 153 drop_last=True) 154 155 self.model.fit( 156 train_objectives=[(train_dataloader, loss)], 157 **kwargs, 158 evaluator=self.evaluator, 159 use_amp=True, 160 **extra_kwargs 161 ) 162 163 return self.model.evaluate(self.evaluator)
See Optuna documentation for types!
SentenceTransformerTrainer( dataset: Union[datasets.dataset_dict.DatasetDict, Dict[str, datasets.arrow_dataset.Dataset]], hparams_config: debeir.training.hparm_tuning.config.HparamConfig, evaluator_fn=None, evaluator=None, use_wandb=False)
132 def __init__(self, dataset: Union[DatasetDict, Dict[str, Dataset]], hparams_config: HparamConfig, 133 evaluator_fn=None, evaluator=None, use_wandb=False): 134 super().__init__(None, evaluator_fn, hparams_config) 135 self.evaluator = evaluator 136 self.use_wandb = use_wandb 137 self.dataset = dataset
def
fit(self, **extra_kwargs):
139 def fit(self, **extra_kwargs): 140 kwargs = self.build_kwargs_and_model(self.hparams) 141 142 if not self.evaluator: 143 self.evaluator = LoggingEvaluator(self.evaluator_fn.from_input_examples(self.dataset['val']), wandb) 144 145 loss = self.loss_fn(model=self.model) 146 147 if self.use_wandb: 148 wandb.watch(self.model) 149 loss = LoggingLoss(loss, wandb) 150 151 train_dataloader = DataLoader(self.dataset['train'], shuffle=True, 152 batch_size=int(kwargs.pop("batch_size")), 153 drop_last=True) 154 155 self.model.fit( 156 train_objectives=[(train_dataloader, loss)], 157 **kwargs, 158 evaluator=self.evaluator, 159 use_amp=True, 160 **extra_kwargs 161 ) 162 163 return self.model.evaluate(self.evaluator)