debeir.models.colbert
1import json 2import logging 3import os 4 5import torch 6from torch import nn 7from transformers import BertConfig, BertModel 8 9logger = logging.getLogger(__name__) 10 11ACT_FUNCS = { 12 "relu": nn.ReLU, 13} 14 15LOSS_FUNCS = { 16 'cross_entropy_loss': nn.CrossEntropyLoss, 17} 18 19 20class CoLBERTConfig(object): 21 default_fname = "colbert_config.json" 22 23 def __init__(self, **kwargs): 24 self.kwargs = kwargs 25 self.__dict__.update(kwargs) 26 27 def save(self, path, fname=default_fname): 28 """ 29 :param fname: file name 30 :param path: Path to save 31 """ 32 json.dump(self.kwargs, open(os.path.join(path, fname), 'w+')) 33 34 @classmethod 35 def load(cls, path, fname=default_fname): 36 """ 37 Load the ColBERT config from path (don't point to file name just directory) 38 :return ColBERTConfig: 39 """ 40 41 kwargs = json.load(open(os.path.join(path, fname))) 42 43 return CoLBERTConfig(**kwargs) 44 45 46class ConvolutionalBlock(nn.Module): 47 48 def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU): 49 super(ConvolutionalBlock, self).__init__() 50 51 padding = int((kernel_size - 1) / 2) 52 if kernel_size == 3: 53 assert padding == 1 # checks 54 if kernel_size == 5: 55 assert padding == 2 # checks 56 layers = [ 57 nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding), 58 nn.BatchNorm1d(num_features=out_channels) 59 ] 60 61 if act_func is not None: 62 layers.append(act_func()) 63 64 self.sequential = nn.Sequential(*layers) 65 66 def forward(self, x): 67 return self.sequential(x) 68 69 70class KMaxPool(nn.Module): 71 def __init__(self, k=1): 72 super(KMaxPool, self).__init__() 73 74 self.k = k 75 76 def forward(self, x): 77 # x : batch_size, channel, time_steps 78 if self.k == 'half': 79 time_steps = x.shape(2) 80 self.k = time_steps // 2 81 82 kmax, kargmax = torch.topk(x, self.k, sorted=True) 83 # kmax, kargmax = x.topk(self.k, dim=2) 84 return kmax 85 86 87def visualisation_dump(argmax, input_tensors): 88 pass 89 90 91class ResidualBlock(nn.Module): 92 93 def __init__(self, in_channels, out_channels, optional_shortcut=True, 94 kernel_size=1, act_func=nn.ReLU): 95 super(ResidualBlock, self).__init__() 96 self.optional_shortcut = optional_shortcut 97 self.convolutional_block = ConvolutionalBlock(in_channels, out_channels, first_stride=1, 98 act_func=act_func, kernel_size=kernel_size) 99 100 def forward(self, x): 101 residual = x 102 x = self.convolutional_block(x) 103 104 if self.optional_shortcut: 105 x = x + residual 106 107 return x 108 109 110class ColBERT(nn.Module): 111 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int 112 = 128, k: int = 8, 113 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 114 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 115 act_func="mish", loss_func='cross_entropy_loss', **kwargs): # kwargs for compat 116 117 super().__init__() 118 self.device = device 119 hidden_dim = config.hidden_size 120 self.seq_length = max_seq_len 121 self.use_trans_blocks = use_trans_blocks 122 self.use_batch_norms = use_batch_norms 123 self.num_layers = config.num_hidden_layers 124 num_labels = config.num_labels 125 self.loss_func = LOSS_FUNCS[loss_func.lower()]() 126 127 # Save our kwargs to reinitialise the model during evaluation 128 self.bert_config = config 129 self.colbert_config = CoLBERTConfig(k=k, 130 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, 131 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, 132 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, 133 act_func=act_func, bert_model_args=bert_model_args, 134 bert_model_kwargs=bert_model_kwargs) 135 136 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) 137 138 # relax this constraint later 139 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" 140 act_func = ACT_FUNCS[act_func.lower()] 141 142 # CNN Part 143 conv_layers = [] 144 transformation_blocks = [None] # Pad the first element, for the for loop in forward 145 batch_norms = [None] # Pad the first element 146 147 # Adds up to num_layers + 1 embedding layer 148 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 149 150 for i in range(self.num_layers): 151 # Create the residual blocks, batch_norms and transformation blocks 152 153 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, 154 kernel_size=residual_kernel_size, act_func=act_func)) 155 if use_trans_blocks: 156 transformation_blocks.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 157 if use_batch_norms: 158 batch_norms.append(nn.BatchNorm1d(hidden_dim)) 159 160 self.conv_layers = nn.ModuleList(conv_layers) 161 162 if use_trans_blocks: 163 self.transformation_blocks = nn.ModuleList(transformation_blocks) 164 if use_batch_norms: 165 self.batch_norms = nn.ModuleList(batch_norms) 166 167 self.kmax_pooling = KMaxPool(k) 168 169 # Create the MLP to compress the k signals 170 linear_layers = list() 171 linear_layers.append(nn.Linear(hidden_dim * k, num_labels)) # Downsample into Kmaxpool? 172 # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) 173 # linear_layers.append(nn.Dropout(dropout_perc)) 174 # linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 175 176 self.linear_layers = nn.Sequential(*linear_layers) 177 self.apply(weight_init) 178 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, 179 config=self.bert_config) # Add Bert model after random initialisation 180 181 for param in self.bert.pooler.parameters(): # We don't need the pooler 182 param.requires_grad = False 183 184 self.bert.to(self.device) 185 186 def forward(self, *args, **kwargs): 187 # input_ids: batch_size x seq_length x hidden_dim 188 labels = kwargs['labels'] if 'labels' in kwargs else None 189 if labels is not None: del kwargs['labels'] 190 191 bert_outputs = self.bert(*args, **kwargs) 192 hidden_states = bert_outputs[-1] 193 194 # Fix this, also draw out what ur model should do first 195 is_embedding_layer = True 196 197 assert len(self.conv_layers) == len( 198 hidden_states) # == len(self.transformation_blocks) == len(self.batch_norms), info 199 200 zip_args = [self.conv_layers, hidden_states] 201 identity = lambda k: k 202 203 if self.use_trans_blocks: 204 assert len(self.transformation_blocks) == len(hidden_states) 205 zip_args.append(self.transformation_blocks) 206 else: 207 zip_args.append([identity for i in range(self.num_layers + 1)]) 208 209 if self.use_batch_norms: 210 assert len(self.batch_norms) == len(hidden_states) 211 zip_args.append(self.batch_norms) 212 else: 213 zip_args.append([identity for i in range(self.num_layers + 1)]) 214 215 out = None 216 for co, hi, tr, bn in zip(*zip_args): 217 if is_embedding_layer: 218 out = co(hi.transpose(1, 2)) # batch x hidden x seq_len 219 is_embedding_layer = not is_embedding_layer 220 else: 221 out = co(out + tr(bn(hi.transpose(1, 2)))) # add hidden dims together 222 223 assert out.shape[2] == self.seq_length 224 225 out = self.kmax_pooling(out) 226 # batch_size x seq_len x hidden -> batch_size x flatten 227 logits = self.linear_layers(torch.flatten(out, start_dim=1)) 228 229 return self.loss_func(logits, labels), logits 230 231 @classmethod 232 def from_config(cls, *args, config_path): 233 kwargs = torch.load(config_path) 234 return ColBERT(*args, **kwargs) 235 236 @classmethod 237 def from_pretrained(cls, output_dir, **kwargs): 238 config_found = True 239 colbert_config = None 240 241 try: 242 colbert_config = CoLBERTConfig.load(output_dir) 243 except: 244 config_found = False 245 246 bert_config = None 247 248 if 'config' in kwargs: 249 bert_config = kwargs['config'] 250 del kwargs['config'] 251 else: 252 bert_config = BertConfig.from_pretrained(output_dir) 253 254 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 255 model = None 256 257 if config_found: 258 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) 259 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) 260 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") 261 262 else: 263 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) 264 logger.info(f"*** Create New CNN Bert Model ***") 265 266 return model 267 268 def save_pretrained(self, output_dir): 269 logger.info(f"*** Saved Bert Model Weights to {output_dir}") 270 self.bert.save_pretrained(output_dir) 271 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') 272 self.bert_config.save_pretrained(output_dir) 273 self.colbert_config.save(output_dir) 274 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}") 275 276 277class ComBERT(nn.Module): 278 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128, 279 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 280 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 281 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs): # kwargs for compat 282 283 super().__init__() 284 self.device = device 285 hidden_dim = config.hidden_size 286 self.seq_length = max_seq_len 287 self.use_trans_blocks = use_trans_blocks 288 self.use_batch_norms = use_batch_norms 289 self.num_layers = config.num_hidden_layers 290 num_labels = config.num_labels 291 self.num_blocks = num_blocks 292 self.loss_func = LOSS_FUNCS[loss_func.lower()]() 293 294 # Save our kwargs to reinitialise the model during evaluation 295 self.bert_config = config 296 self.colbert_config = CoLBERTConfig(k=k, 297 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, 298 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, 299 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, 300 act_func=act_func, bert_model_args=bert_model_args, 301 bert_model_kwargs=bert_model_kwargs) 302 303 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) 304 305 # relax this constraint later 306 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" 307 act_func = ACT_FUNCS[act_func.lower()] 308 309 # CNN Part 310 conv_layers = [] 311 312 # Adds up to num_layers + 1 embedding layer 313 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 314 315 for i in range(num_blocks): 316 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, 317 kernel_size=residual_kernel_size, act_func=act_func)) 318 319 self.conv_layers = nn.ModuleList(conv_layers) 320 self.kmax_pooling = KMaxPool(k) 321 322 # Create the MLP to compress the k signals 323 linear_layers = list() 324 linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons)) # Downsample into Kmaxpool? 325 linear_layers.append(nn.Dropout(dropout_perc)) 326 linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 327 328 self.linear_layers = nn.Sequential(*linear_layers) 329 self.apply(weight_init) 330 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, 331 config=self.bert_config) # Add Bert model after random initialisation 332 self.bert.to(self.device) 333 334 def forward(self, *args, **kwargs): 335 # input_ids: batch_size x seq_length x hidden_dim 336 337 labels = kwargs['labels'] if 'labels' in kwargs else None 338 if labels is not None: del kwargs['labels'] 339 340 bert_outputs = self.bert(*args, **kwargs) 341 hidden_states = list(bert_outputs[-1]) 342 embedding_layer = hidden_states.pop(0) 343 344 split_size = len(hidden_states) // self.num_blocks 345 346 assert split_size % 2 == 0, "must be an even number" 347 split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)] 348 split_layers.insert(0, embedding_layer) 349 350 assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length" 351 352 outputs = [] 353 354 for cnv, layer in zip(self.conv_layers, split_layers): 355 outputs.append(self.kmax_pooling(cnv(layer))) 356 357 # batch_size x seq_len x hidden -> batch_size x flatten 358 logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1)) 359 360 return self.loss_func(logits, labels), logits 361 362 @classmethod 363 def from_config(cls, *args, config_path): 364 kwargs = torch.load(config_path) 365 return ComBERT(*args, **kwargs) 366 367 @classmethod 368 def from_pretrained(cls, output_dir, **kwargs): 369 config_found = True 370 colbert_config = None 371 372 try: 373 colbert_config = CoLBERTConfig.load(output_dir) 374 except: 375 config_found = False 376 377 bert_config = None 378 379 if 'config' in kwargs: 380 bert_config = kwargs['config'] 381 del kwargs['config'] 382 else: 383 bert_config = BertConfig.from_pretrained(output_dir) 384 385 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 386 model = None 387 388 if config_found: 389 model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs) 390 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) 391 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") 392 393 else: 394 model = ComBERT((output_dir,), {}, config=bert_config, **kwargs) 395 logger.info(f"*** Create New CNN Bert Model ***") 396 397 return model 398 399 def save_pretrained(self, output_dir): 400 logger.info(f"*** Saved Bert Model Weights to {output_dir}") 401 self.bert.save_pretrained(output_dir) 402 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') 403 self.bert_config.save_pretrained(output_dir) 404 self.colbert_config.save(output_dir) 405 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
21class CoLBERTConfig(object): 22 default_fname = "colbert_config.json" 23 24 def __init__(self, **kwargs): 25 self.kwargs = kwargs 26 self.__dict__.update(kwargs) 27 28 def save(self, path, fname=default_fname): 29 """ 30 :param fname: file name 31 :param path: Path to save 32 """ 33 json.dump(self.kwargs, open(os.path.join(path, fname), 'w+')) 34 35 @classmethod 36 def load(cls, path, fname=default_fname): 37 """ 38 Load the ColBERT config from path (don't point to file name just directory) 39 :return ColBERTConfig: 40 """ 41 42 kwargs = json.load(open(os.path.join(path, fname))) 43 44 return CoLBERTConfig(**kwargs)
28 def save(self, path, fname=default_fname): 29 """ 30 :param fname: file name 31 :param path: Path to save 32 """ 33 json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
Parameters
- fname: file name
- path: Path to save
35 @classmethod 36 def load(cls, path, fname=default_fname): 37 """ 38 Load the ColBERT config from path (don't point to file name just directory) 39 :return ColBERTConfig: 40 """ 41 42 kwargs = json.load(open(os.path.join(path, fname))) 43 44 return CoLBERTConfig(**kwargs)
Load the ColBERT config from path (don't point to file name just directory)
Returns
47class ConvolutionalBlock(nn.Module): 48 49 def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU): 50 super(ConvolutionalBlock, self).__init__() 51 52 padding = int((kernel_size - 1) / 2) 53 if kernel_size == 3: 54 assert padding == 1 # checks 55 if kernel_size == 5: 56 assert padding == 2 # checks 57 layers = [ 58 nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding), 59 nn.BatchNorm1d(num_features=out_channels) 60 ] 61 62 if act_func is not None: 63 layers.append(act_func()) 64 65 self.sequential = nn.Sequential(*layers) 66 67 def forward(self, x): 68 return self.sequential(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
49 def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU): 50 super(ConvolutionalBlock, self).__init__() 51 52 padding = int((kernel_size - 1) / 2) 53 if kernel_size == 3: 54 assert padding == 1 # checks 55 if kernel_size == 5: 56 assert padding == 2 # checks 57 layers = [ 58 nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding), 59 nn.BatchNorm1d(num_features=out_channels) 60 ] 61 62 if act_func is not None: 63 layers.append(act_func()) 64 65 self.sequential = nn.Sequential(*layers)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
71class KMaxPool(nn.Module): 72 def __init__(self, k=1): 73 super(KMaxPool, self).__init__() 74 75 self.k = k 76 77 def forward(self, x): 78 # x : batch_size, channel, time_steps 79 if self.k == 'half': 80 time_steps = x.shape(2) 81 self.k = time_steps // 2 82 83 kmax, kargmax = torch.topk(x, self.k, sorted=True) 84 # kmax, kargmax = x.topk(self.k, dim=2) 85 return kmax
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
77 def forward(self, x): 78 # x : batch_size, channel, time_steps 79 if self.k == 'half': 80 time_steps = x.shape(2) 81 self.k = time_steps // 2 82 83 kmax, kargmax = torch.topk(x, self.k, sorted=True) 84 # kmax, kargmax = x.topk(self.k, dim=2) 85 return kmax
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
92class ResidualBlock(nn.Module): 93 94 def __init__(self, in_channels, out_channels, optional_shortcut=True, 95 kernel_size=1, act_func=nn.ReLU): 96 super(ResidualBlock, self).__init__() 97 self.optional_shortcut = optional_shortcut 98 self.convolutional_block = ConvolutionalBlock(in_channels, out_channels, first_stride=1, 99 act_func=act_func, kernel_size=kernel_size) 100 101 def forward(self, x): 102 residual = x 103 x = self.convolutional_block(x) 104 105 if self.optional_shortcut: 106 x = x + residual 107 108 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
94 def __init__(self, in_channels, out_channels, optional_shortcut=True, 95 kernel_size=1, act_func=nn.ReLU): 96 super(ResidualBlock, self).__init__() 97 self.optional_shortcut = optional_shortcut 98 self.convolutional_block = ConvolutionalBlock(in_channels, out_channels, first_stride=1, 99 act_func=act_func, kernel_size=kernel_size)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
101 def forward(self, x): 102 residual = x 103 x = self.convolutional_block(x) 104 105 if self.optional_shortcut: 106 x = x + residual 107 108 return x
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
111class ColBERT(nn.Module): 112 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int 113 = 128, k: int = 8, 114 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 115 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 116 act_func="mish", loss_func='cross_entropy_loss', **kwargs): # kwargs for compat 117 118 super().__init__() 119 self.device = device 120 hidden_dim = config.hidden_size 121 self.seq_length = max_seq_len 122 self.use_trans_blocks = use_trans_blocks 123 self.use_batch_norms = use_batch_norms 124 self.num_layers = config.num_hidden_layers 125 num_labels = config.num_labels 126 self.loss_func = LOSS_FUNCS[loss_func.lower()]() 127 128 # Save our kwargs to reinitialise the model during evaluation 129 self.bert_config = config 130 self.colbert_config = CoLBERTConfig(k=k, 131 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, 132 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, 133 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, 134 act_func=act_func, bert_model_args=bert_model_args, 135 bert_model_kwargs=bert_model_kwargs) 136 137 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) 138 139 # relax this constraint later 140 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" 141 act_func = ACT_FUNCS[act_func.lower()] 142 143 # CNN Part 144 conv_layers = [] 145 transformation_blocks = [None] # Pad the first element, for the for loop in forward 146 batch_norms = [None] # Pad the first element 147 148 # Adds up to num_layers + 1 embedding layer 149 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 150 151 for i in range(self.num_layers): 152 # Create the residual blocks, batch_norms and transformation blocks 153 154 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, 155 kernel_size=residual_kernel_size, act_func=act_func)) 156 if use_trans_blocks: 157 transformation_blocks.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 158 if use_batch_norms: 159 batch_norms.append(nn.BatchNorm1d(hidden_dim)) 160 161 self.conv_layers = nn.ModuleList(conv_layers) 162 163 if use_trans_blocks: 164 self.transformation_blocks = nn.ModuleList(transformation_blocks) 165 if use_batch_norms: 166 self.batch_norms = nn.ModuleList(batch_norms) 167 168 self.kmax_pooling = KMaxPool(k) 169 170 # Create the MLP to compress the k signals 171 linear_layers = list() 172 linear_layers.append(nn.Linear(hidden_dim * k, num_labels)) # Downsample into Kmaxpool? 173 # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) 174 # linear_layers.append(nn.Dropout(dropout_perc)) 175 # linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 176 177 self.linear_layers = nn.Sequential(*linear_layers) 178 self.apply(weight_init) 179 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, 180 config=self.bert_config) # Add Bert model after random initialisation 181 182 for param in self.bert.pooler.parameters(): # We don't need the pooler 183 param.requires_grad = False 184 185 self.bert.to(self.device) 186 187 def forward(self, *args, **kwargs): 188 # input_ids: batch_size x seq_length x hidden_dim 189 labels = kwargs['labels'] if 'labels' in kwargs else None 190 if labels is not None: del kwargs['labels'] 191 192 bert_outputs = self.bert(*args, **kwargs) 193 hidden_states = bert_outputs[-1] 194 195 # Fix this, also draw out what ur model should do first 196 is_embedding_layer = True 197 198 assert len(self.conv_layers) == len( 199 hidden_states) # == len(self.transformation_blocks) == len(self.batch_norms), info 200 201 zip_args = [self.conv_layers, hidden_states] 202 identity = lambda k: k 203 204 if self.use_trans_blocks: 205 assert len(self.transformation_blocks) == len(hidden_states) 206 zip_args.append(self.transformation_blocks) 207 else: 208 zip_args.append([identity for i in range(self.num_layers + 1)]) 209 210 if self.use_batch_norms: 211 assert len(self.batch_norms) == len(hidden_states) 212 zip_args.append(self.batch_norms) 213 else: 214 zip_args.append([identity for i in range(self.num_layers + 1)]) 215 216 out = None 217 for co, hi, tr, bn in zip(*zip_args): 218 if is_embedding_layer: 219 out = co(hi.transpose(1, 2)) # batch x hidden x seq_len 220 is_embedding_layer = not is_embedding_layer 221 else: 222 out = co(out + tr(bn(hi.transpose(1, 2)))) # add hidden dims together 223 224 assert out.shape[2] == self.seq_length 225 226 out = self.kmax_pooling(out) 227 # batch_size x seq_len x hidden -> batch_size x flatten 228 logits = self.linear_layers(torch.flatten(out, start_dim=1)) 229 230 return self.loss_func(logits, labels), logits 231 232 @classmethod 233 def from_config(cls, *args, config_path): 234 kwargs = torch.load(config_path) 235 return ColBERT(*args, **kwargs) 236 237 @classmethod 238 def from_pretrained(cls, output_dir, **kwargs): 239 config_found = True 240 colbert_config = None 241 242 try: 243 colbert_config = CoLBERTConfig.load(output_dir) 244 except: 245 config_found = False 246 247 bert_config = None 248 249 if 'config' in kwargs: 250 bert_config = kwargs['config'] 251 del kwargs['config'] 252 else: 253 bert_config = BertConfig.from_pretrained(output_dir) 254 255 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 256 model = None 257 258 if config_found: 259 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) 260 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) 261 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") 262 263 else: 264 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) 265 logger.info(f"*** Create New CNN Bert Model ***") 266 267 return model 268 269 def save_pretrained(self, output_dir): 270 logger.info(f"*** Saved Bert Model Weights to {output_dir}") 271 self.bert.save_pretrained(output_dir) 272 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') 273 self.bert_config.save_pretrained(output_dir) 274 self.colbert_config.save(output_dir) 275 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
112 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int 113 = 128, k: int = 8, 114 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 115 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 116 act_func="mish", loss_func='cross_entropy_loss', **kwargs): # kwargs for compat 117 118 super().__init__() 119 self.device = device 120 hidden_dim = config.hidden_size 121 self.seq_length = max_seq_len 122 self.use_trans_blocks = use_trans_blocks 123 self.use_batch_norms = use_batch_norms 124 self.num_layers = config.num_hidden_layers 125 num_labels = config.num_labels 126 self.loss_func = LOSS_FUNCS[loss_func.lower()]() 127 128 # Save our kwargs to reinitialise the model during evaluation 129 self.bert_config = config 130 self.colbert_config = CoLBERTConfig(k=k, 131 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, 132 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, 133 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, 134 act_func=act_func, bert_model_args=bert_model_args, 135 bert_model_kwargs=bert_model_kwargs) 136 137 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) 138 139 # relax this constraint later 140 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" 141 act_func = ACT_FUNCS[act_func.lower()] 142 143 # CNN Part 144 conv_layers = [] 145 transformation_blocks = [None] # Pad the first element, for the for loop in forward 146 batch_norms = [None] # Pad the first element 147 148 # Adds up to num_layers + 1 embedding layer 149 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 150 151 for i in range(self.num_layers): 152 # Create the residual blocks, batch_norms and transformation blocks 153 154 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, 155 kernel_size=residual_kernel_size, act_func=act_func)) 156 if use_trans_blocks: 157 transformation_blocks.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 158 if use_batch_norms: 159 batch_norms.append(nn.BatchNorm1d(hidden_dim)) 160 161 self.conv_layers = nn.ModuleList(conv_layers) 162 163 if use_trans_blocks: 164 self.transformation_blocks = nn.ModuleList(transformation_blocks) 165 if use_batch_norms: 166 self.batch_norms = nn.ModuleList(batch_norms) 167 168 self.kmax_pooling = KMaxPool(k) 169 170 # Create the MLP to compress the k signals 171 linear_layers = list() 172 linear_layers.append(nn.Linear(hidden_dim * k, num_labels)) # Downsample into Kmaxpool? 173 # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons)) 174 # linear_layers.append(nn.Dropout(dropout_perc)) 175 # linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 176 177 self.linear_layers = nn.Sequential(*linear_layers) 178 self.apply(weight_init) 179 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, 180 config=self.bert_config) # Add Bert model after random initialisation 181 182 for param in self.bert.pooler.parameters(): # We don't need the pooler 183 param.requires_grad = False 184 185 self.bert.to(self.device)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
187 def forward(self, *args, **kwargs): 188 # input_ids: batch_size x seq_length x hidden_dim 189 labels = kwargs['labels'] if 'labels' in kwargs else None 190 if labels is not None: del kwargs['labels'] 191 192 bert_outputs = self.bert(*args, **kwargs) 193 hidden_states = bert_outputs[-1] 194 195 # Fix this, also draw out what ur model should do first 196 is_embedding_layer = True 197 198 assert len(self.conv_layers) == len( 199 hidden_states) # == len(self.transformation_blocks) == len(self.batch_norms), info 200 201 zip_args = [self.conv_layers, hidden_states] 202 identity = lambda k: k 203 204 if self.use_trans_blocks: 205 assert len(self.transformation_blocks) == len(hidden_states) 206 zip_args.append(self.transformation_blocks) 207 else: 208 zip_args.append([identity for i in range(self.num_layers + 1)]) 209 210 if self.use_batch_norms: 211 assert len(self.batch_norms) == len(hidden_states) 212 zip_args.append(self.batch_norms) 213 else: 214 zip_args.append([identity for i in range(self.num_layers + 1)]) 215 216 out = None 217 for co, hi, tr, bn in zip(*zip_args): 218 if is_embedding_layer: 219 out = co(hi.transpose(1, 2)) # batch x hidden x seq_len 220 is_embedding_layer = not is_embedding_layer 221 else: 222 out = co(out + tr(bn(hi.transpose(1, 2)))) # add hidden dims together 223 224 assert out.shape[2] == self.seq_length 225 226 out = self.kmax_pooling(out) 227 # batch_size x seq_len x hidden -> batch_size x flatten 228 logits = self.linear_layers(torch.flatten(out, start_dim=1)) 229 230 return self.loss_func(logits, labels), logits
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
237 @classmethod 238 def from_pretrained(cls, output_dir, **kwargs): 239 config_found = True 240 colbert_config = None 241 242 try: 243 colbert_config = CoLBERTConfig.load(output_dir) 244 except: 245 config_found = False 246 247 bert_config = None 248 249 if 'config' in kwargs: 250 bert_config = kwargs['config'] 251 del kwargs['config'] 252 else: 253 bert_config = BertConfig.from_pretrained(output_dir) 254 255 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 256 model = None 257 258 if config_found: 259 model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs) 260 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) 261 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") 262 263 else: 264 model = ColBERT((output_dir,), {}, config=bert_config, **kwargs) 265 logger.info(f"*** Create New CNN Bert Model ***") 266 267 return model
269 def save_pretrained(self, output_dir): 270 logger.info(f"*** Saved Bert Model Weights to {output_dir}") 271 self.bert.save_pretrained(output_dir) 272 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') 273 self.bert_config.save_pretrained(output_dir) 274 self.colbert_config.save(output_dir) 275 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
278class ComBERT(nn.Module): 279 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128, 280 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 281 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 282 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs): # kwargs for compat 283 284 super().__init__() 285 self.device = device 286 hidden_dim = config.hidden_size 287 self.seq_length = max_seq_len 288 self.use_trans_blocks = use_trans_blocks 289 self.use_batch_norms = use_batch_norms 290 self.num_layers = config.num_hidden_layers 291 num_labels = config.num_labels 292 self.num_blocks = num_blocks 293 self.loss_func = LOSS_FUNCS[loss_func.lower()]() 294 295 # Save our kwargs to reinitialise the model during evaluation 296 self.bert_config = config 297 self.colbert_config = CoLBERTConfig(k=k, 298 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, 299 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, 300 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, 301 act_func=act_func, bert_model_args=bert_model_args, 302 bert_model_kwargs=bert_model_kwargs) 303 304 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) 305 306 # relax this constraint later 307 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" 308 act_func = ACT_FUNCS[act_func.lower()] 309 310 # CNN Part 311 conv_layers = [] 312 313 # Adds up to num_layers + 1 embedding layer 314 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 315 316 for i in range(num_blocks): 317 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, 318 kernel_size=residual_kernel_size, act_func=act_func)) 319 320 self.conv_layers = nn.ModuleList(conv_layers) 321 self.kmax_pooling = KMaxPool(k) 322 323 # Create the MLP to compress the k signals 324 linear_layers = list() 325 linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons)) # Downsample into Kmaxpool? 326 linear_layers.append(nn.Dropout(dropout_perc)) 327 linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 328 329 self.linear_layers = nn.Sequential(*linear_layers) 330 self.apply(weight_init) 331 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, 332 config=self.bert_config) # Add Bert model after random initialisation 333 self.bert.to(self.device) 334 335 def forward(self, *args, **kwargs): 336 # input_ids: batch_size x seq_length x hidden_dim 337 338 labels = kwargs['labels'] if 'labels' in kwargs else None 339 if labels is not None: del kwargs['labels'] 340 341 bert_outputs = self.bert(*args, **kwargs) 342 hidden_states = list(bert_outputs[-1]) 343 embedding_layer = hidden_states.pop(0) 344 345 split_size = len(hidden_states) // self.num_blocks 346 347 assert split_size % 2 == 0, "must be an even number" 348 split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)] 349 split_layers.insert(0, embedding_layer) 350 351 assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length" 352 353 outputs = [] 354 355 for cnv, layer in zip(self.conv_layers, split_layers): 356 outputs.append(self.kmax_pooling(cnv(layer))) 357 358 # batch_size x seq_len x hidden -> batch_size x flatten 359 logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1)) 360 361 return self.loss_func(logits, labels), logits 362 363 @classmethod 364 def from_config(cls, *args, config_path): 365 kwargs = torch.load(config_path) 366 return ComBERT(*args, **kwargs) 367 368 @classmethod 369 def from_pretrained(cls, output_dir, **kwargs): 370 config_found = True 371 colbert_config = None 372 373 try: 374 colbert_config = CoLBERTConfig.load(output_dir) 375 except: 376 config_found = False 377 378 bert_config = None 379 380 if 'config' in kwargs: 381 bert_config = kwargs['config'] 382 del kwargs['config'] 383 else: 384 bert_config = BertConfig.from_pretrained(output_dir) 385 386 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 387 model = None 388 389 if config_found: 390 model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs) 391 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) 392 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") 393 394 else: 395 model = ComBERT((output_dir,), {}, config=bert_config, **kwargs) 396 logger.info(f"*** Create New CNN Bert Model ***") 397 398 return model 399 400 def save_pretrained(self, output_dir): 401 logger.info(f"*** Saved Bert Model Weights to {output_dir}") 402 self.bert.save_pretrained(output_dir) 403 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') 404 self.bert_config.save_pretrained(output_dir) 405 self.colbert_config.save(output_dir) 406 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
279 def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128, 280 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, 281 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, 282 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs): # kwargs for compat 283 284 super().__init__() 285 self.device = device 286 hidden_dim = config.hidden_size 287 self.seq_length = max_seq_len 288 self.use_trans_blocks = use_trans_blocks 289 self.use_batch_norms = use_batch_norms 290 self.num_layers = config.num_hidden_layers 291 num_labels = config.num_labels 292 self.num_blocks = num_blocks 293 self.loss_func = LOSS_FUNCS[loss_func.lower()]() 294 295 # Save our kwargs to reinitialise the model during evaluation 296 self.bert_config = config 297 self.colbert_config = CoLBERTConfig(k=k, 298 optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons, 299 use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks, 300 residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc, 301 act_func=act_func, bert_model_args=bert_model_args, 302 bert_model_kwargs=bert_model_kwargs) 303 304 logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs)) 305 306 # relax this constraint later 307 assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}" 308 act_func = ACT_FUNCS[act_func.lower()] 309 310 # CNN Part 311 conv_layers = [] 312 313 # Adds up to num_layers + 1 embedding layer 314 conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1)) 315 316 for i in range(num_blocks): 317 conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut, 318 kernel_size=residual_kernel_size, act_func=act_func)) 319 320 self.conv_layers = nn.ModuleList(conv_layers) 321 self.kmax_pooling = KMaxPool(k) 322 323 # Create the MLP to compress the k signals 324 linear_layers = list() 325 linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons)) # Downsample into Kmaxpool? 326 linear_layers.append(nn.Dropout(dropout_perc)) 327 linear_layers.append(nn.Linear(hidden_neurons, num_labels)) 328 329 self.linear_layers = nn.Sequential(*linear_layers) 330 self.apply(weight_init) 331 self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs, 332 config=self.bert_config) # Add Bert model after random initialisation 333 self.bert.to(self.device)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
335 def forward(self, *args, **kwargs): 336 # input_ids: batch_size x seq_length x hidden_dim 337 338 labels = kwargs['labels'] if 'labels' in kwargs else None 339 if labels is not None: del kwargs['labels'] 340 341 bert_outputs = self.bert(*args, **kwargs) 342 hidden_states = list(bert_outputs[-1]) 343 embedding_layer = hidden_states.pop(0) 344 345 split_size = len(hidden_states) // self.num_blocks 346 347 assert split_size % 2 == 0, "must be an even number" 348 split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)] 349 split_layers.insert(0, embedding_layer) 350 351 assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length" 352 353 outputs = [] 354 355 for cnv, layer in zip(self.conv_layers, split_layers): 356 outputs.append(self.kmax_pooling(cnv(layer))) 357 358 # batch_size x seq_len x hidden -> batch_size x flatten 359 logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1)) 360 361 return self.loss_func(logits, labels), logits
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
368 @classmethod 369 def from_pretrained(cls, output_dir, **kwargs): 370 config_found = True 371 colbert_config = None 372 373 try: 374 colbert_config = CoLBERTConfig.load(output_dir) 375 except: 376 config_found = False 377 378 bert_config = None 379 380 if 'config' in kwargs: 381 bert_config = kwargs['config'] 382 del kwargs['config'] 383 else: 384 bert_config = BertConfig.from_pretrained(output_dir) 385 386 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 387 model = None 388 389 if config_found: 390 model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs) 391 model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth')) 392 logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}") 393 394 else: 395 model = ComBERT((output_dir,), {}, config=bert_config, **kwargs) 396 logger.info(f"*** Create New CNN Bert Model ***") 397 398 return model
400 def save_pretrained(self, output_dir): 401 logger.info(f"*** Saved Bert Model Weights to {output_dir}") 402 self.bert.save_pretrained(output_dir) 403 torch.save(self.state_dict(), output_dir + '/cnn_bert.pth') 404 self.bert_config.save_pretrained(output_dir) 405 self.colbert_config.save(output_dir) 406 logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr