debeir.models.colbert

  1import json
  2import logging
  3import os
  4
  5import torch
  6from torch import nn
  7from transformers import BertConfig, BertModel
  8
  9logger = logging.getLogger(__name__)
 10
 11ACT_FUNCS = {
 12    "relu": nn.ReLU,
 13}
 14
 15LOSS_FUNCS = {
 16    'cross_entropy_loss': nn.CrossEntropyLoss,
 17}
 18
 19
 20class CoLBERTConfig(object):
 21    default_fname = "colbert_config.json"
 22
 23    def __init__(self, **kwargs):
 24        self.kwargs = kwargs
 25        self.__dict__.update(kwargs)
 26
 27    def save(self, path, fname=default_fname):
 28        """
 29        :param fname: file name
 30        :param path: Path to save
 31        """
 32        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
 33
 34    @classmethod
 35    def load(cls, path, fname=default_fname):
 36        """
 37        Load the ColBERT config from path (don't point to file name just directory)
 38        :return ColBERTConfig:
 39        """
 40
 41        kwargs = json.load(open(os.path.join(path, fname)))
 42
 43        return CoLBERTConfig(**kwargs)
 44
 45
 46class ConvolutionalBlock(nn.Module):
 47
 48    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
 49        super(ConvolutionalBlock, self).__init__()
 50
 51        padding = int((kernel_size - 1) / 2)
 52        if kernel_size == 3:
 53            assert padding == 1  # checks
 54        if kernel_size == 5:
 55            assert padding == 2  # checks
 56        layers = [
 57            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
 58            nn.BatchNorm1d(num_features=out_channels)
 59        ]
 60
 61        if act_func is not None:
 62            layers.append(act_func())
 63
 64        self.sequential = nn.Sequential(*layers)
 65
 66    def forward(self, x):
 67        return self.sequential(x)
 68
 69
 70class KMaxPool(nn.Module):
 71    def __init__(self, k=1):
 72        super(KMaxPool, self).__init__()
 73
 74        self.k = k
 75
 76    def forward(self, x):
 77        # x : batch_size, channel, time_steps
 78        if self.k == 'half':
 79            time_steps = x.shape(2)
 80            self.k = time_steps // 2
 81
 82        kmax, kargmax = torch.topk(x, self.k, sorted=True)
 83        # kmax, kargmax = x.topk(self.k, dim=2)
 84        return kmax
 85
 86
 87def visualisation_dump(argmax, input_tensors):
 88    pass
 89
 90
 91class ResidualBlock(nn.Module):
 92
 93    def __init__(self, in_channels, out_channels, optional_shortcut=True,
 94                 kernel_size=1, act_func=nn.ReLU):
 95        super(ResidualBlock, self).__init__()
 96        self.optional_shortcut = optional_shortcut
 97        self.convolutional_block = ConvolutionalBlock(in_channels, out_channels, first_stride=1,
 98                                                      act_func=act_func, kernel_size=kernel_size)
 99
100    def forward(self, x):
101        residual = x
102        x = self.convolutional_block(x)
103
104        if self.optional_shortcut:
105            x = x + residual
106
107        return x
108
109
110class ColBERT(nn.Module):
111    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int
112    = 128, k: int = 8,
113                 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
114                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
115                 act_func="mish", loss_func='cross_entropy_loss', **kwargs):  # kwargs for compat
116
117        super().__init__()
118        self.device = device
119        hidden_dim = config.hidden_size
120        self.seq_length = max_seq_len
121        self.use_trans_blocks = use_trans_blocks
122        self.use_batch_norms = use_batch_norms
123        self.num_layers = config.num_hidden_layers
124        num_labels = config.num_labels
125        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
126
127        # Save our kwargs to reinitialise the model during evaluation
128        self.bert_config = config
129        self.colbert_config = CoLBERTConfig(k=k,
130                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
131                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
132                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
133                                            act_func=act_func, bert_model_args=bert_model_args,
134                                            bert_model_kwargs=bert_model_kwargs)
135
136        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
137
138        # relax this constraint later
139        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
140        act_func = ACT_FUNCS[act_func.lower()]
141
142        # CNN Part
143        conv_layers = []
144        transformation_blocks = [None]  # Pad the first element, for the for loop in forward
145        batch_norms = [None]  # Pad the first element
146
147        # Adds up to num_layers + 1 embedding layer
148        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
149
150        for i in range(self.num_layers):
151            # Create the residual blocks, batch_norms and transformation blocks
152
153            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
154                                             kernel_size=residual_kernel_size, act_func=act_func))
155            if use_trans_blocks:
156                transformation_blocks.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
157            if use_batch_norms:
158                batch_norms.append(nn.BatchNorm1d(hidden_dim))
159
160        self.conv_layers = nn.ModuleList(conv_layers)
161
162        if use_trans_blocks:
163            self.transformation_blocks = nn.ModuleList(transformation_blocks)
164        if use_batch_norms:
165            self.batch_norms = nn.ModuleList(batch_norms)
166
167        self.kmax_pooling = KMaxPool(k)
168
169        # Create the MLP to compress the k signals
170        linear_layers = list()
171        linear_layers.append(nn.Linear(hidden_dim * k, num_labels))  # Downsample into Kmaxpool?
172        # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons))
173        # linear_layers.append(nn.Dropout(dropout_perc))
174        # linear_layers.append(nn.Linear(hidden_neurons, num_labels))
175
176        self.linear_layers = nn.Sequential(*linear_layers)
177        self.apply(weight_init)
178        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
179                                              config=self.bert_config)  # Add Bert model after random initialisation
180
181        for param in self.bert.pooler.parameters():  # We don't need the pooler
182            param.requires_grad = False
183
184        self.bert.to(self.device)
185
186    def forward(self, *args, **kwargs):
187        # input_ids: batch_size x seq_length x hidden_dim
188        labels = kwargs['labels'] if 'labels' in kwargs else None
189        if labels is not None: del kwargs['labels']
190
191        bert_outputs = self.bert(*args, **kwargs)
192        hidden_states = bert_outputs[-1]
193
194        # Fix this, also draw out what ur model should do first
195        is_embedding_layer = True
196
197        assert len(self.conv_layers) == len(
198            hidden_states)  # == len(self.transformation_blocks) == len(self.batch_norms), info
199
200        zip_args = [self.conv_layers, hidden_states]
201        identity = lambda k: k
202
203        if self.use_trans_blocks:
204            assert len(self.transformation_blocks) == len(hidden_states)
205            zip_args.append(self.transformation_blocks)
206        else:
207            zip_args.append([identity for i in range(self.num_layers + 1)])
208
209        if self.use_batch_norms:
210            assert len(self.batch_norms) == len(hidden_states)
211            zip_args.append(self.batch_norms)
212        else:
213            zip_args.append([identity for i in range(self.num_layers + 1)])
214
215        out = None
216        for co, hi, tr, bn in zip(*zip_args):
217            if is_embedding_layer:
218                out = co(hi.transpose(1, 2))  # batch x hidden x seq_len
219                is_embedding_layer = not is_embedding_layer
220            else:
221                out = co(out + tr(bn(hi.transpose(1, 2))))  # add hidden dims together
222
223        assert out.shape[2] == self.seq_length
224
225        out = self.kmax_pooling(out)
226        # batch_size x seq_len x hidden -> batch_size x flatten
227        logits = self.linear_layers(torch.flatten(out, start_dim=1))
228
229        return self.loss_func(logits, labels), logits
230
231    @classmethod
232    def from_config(cls, *args, config_path):
233        kwargs = torch.load(config_path)
234        return ColBERT(*args, **kwargs)
235
236    @classmethod
237    def from_pretrained(cls, output_dir, **kwargs):
238        config_found = True
239        colbert_config = None
240
241        try:
242            colbert_config = CoLBERTConfig.load(output_dir)
243        except:
244            config_found = False
245
246        bert_config = None
247
248        if 'config' in kwargs:
249            bert_config = kwargs['config']
250            del kwargs['config']
251        else:
252            bert_config = BertConfig.from_pretrained(output_dir)
253
254        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
255        model = None
256
257        if config_found:
258            model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs)
259            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
260            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
261
262        else:
263            model = ColBERT((output_dir,), {}, config=bert_config, **kwargs)
264            logger.info(f"*** Create New CNN Bert Model ***")
265
266        return model
267
268    def save_pretrained(self, output_dir):
269        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
270        self.bert.save_pretrained(output_dir)
271        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
272        self.bert_config.save_pretrained(output_dir)
273        self.colbert_config.save(output_dir)
274        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
275
276
277class ComBERT(nn.Module):
278    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128,
279                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
280                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
281                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
282
283        super().__init__()
284        self.device = device
285        hidden_dim = config.hidden_size
286        self.seq_length = max_seq_len
287        self.use_trans_blocks = use_trans_blocks
288        self.use_batch_norms = use_batch_norms
289        self.num_layers = config.num_hidden_layers
290        num_labels = config.num_labels
291        self.num_blocks = num_blocks
292        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
293
294        # Save our kwargs to reinitialise the model during evaluation
295        self.bert_config = config
296        self.colbert_config = CoLBERTConfig(k=k,
297                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
298                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
299                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
300                                            act_func=act_func, bert_model_args=bert_model_args,
301                                            bert_model_kwargs=bert_model_kwargs)
302
303        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
304
305        # relax this constraint later
306        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
307        act_func = ACT_FUNCS[act_func.lower()]
308
309        # CNN Part
310        conv_layers = []
311
312        # Adds up to num_layers + 1 embedding layer
313        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
314
315        for i in range(num_blocks):
316            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
317                                             kernel_size=residual_kernel_size, act_func=act_func))
318
319        self.conv_layers = nn.ModuleList(conv_layers)
320        self.kmax_pooling = KMaxPool(k)
321
322        # Create the MLP to compress the k signals
323        linear_layers = list()
324        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
325        linear_layers.append(nn.Dropout(dropout_perc))
326        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
327
328        self.linear_layers = nn.Sequential(*linear_layers)
329        self.apply(weight_init)
330        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
331                                              config=self.bert_config)  # Add Bert model after random initialisation
332        self.bert.to(self.device)
333
334    def forward(self, *args, **kwargs):
335        # input_ids: batch_size x seq_length x hidden_dim
336
337        labels = kwargs['labels'] if 'labels' in kwargs else None
338        if labels is not None: del kwargs['labels']
339
340        bert_outputs = self.bert(*args, **kwargs)
341        hidden_states = list(bert_outputs[-1])
342        embedding_layer = hidden_states.pop(0)
343
344        split_size = len(hidden_states) // self.num_blocks
345
346        assert split_size % 2 == 0, "must be an even number"
347        split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)]
348        split_layers.insert(0, embedding_layer)
349
350        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
351
352        outputs = []
353
354        for cnv, layer in zip(self.conv_layers, split_layers):
355            outputs.append(self.kmax_pooling(cnv(layer)))
356
357        # batch_size x seq_len x hidden -> batch_size x flatten
358        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
359
360        return self.loss_func(logits, labels), logits
361
362    @classmethod
363    def from_config(cls, *args, config_path):
364        kwargs = torch.load(config_path)
365        return ComBERT(*args, **kwargs)
366
367    @classmethod
368    def from_pretrained(cls, output_dir, **kwargs):
369        config_found = True
370        colbert_config = None
371
372        try:
373            colbert_config = CoLBERTConfig.load(output_dir)
374        except:
375            config_found = False
376
377        bert_config = None
378
379        if 'config' in kwargs:
380            bert_config = kwargs['config']
381            del kwargs['config']
382        else:
383            bert_config = BertConfig.from_pretrained(output_dir)
384
385        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
386        model = None
387
388        if config_found:
389            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
390            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
391            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
392
393        else:
394            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
395            logger.info(f"*** Create New CNN Bert Model ***")
396
397        return model
398
399    def save_pretrained(self, output_dir):
400        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
401        self.bert.save_pretrained(output_dir)
402        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
403        self.bert_config.save_pretrained(output_dir)
404        self.colbert_config.save(output_dir)
405        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
class CoLBERTConfig:
21class CoLBERTConfig(object):
22    default_fname = "colbert_config.json"
23
24    def __init__(self, **kwargs):
25        self.kwargs = kwargs
26        self.__dict__.update(kwargs)
27
28    def save(self, path, fname=default_fname):
29        """
30        :param fname: file name
31        :param path: Path to save
32        """
33        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
34
35    @classmethod
36    def load(cls, path, fname=default_fname):
37        """
38        Load the ColBERT config from path (don't point to file name just directory)
39        :return ColBERTConfig:
40        """
41
42        kwargs = json.load(open(os.path.join(path, fname)))
43
44        return CoLBERTConfig(**kwargs)
CoLBERTConfig(**kwargs)
24    def __init__(self, **kwargs):
25        self.kwargs = kwargs
26        self.__dict__.update(kwargs)
def save(self, path, fname='colbert_config.json'):
28    def save(self, path, fname=default_fname):
29        """
30        :param fname: file name
31        :param path: Path to save
32        """
33        json.dump(self.kwargs, open(os.path.join(path, fname), 'w+'))
Parameters
  • fname: file name
  • path: Path to save
@classmethod
def load(cls, path, fname='colbert_config.json'):
35    @classmethod
36    def load(cls, path, fname=default_fname):
37        """
38        Load the ColBERT config from path (don't point to file name just directory)
39        :return ColBERTConfig:
40        """
41
42        kwargs = json.load(open(os.path.join(path, fname)))
43
44        return CoLBERTConfig(**kwargs)

Load the ColBERT config from path (don't point to file name just directory)

Returns
class ConvolutionalBlock(torch.nn.modules.module.Module):
47class ConvolutionalBlock(nn.Module):
48
49    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
50        super(ConvolutionalBlock, self).__init__()
51
52        padding = int((kernel_size - 1) / 2)
53        if kernel_size == 3:
54            assert padding == 1  # checks
55        if kernel_size == 5:
56            assert padding == 2  # checks
57        layers = [
58            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
59            nn.BatchNorm1d(num_features=out_channels)
60        ]
61
62        if act_func is not None:
63            layers.append(act_func())
64
65        self.sequential = nn.Sequential(*layers)
66
67    def forward(self, x):
68        return self.sequential(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvolutionalBlock( in_channels, out_channels, kernel_size=1, first_stride=1, act_func=<class 'torch.nn.modules.activation.ReLU'>)
49    def __init__(self, in_channels, out_channels, kernel_size=1, first_stride=1, act_func=nn.ReLU):
50        super(ConvolutionalBlock, self).__init__()
51
52        padding = int((kernel_size - 1) / 2)
53        if kernel_size == 3:
54            assert padding == 1  # checks
55        if kernel_size == 5:
56            assert padding == 2  # checks
57        layers = [
58            nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=first_stride, padding=padding),
59            nn.BatchNorm1d(num_features=out_channels)
60        ]
61
62        if act_func is not None:
63            layers.append(act_func())
64
65        self.sequential = nn.Sequential(*layers)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, x):
67    def forward(self, x):
68        return self.sequential(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class KMaxPool(torch.nn.modules.module.Module):
71class KMaxPool(nn.Module):
72    def __init__(self, k=1):
73        super(KMaxPool, self).__init__()
74
75        self.k = k
76
77    def forward(self, x):
78        # x : batch_size, channel, time_steps
79        if self.k == 'half':
80            time_steps = x.shape(2)
81            self.k = time_steps // 2
82
83        kmax, kargmax = torch.topk(x, self.k, sorted=True)
84        # kmax, kargmax = x.topk(self.k, dim=2)
85        return kmax

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

KMaxPool(k=1)
72    def __init__(self, k=1):
73        super(KMaxPool, self).__init__()
74
75        self.k = k

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, x):
77    def forward(self, x):
78        # x : batch_size, channel, time_steps
79        if self.k == 'half':
80            time_steps = x.shape(2)
81            self.k = time_steps // 2
82
83        kmax, kargmax = torch.topk(x, self.k, sorted=True)
84        # kmax, kargmax = x.topk(self.k, dim=2)
85        return kmax

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
def visualisation_dump(argmax, input_tensors):
88def visualisation_dump(argmax, input_tensors):
89    pass
class ResidualBlock(torch.nn.modules.module.Module):
 92class ResidualBlock(nn.Module):
 93
 94    def __init__(self, in_channels, out_channels, optional_shortcut=True,
 95                 kernel_size=1, act_func=nn.ReLU):
 96        super(ResidualBlock, self).__init__()
 97        self.optional_shortcut = optional_shortcut
 98        self.convolutional_block = ConvolutionalBlock(in_channels, out_channels, first_stride=1,
 99                                                      act_func=act_func, kernel_size=kernel_size)
100
101    def forward(self, x):
102        residual = x
103        x = self.convolutional_block(x)
104
105        if self.optional_shortcut:
106            x = x + residual
107
108        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ResidualBlock( in_channels, out_channels, optional_shortcut=True, kernel_size=1, act_func=<class 'torch.nn.modules.activation.ReLU'>)
94    def __init__(self, in_channels, out_channels, optional_shortcut=True,
95                 kernel_size=1, act_func=nn.ReLU):
96        super(ResidualBlock, self).__init__()
97        self.optional_shortcut = optional_shortcut
98        self.convolutional_block = ConvolutionalBlock(in_channels, out_channels, first_stride=1,
99                                                      act_func=act_func, kernel_size=kernel_size)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, x):
101    def forward(self, x):
102        residual = x
103        x = self.convolutional_block(x)
104
105        if self.optional_shortcut:
106            x = x + residual
107
108        return x

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class ColBERT(torch.nn.modules.module.Module):
111class ColBERT(nn.Module):
112    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int
113    = 128, k: int = 8,
114                 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
115                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
116                 act_func="mish", loss_func='cross_entropy_loss', **kwargs):  # kwargs for compat
117
118        super().__init__()
119        self.device = device
120        hidden_dim = config.hidden_size
121        self.seq_length = max_seq_len
122        self.use_trans_blocks = use_trans_blocks
123        self.use_batch_norms = use_batch_norms
124        self.num_layers = config.num_hidden_layers
125        num_labels = config.num_labels
126        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
127
128        # Save our kwargs to reinitialise the model during evaluation
129        self.bert_config = config
130        self.colbert_config = CoLBERTConfig(k=k,
131                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
132                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
133                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
134                                            act_func=act_func, bert_model_args=bert_model_args,
135                                            bert_model_kwargs=bert_model_kwargs)
136
137        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
138
139        # relax this constraint later
140        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
141        act_func = ACT_FUNCS[act_func.lower()]
142
143        # CNN Part
144        conv_layers = []
145        transformation_blocks = [None]  # Pad the first element, for the for loop in forward
146        batch_norms = [None]  # Pad the first element
147
148        # Adds up to num_layers + 1 embedding layer
149        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
150
151        for i in range(self.num_layers):
152            # Create the residual blocks, batch_norms and transformation blocks
153
154            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
155                                             kernel_size=residual_kernel_size, act_func=act_func))
156            if use_trans_blocks:
157                transformation_blocks.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
158            if use_batch_norms:
159                batch_norms.append(nn.BatchNorm1d(hidden_dim))
160
161        self.conv_layers = nn.ModuleList(conv_layers)
162
163        if use_trans_blocks:
164            self.transformation_blocks = nn.ModuleList(transformation_blocks)
165        if use_batch_norms:
166            self.batch_norms = nn.ModuleList(batch_norms)
167
168        self.kmax_pooling = KMaxPool(k)
169
170        # Create the MLP to compress the k signals
171        linear_layers = list()
172        linear_layers.append(nn.Linear(hidden_dim * k, num_labels))  # Downsample into Kmaxpool?
173        # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons))
174        # linear_layers.append(nn.Dropout(dropout_perc))
175        # linear_layers.append(nn.Linear(hidden_neurons, num_labels))
176
177        self.linear_layers = nn.Sequential(*linear_layers)
178        self.apply(weight_init)
179        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
180                                              config=self.bert_config)  # Add Bert model after random initialisation
181
182        for param in self.bert.pooler.parameters():  # We don't need the pooler
183            param.requires_grad = False
184
185        self.bert.to(self.device)
186
187    def forward(self, *args, **kwargs):
188        # input_ids: batch_size x seq_length x hidden_dim
189        labels = kwargs['labels'] if 'labels' in kwargs else None
190        if labels is not None: del kwargs['labels']
191
192        bert_outputs = self.bert(*args, **kwargs)
193        hidden_states = bert_outputs[-1]
194
195        # Fix this, also draw out what ur model should do first
196        is_embedding_layer = True
197
198        assert len(self.conv_layers) == len(
199            hidden_states)  # == len(self.transformation_blocks) == len(self.batch_norms), info
200
201        zip_args = [self.conv_layers, hidden_states]
202        identity = lambda k: k
203
204        if self.use_trans_blocks:
205            assert len(self.transformation_blocks) == len(hidden_states)
206            zip_args.append(self.transformation_blocks)
207        else:
208            zip_args.append([identity for i in range(self.num_layers + 1)])
209
210        if self.use_batch_norms:
211            assert len(self.batch_norms) == len(hidden_states)
212            zip_args.append(self.batch_norms)
213        else:
214            zip_args.append([identity for i in range(self.num_layers + 1)])
215
216        out = None
217        for co, hi, tr, bn in zip(*zip_args):
218            if is_embedding_layer:
219                out = co(hi.transpose(1, 2))  # batch x hidden x seq_len
220                is_embedding_layer = not is_embedding_layer
221            else:
222                out = co(out + tr(bn(hi.transpose(1, 2))))  # add hidden dims together
223
224        assert out.shape[2] == self.seq_length
225
226        out = self.kmax_pooling(out)
227        # batch_size x seq_len x hidden -> batch_size x flatten
228        logits = self.linear_layers(torch.flatten(out, start_dim=1))
229
230        return self.loss_func(logits, labels), logits
231
232    @classmethod
233    def from_config(cls, *args, config_path):
234        kwargs = torch.load(config_path)
235        return ColBERT(*args, **kwargs)
236
237    @classmethod
238    def from_pretrained(cls, output_dir, **kwargs):
239        config_found = True
240        colbert_config = None
241
242        try:
243            colbert_config = CoLBERTConfig.load(output_dir)
244        except:
245            config_found = False
246
247        bert_config = None
248
249        if 'config' in kwargs:
250            bert_config = kwargs['config']
251            del kwargs['config']
252        else:
253            bert_config = BertConfig.from_pretrained(output_dir)
254
255        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
256        model = None
257
258        if config_found:
259            model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs)
260            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
261            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
262
263        else:
264            model = ColBERT((output_dir,), {}, config=bert_config, **kwargs)
265            logger.info(f"*** Create New CNN Bert Model ***")
266
267        return model
268
269    def save_pretrained(self, output_dir):
270        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
271        self.bert.save_pretrained(output_dir)
272        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
273        self.bert_config.save_pretrained(output_dir)
274        self.colbert_config.save(output_dir)
275        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ColBERT( bert_model_args, bert_model_kwargs, config: transformers.models.bert.configuration_bert.BertConfig, device: str, max_seq_len: int = 128, k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, act_func='mish', loss_func='cross_entropy_loss', **kwargs)
112    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int
113    = 128, k: int = 8,
114                 optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
115                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
116                 act_func="mish", loss_func='cross_entropy_loss', **kwargs):  # kwargs for compat
117
118        super().__init__()
119        self.device = device
120        hidden_dim = config.hidden_size
121        self.seq_length = max_seq_len
122        self.use_trans_blocks = use_trans_blocks
123        self.use_batch_norms = use_batch_norms
124        self.num_layers = config.num_hidden_layers
125        num_labels = config.num_labels
126        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
127
128        # Save our kwargs to reinitialise the model during evaluation
129        self.bert_config = config
130        self.colbert_config = CoLBERTConfig(k=k,
131                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
132                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
133                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
134                                            act_func=act_func, bert_model_args=bert_model_args,
135                                            bert_model_kwargs=bert_model_kwargs)
136
137        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
138
139        # relax this constraint later
140        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
141        act_func = ACT_FUNCS[act_func.lower()]
142
143        # CNN Part
144        conv_layers = []
145        transformation_blocks = [None]  # Pad the first element, for the for loop in forward
146        batch_norms = [None]  # Pad the first element
147
148        # Adds up to num_layers + 1 embedding layer
149        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
150
151        for i in range(self.num_layers):
152            # Create the residual blocks, batch_norms and transformation blocks
153
154            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
155                                             kernel_size=residual_kernel_size, act_func=act_func))
156            if use_trans_blocks:
157                transformation_blocks.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
158            if use_batch_norms:
159                batch_norms.append(nn.BatchNorm1d(hidden_dim))
160
161        self.conv_layers = nn.ModuleList(conv_layers)
162
163        if use_trans_blocks:
164            self.transformation_blocks = nn.ModuleList(transformation_blocks)
165        if use_batch_norms:
166            self.batch_norms = nn.ModuleList(batch_norms)
167
168        self.kmax_pooling = KMaxPool(k)
169
170        # Create the MLP to compress the k signals
171        linear_layers = list()
172        linear_layers.append(nn.Linear(hidden_dim * k, num_labels))  # Downsample into Kmaxpool?
173        # linear_layers.append(nn.Linear(hidden_neurons, hidden_neurons))
174        # linear_layers.append(nn.Dropout(dropout_perc))
175        # linear_layers.append(nn.Linear(hidden_neurons, num_labels))
176
177        self.linear_layers = nn.Sequential(*linear_layers)
178        self.apply(weight_init)
179        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
180                                              config=self.bert_config)  # Add Bert model after random initialisation
181
182        for param in self.bert.pooler.parameters():  # We don't need the pooler
183            param.requires_grad = False
184
185        self.bert.to(self.device)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, *args, **kwargs):
187    def forward(self, *args, **kwargs):
188        # input_ids: batch_size x seq_length x hidden_dim
189        labels = kwargs['labels'] if 'labels' in kwargs else None
190        if labels is not None: del kwargs['labels']
191
192        bert_outputs = self.bert(*args, **kwargs)
193        hidden_states = bert_outputs[-1]
194
195        # Fix this, also draw out what ur model should do first
196        is_embedding_layer = True
197
198        assert len(self.conv_layers) == len(
199            hidden_states)  # == len(self.transformation_blocks) == len(self.batch_norms), info
200
201        zip_args = [self.conv_layers, hidden_states]
202        identity = lambda k: k
203
204        if self.use_trans_blocks:
205            assert len(self.transformation_blocks) == len(hidden_states)
206            zip_args.append(self.transformation_blocks)
207        else:
208            zip_args.append([identity for i in range(self.num_layers + 1)])
209
210        if self.use_batch_norms:
211            assert len(self.batch_norms) == len(hidden_states)
212            zip_args.append(self.batch_norms)
213        else:
214            zip_args.append([identity for i in range(self.num_layers + 1)])
215
216        out = None
217        for co, hi, tr, bn in zip(*zip_args):
218            if is_embedding_layer:
219                out = co(hi.transpose(1, 2))  # batch x hidden x seq_len
220                is_embedding_layer = not is_embedding_layer
221            else:
222                out = co(out + tr(bn(hi.transpose(1, 2))))  # add hidden dims together
223
224        assert out.shape[2] == self.seq_length
225
226        out = self.kmax_pooling(out)
227        # batch_size x seq_len x hidden -> batch_size x flatten
228        logits = self.linear_layers(torch.flatten(out, start_dim=1))
229
230        return self.loss_func(logits, labels), logits

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

@classmethod
def from_config(cls, *args, config_path):
232    @classmethod
233    def from_config(cls, *args, config_path):
234        kwargs = torch.load(config_path)
235        return ColBERT(*args, **kwargs)
@classmethod
def from_pretrained(cls, output_dir, **kwargs):
237    @classmethod
238    def from_pretrained(cls, output_dir, **kwargs):
239        config_found = True
240        colbert_config = None
241
242        try:
243            colbert_config = CoLBERTConfig.load(output_dir)
244        except:
245            config_found = False
246
247        bert_config = None
248
249        if 'config' in kwargs:
250            bert_config = kwargs['config']
251            del kwargs['config']
252        else:
253            bert_config = BertConfig.from_pretrained(output_dir)
254
255        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
256        model = None
257
258        if config_found:
259            model = ColBERT(config=bert_config, device=device, **colbert_config.kwargs)
260            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
261            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
262
263        else:
264            model = ColBERT((output_dir,), {}, config=bert_config, **kwargs)
265            logger.info(f"*** Create New CNN Bert Model ***")
266
267        return model
def save_pretrained(self, output_dir):
269    def save_pretrained(self, output_dir):
270        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
271        self.bert.save_pretrained(output_dir)
272        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
273        self.bert_config.save_pretrained(output_dir)
274        self.colbert_config.save(output_dir)
275        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class ComBERT(torch.nn.modules.module.Module):
278class ComBERT(nn.Module):
279    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128,
280                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
281                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
282                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
283
284        super().__init__()
285        self.device = device
286        hidden_dim = config.hidden_size
287        self.seq_length = max_seq_len
288        self.use_trans_blocks = use_trans_blocks
289        self.use_batch_norms = use_batch_norms
290        self.num_layers = config.num_hidden_layers
291        num_labels = config.num_labels
292        self.num_blocks = num_blocks
293        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
294
295        # Save our kwargs to reinitialise the model during evaluation
296        self.bert_config = config
297        self.colbert_config = CoLBERTConfig(k=k,
298                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
299                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
300                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
301                                            act_func=act_func, bert_model_args=bert_model_args,
302                                            bert_model_kwargs=bert_model_kwargs)
303
304        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
305
306        # relax this constraint later
307        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
308        act_func = ACT_FUNCS[act_func.lower()]
309
310        # CNN Part
311        conv_layers = []
312
313        # Adds up to num_layers + 1 embedding layer
314        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
315
316        for i in range(num_blocks):
317            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
318                                             kernel_size=residual_kernel_size, act_func=act_func))
319
320        self.conv_layers = nn.ModuleList(conv_layers)
321        self.kmax_pooling = KMaxPool(k)
322
323        # Create the MLP to compress the k signals
324        linear_layers = list()
325        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
326        linear_layers.append(nn.Dropout(dropout_perc))
327        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
328
329        self.linear_layers = nn.Sequential(*linear_layers)
330        self.apply(weight_init)
331        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
332                                              config=self.bert_config)  # Add Bert model after random initialisation
333        self.bert.to(self.device)
334
335    def forward(self, *args, **kwargs):
336        # input_ids: batch_size x seq_length x hidden_dim
337
338        labels = kwargs['labels'] if 'labels' in kwargs else None
339        if labels is not None: del kwargs['labels']
340
341        bert_outputs = self.bert(*args, **kwargs)
342        hidden_states = list(bert_outputs[-1])
343        embedding_layer = hidden_states.pop(0)
344
345        split_size = len(hidden_states) // self.num_blocks
346
347        assert split_size % 2 == 0, "must be an even number"
348        split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)]
349        split_layers.insert(0, embedding_layer)
350
351        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
352
353        outputs = []
354
355        for cnv, layer in zip(self.conv_layers, split_layers):
356            outputs.append(self.kmax_pooling(cnv(layer)))
357
358        # batch_size x seq_len x hidden -> batch_size x flatten
359        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
360
361        return self.loss_func(logits, labels), logits
362
363    @classmethod
364    def from_config(cls, *args, config_path):
365        kwargs = torch.load(config_path)
366        return ComBERT(*args, **kwargs)
367
368    @classmethod
369    def from_pretrained(cls, output_dir, **kwargs):
370        config_found = True
371        colbert_config = None
372
373        try:
374            colbert_config = CoLBERTConfig.load(output_dir)
375        except:
376            config_found = False
377
378        bert_config = None
379
380        if 'config' in kwargs:
381            bert_config = kwargs['config']
382            del kwargs['config']
383        else:
384            bert_config = BertConfig.from_pretrained(output_dir)
385
386        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
387        model = None
388
389        if config_found:
390            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
391            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
392            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
393
394        else:
395            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
396            logger.info(f"*** Create New CNN Bert Model ***")
397
398        return model
399
400    def save_pretrained(self, output_dir):
401        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
402        self.bert.save_pretrained(output_dir)
403        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
404        self.bert_config.save_pretrained(output_dir)
405        self.colbert_config.save(output_dir)
406        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ComBERT( bert_model_args, bert_model_kwargs, config: transformers.models.bert.configuration_bert.BertConfig, device: str, max_seq_len: int = 128, k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True, use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5, act_func='mish', loss_func='cross_entropy_loss', num_blocks=2, **kwargs)
279    def __init__(self, bert_model_args, bert_model_kwargs, config: BertConfig, device: str, max_seq_len: int = 128,
280                 k: int = 8, optional_shortcut: bool = True, hidden_neurons: int = 2048, use_batch_norms: bool = True,
281                 use_trans_blocks: bool = False, residual_kernel_size: int = 1, dropout_perc: float = 0.5,
282                 act_func="mish", loss_func='cross_entropy_loss', num_blocks=2, **kwargs):  # kwargs for compat
283
284        super().__init__()
285        self.device = device
286        hidden_dim = config.hidden_size
287        self.seq_length = max_seq_len
288        self.use_trans_blocks = use_trans_blocks
289        self.use_batch_norms = use_batch_norms
290        self.num_layers = config.num_hidden_layers
291        num_labels = config.num_labels
292        self.num_blocks = num_blocks
293        self.loss_func = LOSS_FUNCS[loss_func.lower()]()
294
295        # Save our kwargs to reinitialise the model during evaluation
296        self.bert_config = config
297        self.colbert_config = CoLBERTConfig(k=k,
298                                            optional_shortcut=optional_shortcut, hidden_neurons=hidden_neurons,
299                                            use_batch_norms=use_batch_norms, use_trans_blocks=use_trans_blocks,
300                                            residual_kernel_size=residual_kernel_size, dropout_perc=dropout_perc,
301                                            act_func=act_func, bert_model_args=bert_model_args,
302                                            bert_model_kwargs=bert_model_kwargs)
303
304        logging.info("ColBERT Configuration %s" % str(self.colbert_config.kwargs))
305
306        # relax this constraint later
307        assert act_func.lower() in ACT_FUNCS, f"Error not in activation function dictionary, {ACT_FUNCS.keys()}"
308        act_func = ACT_FUNCS[act_func.lower()]
309
310        # CNN Part
311        conv_layers = []
312
313        # Adds up to num_layers + 1 embedding layer
314        conv_layers.append(nn.Conv1d(hidden_dim, hidden_dim, kernel_size=1))
315
316        for i in range(num_blocks):
317            conv_layers.append(ResidualBlock(hidden_dim, hidden_dim, optional_shortcut=optional_shortcut,
318                                             kernel_size=residual_kernel_size, act_func=act_func))
319
320        self.conv_layers = nn.ModuleList(conv_layers)
321        self.kmax_pooling = KMaxPool(k)
322
323        # Create the MLP to compress the k signals
324        linear_layers = list()
325        linear_layers.append(nn.Linear(hidden_dim * k, hidden_neurons))  # Downsample into Kmaxpool?
326        linear_layers.append(nn.Dropout(dropout_perc))
327        linear_layers.append(nn.Linear(hidden_neurons, num_labels))
328
329        self.linear_layers = nn.Sequential(*linear_layers)
330        self.apply(weight_init)
331        self.bert = BertModel.from_pretrained(*bert_model_args, **bert_model_kwargs,
332                                              config=self.bert_config)  # Add Bert model after random initialisation
333        self.bert.to(self.device)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, *args, **kwargs):
335    def forward(self, *args, **kwargs):
336        # input_ids: batch_size x seq_length x hidden_dim
337
338        labels = kwargs['labels'] if 'labels' in kwargs else None
339        if labels is not None: del kwargs['labels']
340
341        bert_outputs = self.bert(*args, **kwargs)
342        hidden_states = list(bert_outputs[-1])
343        embedding_layer = hidden_states.pop(0)
344
345        split_size = len(hidden_states) // self.num_blocks
346
347        assert split_size % 2 == 0, "must be an even number"
348        split_layers = [hidden_states[x:x + split_size] for x in range(0, len(hidden_states), split_size)]
349        split_layers.insert(0, embedding_layer)
350
351        assert len(self.conv_layers) == len(split_layers), "must have equal inputs in length"
352
353        outputs = []
354
355        for cnv, layer in zip(self.conv_layers, split_layers):
356            outputs.append(self.kmax_pooling(cnv(layer)))
357
358        # batch_size x seq_len x hidden -> batch_size x flatten
359        logits = self.linear_layers(torch.flatten(torch.cat(outputs, dim=-1), start_dim=1))
360
361        return self.loss_func(logits, labels), logits

Defines the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

@classmethod
def from_config(cls, *args, config_path):
363    @classmethod
364    def from_config(cls, *args, config_path):
365        kwargs = torch.load(config_path)
366        return ComBERT(*args, **kwargs)
@classmethod
def from_pretrained(cls, output_dir, **kwargs):
368    @classmethod
369    def from_pretrained(cls, output_dir, **kwargs):
370        config_found = True
371        colbert_config = None
372
373        try:
374            colbert_config = CoLBERTConfig.load(output_dir)
375        except:
376            config_found = False
377
378        bert_config = None
379
380        if 'config' in kwargs:
381            bert_config = kwargs['config']
382            del kwargs['config']
383        else:
384            bert_config = BertConfig.from_pretrained(output_dir)
385
386        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
387        model = None
388
389        if config_found:
390            model = ComBERT(config=bert_config, device=device, **colbert_config.kwargs)
391            model.load_state_dict(torch.load(output_dir + '/cnn_bert.pth'))
392            logger.info(f"*** Loaded CNN Bert Model Weights from {output_dir + '/cnn_bert.pth'}")
393
394        else:
395            model = ComBERT((output_dir,), {}, config=bert_config, **kwargs)
396            logger.info(f"*** Create New CNN Bert Model ***")
397
398        return model
def save_pretrained(self, output_dir):
400    def save_pretrained(self, output_dir):
401        logger.info(f"*** Saved Bert Model Weights to {output_dir}")
402        self.bert.save_pretrained(output_dir)
403        torch.save(self.state_dict(), output_dir + '/cnn_bert.pth')
404        self.bert_config.save_pretrained(output_dir)
405        self.colbert_config.save(output_dir)
406        logger.info(f"*** Saved CNN Bert Model Weights to {output_dir + '/cnn_bert.pth'}")
Inherited Members
torch.nn.modules.module.Module
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr