debeir.training.losses.contrastive
Author: Yonglong Tian (yonglong@mit.edu) Date: May 07, 2020
Code imported from: https://github.com/HobbitLong/SupContrast/blob/master/losses.py
1""" 2Author: Yonglong Tian (yonglong@mit.edu) 3Date: May 07, 2020 4 5 6Code imported from: https://github.com/HobbitLong/SupContrast/blob/master/losses.py 7""" 8 9from enum import Enum 10from typing import Dict, Iterable 11 12import torch 13import torch.nn.functional as F 14from torch import Tensor, nn 15 16 17class SupConLoss(nn.Module): 18 """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 19 It also supports the unsupervised contrastive loss in SimCLR""" 20 21 # def __init__(self, temperature=0.07, contrast_mode='all', 22 # base_temperature=0.07): 23 def __init__(self, temperature=1.0, contrast_mode='all', 24 base_temperature=1.0): 25 super(SupConLoss, self).__init__() 26 self.temperature = temperature 27 self.base_temperature = base_temperature 28 self.contrast_mode = contrast_mode 29 30 def forward(self, features, labels=None, mask=None): 31 """Compute loss for model. If both `labels` and `mask` are None, 32 it degenerates to SimCLR unsupervised loss: 33 https://arxiv.org/pdf/2002.05709.pdf 34 Args: 35 features: hidden vector of shape [bsz, n_views, ...]. 36 labels: ground truth of shape [bsz]. 37 mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 38 has the same class as sample i. Can be asymmetric. 39 Returns: 40 A loss scalar. 41 """ 42 device = (torch.device('cuda') 43 if features.is_cuda 44 else torch.device('cpu')) 45 46 if len(features.shape) < 3: 47 raise ValueError('`features` needs to be [bsz, n_views, ...],' 48 'at least 3 dimensions are required') 49 if len(features.shape) > 3: 50 features = features.view(features.shape[0], features.shape[1], -1) 51 52 batch_size = features.shape[0] 53 if labels is not None and mask is not None: 54 raise ValueError('Cannot define both `labels` and `mask`') 55 elif labels is None and mask is None: 56 mask = torch.eye(batch_size, dtype=torch.float32).to(device) 57 elif labels is not None: 58 labels = labels.contiguous().view(-1, 1) 59 if labels.shape[0] != batch_size: 60 raise ValueError('Num of labels does not match num of features') 61 mask = torch.eq(labels, labels.T).float().to(device) 62 else: 63 mask = mask.float().to(device) 64 65 contrast_count = features.shape[1] 66 contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 67 if self.contrast_mode == 'one': 68 anchor_feature = features[:, 0] 69 anchor_count = 1 70 elif self.contrast_mode == 'all': 71 anchor_feature = contrast_feature 72 anchor_count = contrast_count 73 else: 74 raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 75 76 # compute logits 77 anchor_dot_contrast = torch.div( 78 torch.matmul(anchor_feature, contrast_feature.T), 79 self.temperature) 80 # for numerical stability 81 logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 82 logits = anchor_dot_contrast - logits_max.detach() 83 84 # tile mask 85 mask = mask.repeat(anchor_count, contrast_count) 86 # mask-out self-contrast cases 87 logits_mask = torch.scatter( 88 torch.ones_like(mask), 89 1, 90 torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 91 0 92 ) 93 mask = mask * logits_mask 94 95 # compute log_prob 96 exp_logits = torch.exp(logits) * logits_mask 97 log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 98 99 # compute mean of log-likelihood over positive 100 mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 101 102 # loss 103 loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 104 loss = loss.view(anchor_count, batch_size).mean() 105 106 return loss 107 108 109class SiameseDistanceMetric(Enum): 110 """ 111 The metric for the contrastive loss 112 """ 113 EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) 114 MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) 115 COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y) 116 117 118class ContrastiveSentLoss(nn.Module): 119 """ 120 Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the 121 two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. 122 Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 123 :param model: SentenceTransformer model 124 :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used 125 :param margin: Negative samples (label == 0) should have a distance of at least the margin value. 126 :param size_average: Average by the size of the mini-batch. 127 Example:: 128 from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample 129 from torch.utils.data import DataLoader 130 model = SentenceTransformer('all-MiniLM-L6-v2') 131 train_examples = [ 132 InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), 133 InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] 134 train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) 135 train_loss = losses.ContrastiveLoss(model=model) 136 model.fit([(train_dataloader, train_loss)], show_progress_bar=True) 137 """ 138 139 def __init__(self, model, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, 140 margin: float = 0.5, size_average: bool = True): 141 super(ContrastiveSentLoss, self).__init__() 142 self.distance_metric = distance_metric 143 self.margin = margin 144 self.model = model 145 self.size_average = size_average 146 147 def get_config_dict(self): 148 distance_metric_name = self.distance_metric.__name__ 149 for name, value in vars(SiameseDistanceMetric).items(): 150 if value == self.distance_metric: 151 distance_metric_name = "SiameseDistanceMetric.{}".format(name) 152 break 153 154 return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average} 155 156 def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): 157 reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] 158 assert len(reps) == 2 159 rep_anchor, rep_other = reps 160 distances = self.distance_metric(rep_anchor, rep_other) 161 losses = 0.5 * ( 162 labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)) 163 return losses.mean() if self.size_average else losses.sum()
18class SupConLoss(nn.Module): 19 """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 20 It also supports the unsupervised contrastive loss in SimCLR""" 21 22 # def __init__(self, temperature=0.07, contrast_mode='all', 23 # base_temperature=0.07): 24 def __init__(self, temperature=1.0, contrast_mode='all', 25 base_temperature=1.0): 26 super(SupConLoss, self).__init__() 27 self.temperature = temperature 28 self.base_temperature = base_temperature 29 self.contrast_mode = contrast_mode 30 31 def forward(self, features, labels=None, mask=None): 32 """Compute loss for model. If both `labels` and `mask` are None, 33 it degenerates to SimCLR unsupervised loss: 34 https://arxiv.org/pdf/2002.05709.pdf 35 Args: 36 features: hidden vector of shape [bsz, n_views, ...]. 37 labels: ground truth of shape [bsz]. 38 mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 39 has the same class as sample i. Can be asymmetric. 40 Returns: 41 A loss scalar. 42 """ 43 device = (torch.device('cuda') 44 if features.is_cuda 45 else torch.device('cpu')) 46 47 if len(features.shape) < 3: 48 raise ValueError('`features` needs to be [bsz, n_views, ...],' 49 'at least 3 dimensions are required') 50 if len(features.shape) > 3: 51 features = features.view(features.shape[0], features.shape[1], -1) 52 53 batch_size = features.shape[0] 54 if labels is not None and mask is not None: 55 raise ValueError('Cannot define both `labels` and `mask`') 56 elif labels is None and mask is None: 57 mask = torch.eye(batch_size, dtype=torch.float32).to(device) 58 elif labels is not None: 59 labels = labels.contiguous().view(-1, 1) 60 if labels.shape[0] != batch_size: 61 raise ValueError('Num of labels does not match num of features') 62 mask = torch.eq(labels, labels.T).float().to(device) 63 else: 64 mask = mask.float().to(device) 65 66 contrast_count = features.shape[1] 67 contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 68 if self.contrast_mode == 'one': 69 anchor_feature = features[:, 0] 70 anchor_count = 1 71 elif self.contrast_mode == 'all': 72 anchor_feature = contrast_feature 73 anchor_count = contrast_count 74 else: 75 raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 76 77 # compute logits 78 anchor_dot_contrast = torch.div( 79 torch.matmul(anchor_feature, contrast_feature.T), 80 self.temperature) 81 # for numerical stability 82 logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 83 logits = anchor_dot_contrast - logits_max.detach() 84 85 # tile mask 86 mask = mask.repeat(anchor_count, contrast_count) 87 # mask-out self-contrast cases 88 logits_mask = torch.scatter( 89 torch.ones_like(mask), 90 1, 91 torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 92 0 93 ) 94 mask = mask * logits_mask 95 96 # compute log_prob 97 exp_logits = torch.exp(logits) * logits_mask 98 log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 99 100 # compute mean of log-likelihood over positive 101 mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 102 103 # loss 104 loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 105 loss = loss.view(anchor_count, batch_size).mean() 106 107 return loss
Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. It also supports the unsupervised contrastive loss in SimCLR
24 def __init__(self, temperature=1.0, contrast_mode='all', 25 base_temperature=1.0): 26 super(SupConLoss, self).__init__() 27 self.temperature = temperature 28 self.base_temperature = base_temperature 29 self.contrast_mode = contrast_mode
Initializes internal Module state, shared by both nn.Module and ScriptModule.
31 def forward(self, features, labels=None, mask=None): 32 """Compute loss for model. If both `labels` and `mask` are None, 33 it degenerates to SimCLR unsupervised loss: 34 https://arxiv.org/pdf/2002.05709.pdf 35 Args: 36 features: hidden vector of shape [bsz, n_views, ...]. 37 labels: ground truth of shape [bsz]. 38 mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 39 has the same class as sample i. Can be asymmetric. 40 Returns: 41 A loss scalar. 42 """ 43 device = (torch.device('cuda') 44 if features.is_cuda 45 else torch.device('cpu')) 46 47 if len(features.shape) < 3: 48 raise ValueError('`features` needs to be [bsz, n_views, ...],' 49 'at least 3 dimensions are required') 50 if len(features.shape) > 3: 51 features = features.view(features.shape[0], features.shape[1], -1) 52 53 batch_size = features.shape[0] 54 if labels is not None and mask is not None: 55 raise ValueError('Cannot define both `labels` and `mask`') 56 elif labels is None and mask is None: 57 mask = torch.eye(batch_size, dtype=torch.float32).to(device) 58 elif labels is not None: 59 labels = labels.contiguous().view(-1, 1) 60 if labels.shape[0] != batch_size: 61 raise ValueError('Num of labels does not match num of features') 62 mask = torch.eq(labels, labels.T).float().to(device) 63 else: 64 mask = mask.float().to(device) 65 66 contrast_count = features.shape[1] 67 contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 68 if self.contrast_mode == 'one': 69 anchor_feature = features[:, 0] 70 anchor_count = 1 71 elif self.contrast_mode == 'all': 72 anchor_feature = contrast_feature 73 anchor_count = contrast_count 74 else: 75 raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 76 77 # compute logits 78 anchor_dot_contrast = torch.div( 79 torch.matmul(anchor_feature, contrast_feature.T), 80 self.temperature) 81 # for numerical stability 82 logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 83 logits = anchor_dot_contrast - logits_max.detach() 84 85 # tile mask 86 mask = mask.repeat(anchor_count, contrast_count) 87 # mask-out self-contrast cases 88 logits_mask = torch.scatter( 89 torch.ones_like(mask), 90 1, 91 torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 92 0 93 ) 94 mask = mask * logits_mask 95 96 # compute log_prob 97 exp_logits = torch.exp(logits) * logits_mask 98 log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 99 100 # compute mean of log-likelihood over positive 101 mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 102 103 # loss 104 loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 105 loss = loss.view(anchor_count, batch_size).mean() 106 107 return loss
Compute loss for model. If both labels
and mask
are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
110class SiameseDistanceMetric(Enum): 111 """ 112 The metric for the contrastive loss 113 """ 114 EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) 115 MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) 116 COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)
The metric for the contrastive loss
Inherited Members
- enum.Enum
- name
- value
119class ContrastiveSentLoss(nn.Module): 120 """ 121 Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the 122 two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. 123 Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf 124 :param model: SentenceTransformer model 125 :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used 126 :param margin: Negative samples (label == 0) should have a distance of at least the margin value. 127 :param size_average: Average by the size of the mini-batch. 128 Example:: 129 from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample 130 from torch.utils.data import DataLoader 131 model = SentenceTransformer('all-MiniLM-L6-v2') 132 train_examples = [ 133 InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), 134 InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] 135 train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) 136 train_loss = losses.ContrastiveLoss(model=model) 137 model.fit([(train_dataloader, train_loss)], show_progress_bar=True) 138 """ 139 140 def __init__(self, model, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, 141 margin: float = 0.5, size_average: bool = True): 142 super(ContrastiveSentLoss, self).__init__() 143 self.distance_metric = distance_metric 144 self.margin = margin 145 self.model = model 146 self.size_average = size_average 147 148 def get_config_dict(self): 149 distance_metric_name = self.distance_metric.__name__ 150 for name, value in vars(SiameseDistanceMetric).items(): 151 if value == self.distance_metric: 152 distance_metric_name = "SiameseDistanceMetric.{}".format(name) 153 break 154 155 return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average} 156 157 def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): 158 reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] 159 assert len(reps) == 2 160 rep_anchor, rep_other = reps 161 distances = self.distance_metric(rep_anchor, rep_other) 162 losses = 0.5 * ( 163 labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)) 164 return losses.mean() if self.size_average else losses.sum()
Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
Parameters
- model: SentenceTransformer model
- distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
- margin: Negative samples (label == 0) should have a distance of at least the margin value.
- size_average: Average by the size of the mini-batch. Example:: from sentence_transformers import SentenceTransformer, LoggingHandler, losses, InputExample from torch.utils.data import DataLoader model = SentenceTransformer('all-MiniLM-L6-v2') train_examples = [ InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1), InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0)] train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2) train_loss = losses.ContrastiveLoss(model=model) model.fit([(train_dataloader, train_loss)], show_progress_bar=True)
140 def __init__(self, model, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, 141 margin: float = 0.5, size_average: bool = True): 142 super(ContrastiveSentLoss, self).__init__() 143 self.distance_metric = distance_metric 144 self.margin = margin 145 self.model = model 146 self.size_average = size_average
Initializes internal Module state, shared by both nn.Module and ScriptModule.
148 def get_config_dict(self): 149 distance_metric_name = self.distance_metric.__name__ 150 for name, value in vars(SiameseDistanceMetric).items(): 151 if value == self.distance_metric: 152 distance_metric_name = "SiameseDistanceMetric.{}".format(name) 153 break 154 155 return {'distance_metric': distance_metric_name, 'margin': self.margin, 'size_average': self.size_average}
157 def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor): 158 reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features] 159 assert len(reps) == 2 160 rep_anchor, rep_other = reps 161 distances = self.distance_metric(rep_anchor, rep_other) 162 losses = 0.5 * ( 163 labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)) 164 return losses.mean() if self.size_average else losses.sum()
Defines the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr