debeir.training.evaluate_reranker

 1from collections import defaultdict
 2from typing import Dict, List, Union
 3
 4import numpy as np
 5from debeir.evaluation.evaluator import Evaluator
 6from debeir.rankers.transformer_sent_encoder import Encoder
 7from sklearn.metrics.pairwise import cosine_similarity
 8
 9from datasets import Dataset
10
11distance_fns = {
12    "dot_score": np.dot,
13    "cos_sim": cosine_similarity
14}
15
16
17class SentenceEvaluator(Evaluator):
18    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
19                 text_cols: List[str], query_cols: List[str], id_col: str,
20                 distance_fn: str,
21                 qrels: str, metrics: List[str]):
22        super().__init__(qrels, metrics)
23        self.encoder = model
24        self.dataset = dataset
25        self.parsed_topics = parsed_topics
26        self.distance_fn = distance_fns[distance_fn]
27        self.query_cols = query_cols
28        self.text_cols = text_cols
29
30        self._get_topic_embeddings(query_cols)
31        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
32
33    def _get_topic_embeddings(self, query_cols):
34        for topic_num, topic in self.parsed_topics.items():
35            for query_col in query_cols:
36                query = topic[query_col]
37                query_eb = self.encoder(query)
38
39                topic[query_col + "_eb"] = query_eb
40
41    def _get_document_embedding_and_mapping(self, id_col, text_cols):
42        document_ebs = defaultdict(lambda: defaultdict(lambda: []))
43
44        for datum in self.dataset:
45            for text_col in text_cols:
46                embedding = self.encoder(datum[text_col])
47                topic_num, doc_id = datum[id_col].split("_")
48                document_ebs[topic_num][doc_id].append([text_col, embedding])
49
50        return document_ebs
51
52    def _get_score(self, a, b, aggregate="sum"):
53        scores = []
54
55        aggs = {
56            "max": max,
57            "min": min,
58            "sum": sum,
59            "avg": lambda k: sum(k) / len(k)
60        }
61
62        if not isinstance(a[0], list):
63            a = [a]
64
65        if not isinstance(b[0], list):
66            b = [b]
67
68        for _a in a:
69            for _b in b:
70                scores.append(float(self.distance_fn(_a, _b)))
71
72        return aggs[aggregate](scores)
73
74    def produce_ranked_lists(self):
75        # Store the indexes to access
76        # For each topic, sort.
77
78        topics = defaultdict(lambda: [])  # [document_id, score]
79
80        for topic_num, doc_topics in self.document_ebs.items():
81            for doc_id, document_repr in doc_topics.items():
82                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
83
84                query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols]
85                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
86
87        for topic_num in topics:
88            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
89
90        return topics
class SentenceEvaluator(debeir.evaluation.evaluator.Evaluator):
18class SentenceEvaluator(Evaluator):
19    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
20                 text_cols: List[str], query_cols: List[str], id_col: str,
21                 distance_fn: str,
22                 qrels: str, metrics: List[str]):
23        super().__init__(qrels, metrics)
24        self.encoder = model
25        self.dataset = dataset
26        self.parsed_topics = parsed_topics
27        self.distance_fn = distance_fns[distance_fn]
28        self.query_cols = query_cols
29        self.text_cols = text_cols
30
31        self._get_topic_embeddings(query_cols)
32        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
33
34    def _get_topic_embeddings(self, query_cols):
35        for topic_num, topic in self.parsed_topics.items():
36            for query_col in query_cols:
37                query = topic[query_col]
38                query_eb = self.encoder(query)
39
40                topic[query_col + "_eb"] = query_eb
41
42    def _get_document_embedding_and_mapping(self, id_col, text_cols):
43        document_ebs = defaultdict(lambda: defaultdict(lambda: []))
44
45        for datum in self.dataset:
46            for text_col in text_cols:
47                embedding = self.encoder(datum[text_col])
48                topic_num, doc_id = datum[id_col].split("_")
49                document_ebs[topic_num][doc_id].append([text_col, embedding])
50
51        return document_ebs
52
53    def _get_score(self, a, b, aggregate="sum"):
54        scores = []
55
56        aggs = {
57            "max": max,
58            "min": min,
59            "sum": sum,
60            "avg": lambda k: sum(k) / len(k)
61        }
62
63        if not isinstance(a[0], list):
64            a = [a]
65
66        if not isinstance(b[0], list):
67            b = [b]
68
69        for _a in a:
70            for _b in b:
71                scores.append(float(self.distance_fn(_a, _b)))
72
73        return aggs[aggregate](scores)
74
75    def produce_ranked_lists(self):
76        # Store the indexes to access
77        # For each topic, sort.
78
79        topics = defaultdict(lambda: [])  # [document_id, score]
80
81        for topic_num, doc_topics in self.document_ebs.items():
82            for doc_id, document_repr in doc_topics.items():
83                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
84
85                query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols]
86                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
87
88        for topic_num in topics:
89            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
90
91        return topics

Evaluation class for computing metrics from TREC-style files

SentenceEvaluator( model: debeir.rankers.transformer_sent_encoder.Encoder, dataset: datasets.arrow_dataset.Dataset, parsed_topics: Dict[Union[str, int], Dict], text_cols: List[str], query_cols: List[str], id_col: str, distance_fn: str, qrels: str, metrics: List[str])
19    def __init__(self, model: Encoder, dataset: Dataset, parsed_topics: Dict[Union[str, int], Dict],
20                 text_cols: List[str], query_cols: List[str], id_col: str,
21                 distance_fn: str,
22                 qrels: str, metrics: List[str]):
23        super().__init__(qrels, metrics)
24        self.encoder = model
25        self.dataset = dataset
26        self.parsed_topics = parsed_topics
27        self.distance_fn = distance_fns[distance_fn]
28        self.query_cols = query_cols
29        self.text_cols = text_cols
30
31        self._get_topic_embeddings(query_cols)
32        self.document_ebs = self._get_document_embedding_and_mapping(id_col, text_cols)
def produce_ranked_lists(self):
75    def produce_ranked_lists(self):
76        # Store the indexes to access
77        # For each topic, sort.
78
79        topics = defaultdict(lambda: [])  # [document_id, score]
80
81        for topic_num, doc_topics in self.document_ebs.items():
82            for doc_id, document_repr in doc_topics.items():
83                doc_txt_cols, doc_embeddings = list(zip(*document_repr))
84
85                query_ebs = [self.parsed_topics[text_col + "_eb"] for text_col in self.text_cols]
86                topics[topic_num].append([doc_id, self._get_score(query_ebs, doc_embeddings)])
87
88        for topic_num in topics:
89            topics[topic_num].sort(key=lambda k: k[1], reverse=True)
90
91        return topics