debeir.training.hparm_tuning.types
1import abc 2import dataclasses 3from typing import Sequence 4 5import optuna 6 7 8class Hparam: 9 name: str 10 11 @abc.abstractmethod 12 def suggest(self, *args, **kwargs): 13 raise NotImplementedError() 14 15 16@dataclasses.dataclass(init=True) 17class HparamFloat(Hparam): 18 name: str 19 low: float 20 high: float 21 log: bool = False 22 step: float = None 23 24 def suggest(self, trial: optuna.Trial): 25 return trial.suggest_float(self.name, self.low, self.high, step=self.step, log=self.log) 26 27 28@dataclasses.dataclass(init=True) 29class HparamInt(Hparam): 30 name: str 31 low: int 32 high: int 33 log: bool = False 34 step: int = 1 35 36 def suggest(self, trial: optuna.Trial): 37 return trial.suggest_int(self.name, self.low, self.high, step=self.step, log=self.log) 38 39 40@dataclasses.dataclass(init=True) 41class HparamCategorical(Hparam): 42 name: str 43 choices: Sequence 44 func: str = "suggest_categorical" 45 46 def suggest(self, trial: optuna.Trial): 47 return trial.suggest_categorical(self.name, self.choices) 48 49 50@dataclasses.dataclass(init=True) 51class HparamUniform(Hparam): 52 name: str 53 low: float 54 high: float 55 func: str = "suggest_uniform" 56 57 def suggest(self, trial: optuna.Trial): 58 return trial.suggest_uniform(self.name, self.low, self.high) 59 60 61@dataclasses.dataclass(init=True) 62class HparamLogUniform(Hparam): 63 name: str 64 low: float 65 high: float 66 func: str = "suggest_loguniform" 67 68 def suggest(self, trial: optuna.Trial): 69 return trial.suggest_loguniform(self.name, self.low, self.high) 70 71 72@dataclasses.dataclass(init=True) 73class HparamDiscreteUniform(Hparam): 74 name: str 75 low: float 76 high: float 77 q: float 78 func: str = "suggest_discrete_uniform" 79 80 def suggest(self, trial: optuna.Trial): 81 return trial.suggest_discrete_uniform(self.name, self.low, self.high, self.q) 82 83 84HparamTypes = { 85 "float": HparamFloat, 86 "int": HparamInt, 87 "categorical": HparamCategorical, 88 "uniform": HparamUniform, 89 "loguniform": HparamLogUniform, 90 "discrete_uniform": HparamDiscreteUniform 91}
class
Hparam:
73@dataclasses.dataclass(init=True) 74class HparamDiscreteUniform(Hparam): 75 name: str 76 low: float 77 high: float 78 q: float 79 func: str = "suggest_discrete_uniform" 80 81 def suggest(self, trial: optuna.Trial): 82 return trial.suggest_discrete_uniform(self.name, self.low, self.high, self.q)