debeir.training.hparm_tuning.types

 1import abc
 2import dataclasses
 3from typing import Sequence
 4
 5import optuna
 6
 7
 8class Hparam:
 9    name: str
10
11    @abc.abstractmethod
12    def suggest(self, *args, **kwargs):
13        raise NotImplementedError()
14
15
16@dataclasses.dataclass(init=True)
17class HparamFloat(Hparam):
18    name: str
19    low: float
20    high: float
21    log: bool = False
22    step: float = None
23
24    def suggest(self, trial: optuna.Trial):
25        return trial.suggest_float(self.name, self.low, self.high, step=self.step, log=self.log)
26
27
28@dataclasses.dataclass(init=True)
29class HparamInt(Hparam):
30    name: str
31    low: int
32    high: int
33    log: bool = False
34    step: int = 1
35
36    def suggest(self, trial: optuna.Trial):
37        return trial.suggest_int(self.name, self.low, self.high, step=self.step, log=self.log)
38
39
40@dataclasses.dataclass(init=True)
41class HparamCategorical(Hparam):
42    name: str
43    choices: Sequence
44    func: str = "suggest_categorical"
45
46    def suggest(self, trial: optuna.Trial):
47        return trial.suggest_categorical(self.name, self.choices)
48
49
50@dataclasses.dataclass(init=True)
51class HparamUniform(Hparam):
52    name: str
53    low: float
54    high: float
55    func: str = "suggest_uniform"
56
57    def suggest(self, trial: optuna.Trial):
58        return trial.suggest_uniform(self.name, self.low, self.high)
59
60
61@dataclasses.dataclass(init=True)
62class HparamLogUniform(Hparam):
63    name: str
64    low: float
65    high: float
66    func: str = "suggest_loguniform"
67
68    def suggest(self, trial: optuna.Trial):
69        return trial.suggest_loguniform(self.name, self.low, self.high)
70
71
72@dataclasses.dataclass(init=True)
73class HparamDiscreteUniform(Hparam):
74    name: str
75    low: float
76    high: float
77    q: float
78    func: str = "suggest_discrete_uniform"
79
80    def suggest(self, trial: optuna.Trial):
81        return trial.suggest_discrete_uniform(self.name, self.low, self.high, self.q)
82
83
84HparamTypes = {
85    "float": HparamFloat,
86    "int": HparamInt,
87    "categorical": HparamCategorical,
88    "uniform": HparamUniform,
89    "loguniform": HparamLogUniform,
90    "discrete_uniform": HparamDiscreteUniform
91}
class Hparam:
 9class Hparam:
10    name: str
11
12    @abc.abstractmethod
13    def suggest(self, *args, **kwargs):
14        raise NotImplementedError()
Hparam()
@abc.abstractmethod
def suggest(self, *args, **kwargs):
12    @abc.abstractmethod
13    def suggest(self, *args, **kwargs):
14        raise NotImplementedError()
@dataclasses.dataclass(init=True)
class HparamFloat(Hparam):
17@dataclasses.dataclass(init=True)
18class HparamFloat(Hparam):
19    name: str
20    low: float
21    high: float
22    log: bool = False
23    step: float = None
24
25    def suggest(self, trial: optuna.Trial):
26        return trial.suggest_float(self.name, self.low, self.high, step=self.step, log=self.log)
HparamFloat( name: str, low: float, high: float, log: bool = False, step: float = None)
def suggest(self, trial: optuna.trial._trial.Trial):
25    def suggest(self, trial: optuna.Trial):
26        return trial.suggest_float(self.name, self.low, self.high, step=self.step, log=self.log)
@dataclasses.dataclass(init=True)
class HparamInt(Hparam):
29@dataclasses.dataclass(init=True)
30class HparamInt(Hparam):
31    name: str
32    low: int
33    high: int
34    log: bool = False
35    step: int = 1
36
37    def suggest(self, trial: optuna.Trial):
38        return trial.suggest_int(self.name, self.low, self.high, step=self.step, log=self.log)
HparamInt(name: str, low: int, high: int, log: bool = False, step: int = 1)
def suggest(self, trial: optuna.trial._trial.Trial):
37    def suggest(self, trial: optuna.Trial):
38        return trial.suggest_int(self.name, self.low, self.high, step=self.step, log=self.log)
@dataclasses.dataclass(init=True)
class HparamCategorical(Hparam):
41@dataclasses.dataclass(init=True)
42class HparamCategorical(Hparam):
43    name: str
44    choices: Sequence
45    func: str = "suggest_categorical"
46
47    def suggest(self, trial: optuna.Trial):
48        return trial.suggest_categorical(self.name, self.choices)
HparamCategorical(name: str, choices: Sequence, func: str = 'suggest_categorical')
def suggest(self, trial: optuna.trial._trial.Trial):
47    def suggest(self, trial: optuna.Trial):
48        return trial.suggest_categorical(self.name, self.choices)
@dataclasses.dataclass(init=True)
class HparamUniform(Hparam):
51@dataclasses.dataclass(init=True)
52class HparamUniform(Hparam):
53    name: str
54    low: float
55    high: float
56    func: str = "suggest_uniform"
57
58    def suggest(self, trial: optuna.Trial):
59        return trial.suggest_uniform(self.name, self.low, self.high)
HparamUniform(name: str, low: float, high: float, func: str = 'suggest_uniform')
def suggest(self, trial: optuna.trial._trial.Trial):
58    def suggest(self, trial: optuna.Trial):
59        return trial.suggest_uniform(self.name, self.low, self.high)
@dataclasses.dataclass(init=True)
class HparamLogUniform(Hparam):
62@dataclasses.dataclass(init=True)
63class HparamLogUniform(Hparam):
64    name: str
65    low: float
66    high: float
67    func: str = "suggest_loguniform"
68
69    def suggest(self, trial: optuna.Trial):
70        return trial.suggest_loguniform(self.name, self.low, self.high)
HparamLogUniform(name: str, low: float, high: float, func: str = 'suggest_loguniform')
def suggest(self, trial: optuna.trial._trial.Trial):
69    def suggest(self, trial: optuna.Trial):
70        return trial.suggest_loguniform(self.name, self.low, self.high)
@dataclasses.dataclass(init=True)
class HparamDiscreteUniform(Hparam):
73@dataclasses.dataclass(init=True)
74class HparamDiscreteUniform(Hparam):
75    name: str
76    low: float
77    high: float
78    q: float
79    func: str = "suggest_discrete_uniform"
80
81    def suggest(self, trial: optuna.Trial):
82        return trial.suggest_discrete_uniform(self.name, self.low, self.high, self.q)
HparamDiscreteUniform( name: str, low: float, high: float, q: float, func: str = 'suggest_discrete_uniform')
def suggest(self, trial: optuna.trial._trial.Trial):
81    def suggest(self, trial: optuna.Trial):
82        return trial.suggest_discrete_uniform(self.name, self.low, self.high, self.q)