debeir.core.query
1import dataclasses 2from typing import Dict, Optional, Union 3 4import loguru 5from debeir.engines.elasticsearch.generate_script_score import generate_script 6from debeir.core.config import GenericConfig, apply_config 7from debeir.utils.scaler import get_z_value 8 9 10@dataclasses.dataclass(init=True) 11class Query: 12 """ 13 A query interface class 14 :param topics: Topics that the query will be composed of 15 :param config: Config object that contains the settings for querying 16 """ 17 topics: Dict[int, Dict[str, str]] 18 config: GenericConfig 19 20 21class GenericElasticsearchQuery(Query): 22 """ 23 A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. 24 Requires topics, configs to be included 25 """ 26 id_mapping: str = "Id" 27 28 def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs): 29 super().__init__(topics, config) 30 31 if id_mapping is None: 32 self.id_mapping = "id" 33 34 if mappings is None: 35 self.mappings = ["Text"] 36 else: 37 self.mappings = mappings 38 39 self.topics = topics 40 self.config = config 41 self.query_type = self.config.query_type 42 43 self.embed_mappings = ["Text_Embedding"] 44 45 self.query_funcs = { 46 "query": self.generate_query, 47 "embedding": self.generate_query_embedding, 48 } 49 50 self.top_bm25_scores = top_bm25_scores 51 52 def _generate_base_query(self, topic_num): 53 qfield = list(self.topics[topic_num].keys())[0] 54 query = self.topics[topic_num][qfield] 55 should = {"should": []} 56 57 for i, field in enumerate(self.mappings): 58 should["should"].append( 59 { 60 "match": { 61 f"{field}": { 62 "query": query, 63 } 64 } 65 } 66 ) 67 68 return qfield, query, should 69 70 def generate_query(self, topic_num, *args, **kwargs): 71 """ 72 Generates a simple BM25 query based off the query facets. Searches over all the document facets. 73 :param topic_num: 74 :param args: 75 :param kwargs: 76 :return: 77 """ 78 _, _, should = self._generate_base_query(topic_num) 79 80 query = { 81 "query": { 82 "bool": should, 83 } 84 } 85 86 return query 87 88 def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]): 89 """ 90 Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used 91 for log normalization. 92 93 Score = log(bm25)/log(z) + embed_score 94 :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score} 95 """ 96 self.top_bm25_scores = scores 97 98 def has_bm25_scores(self): 99 """ 100 Checks if BM25 scores have been set 101 :return: 102 """ 103 return self.top_bm25_scores is not None 104 105 @apply_config 106 def generate_query_embedding( 107 self, topic_num, encoder, *args, norm_weight=2.15, ablations=False, cosine_ceiling=Optional[float], 108 cosine_offset: float = 1.0, **kwargs): 109 """ 110 Generates an embedding script score query for Elasticsearch as part of the NIR scoring function. 111 112 :param topic_num: The topic number to search for 113 :param encoder: The encoder that will be used for encoding the topics 114 :param norm_weight: The BM25 log normalization constant 115 :param ablations: Whether to execute ablation style queries (i.e. one query facet 116 or one document facet at a time) 117 :param cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation 118 :param args: 119 :param kwargs: Pass disable_cache to disable encoder caching 120 :return: 121 An elasticsearch script_score query 122 """ 123 124 qfields = list(self.topics[topic_num].keys()) 125 should = {"should": []} 126 127 if self.has_bm25_scores(): 128 cosine_ceiling = len(self.embed_mappings) * len(qfields) if cosine_ceiling is None else cosine_ceiling 129 norm_weight = get_z_value( 130 cosine_ceiling=cosine_ceiling, 131 bm25_ceiling=self.top_bm25_scores[topic_num], 132 ) 133 loguru.logger.debug(f"Automatic norm_weight: {norm_weight}") 134 135 params = { 136 "weights": [1] * (len(self.embed_mappings) * len(self.mappings)), 137 "offset": cosine_offset, 138 "norm_weight": norm_weight, 139 "disable_bm25": ablations, 140 } 141 142 embed_fields = [] 143 144 for qfield in qfields: 145 for field in self.mappings: 146 should["should"].append( 147 { 148 "match": { 149 f"{field}": { 150 "query": self.topics[topic_num][qfield], 151 } 152 } 153 } 154 ) 155 156 params[f"{qfield}_eb"] = encoder.encode(topic=self.topics[topic_num][qfield]) 157 embed_fields.append(f"{qfield}_eb") 158 159 query = { 160 "query": { 161 "script_score": { 162 "query": { 163 "bool": should, 164 }, 165 "script": generate_script( 166 self.embed_mappings, params, qfields=embed_fields 167 ), 168 } 169 } 170 } 171 172 loguru.logger.debug(query) 173 return query 174 175 @classmethod 176 def get_id_mapping(cls, hit): 177 """ 178 Get the document ID 179 180 :param hit: The raw document result 181 :return: 182 The document's ID 183 """ 184 return hit[cls.id_mapping]
@dataclasses.dataclass(init=True)
class
Query:
11@dataclasses.dataclass(init=True) 12class Query: 13 """ 14 A query interface class 15 :param topics: Topics that the query will be composed of 16 :param config: Config object that contains the settings for querying 17 """ 18 topics: Dict[int, Dict[str, str]] 19 config: GenericConfig
A query interface class
Parameters
- topics: Topics that the query will be composed of
- config: Config object that contains the settings for querying
Query( topics: Dict[int, Dict[str, str]], config: debeir.core.config.GenericConfig)
22class GenericElasticsearchQuery(Query): 23 """ 24 A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. 25 Requires topics, configs to be included 26 """ 27 id_mapping: str = "Id" 28 29 def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs): 30 super().__init__(topics, config) 31 32 if id_mapping is None: 33 self.id_mapping = "id" 34 35 if mappings is None: 36 self.mappings = ["Text"] 37 else: 38 self.mappings = mappings 39 40 self.topics = topics 41 self.config = config 42 self.query_type = self.config.query_type 43 44 self.embed_mappings = ["Text_Embedding"] 45 46 self.query_funcs = { 47 "query": self.generate_query, 48 "embedding": self.generate_query_embedding, 49 } 50 51 self.top_bm25_scores = top_bm25_scores 52 53 def _generate_base_query(self, topic_num): 54 qfield = list(self.topics[topic_num].keys())[0] 55 query = self.topics[topic_num][qfield] 56 should = {"should": []} 57 58 for i, field in enumerate(self.mappings): 59 should["should"].append( 60 { 61 "match": { 62 f"{field}": { 63 "query": query, 64 } 65 } 66 } 67 ) 68 69 return qfield, query, should 70 71 def generate_query(self, topic_num, *args, **kwargs): 72 """ 73 Generates a simple BM25 query based off the query facets. Searches over all the document facets. 74 :param topic_num: 75 :param args: 76 :param kwargs: 77 :return: 78 """ 79 _, _, should = self._generate_base_query(topic_num) 80 81 query = { 82 "query": { 83 "bool": should, 84 } 85 } 86 87 return query 88 89 def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]): 90 """ 91 Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used 92 for log normalization. 93 94 Score = log(bm25)/log(z) + embed_score 95 :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score} 96 """ 97 self.top_bm25_scores = scores 98 99 def has_bm25_scores(self): 100 """ 101 Checks if BM25 scores have been set 102 :return: 103 """ 104 return self.top_bm25_scores is not None 105 106 @apply_config 107 def generate_query_embedding( 108 self, topic_num, encoder, *args, norm_weight=2.15, ablations=False, cosine_ceiling=Optional[float], 109 cosine_offset: float = 1.0, **kwargs): 110 """ 111 Generates an embedding script score query for Elasticsearch as part of the NIR scoring function. 112 113 :param topic_num: The topic number to search for 114 :param encoder: The encoder that will be used for encoding the topics 115 :param norm_weight: The BM25 log normalization constant 116 :param ablations: Whether to execute ablation style queries (i.e. one query facet 117 or one document facet at a time) 118 :param cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation 119 :param args: 120 :param kwargs: Pass disable_cache to disable encoder caching 121 :return: 122 An elasticsearch script_score query 123 """ 124 125 qfields = list(self.topics[topic_num].keys()) 126 should = {"should": []} 127 128 if self.has_bm25_scores(): 129 cosine_ceiling = len(self.embed_mappings) * len(qfields) if cosine_ceiling is None else cosine_ceiling 130 norm_weight = get_z_value( 131 cosine_ceiling=cosine_ceiling, 132 bm25_ceiling=self.top_bm25_scores[topic_num], 133 ) 134 loguru.logger.debug(f"Automatic norm_weight: {norm_weight}") 135 136 params = { 137 "weights": [1] * (len(self.embed_mappings) * len(self.mappings)), 138 "offset": cosine_offset, 139 "norm_weight": norm_weight, 140 "disable_bm25": ablations, 141 } 142 143 embed_fields = [] 144 145 for qfield in qfields: 146 for field in self.mappings: 147 should["should"].append( 148 { 149 "match": { 150 f"{field}": { 151 "query": self.topics[topic_num][qfield], 152 } 153 } 154 } 155 ) 156 157 params[f"{qfield}_eb"] = encoder.encode(topic=self.topics[topic_num][qfield]) 158 embed_fields.append(f"{qfield}_eb") 159 160 query = { 161 "query": { 162 "script_score": { 163 "query": { 164 "bool": should, 165 }, 166 "script": generate_script( 167 self.embed_mappings, params, qfields=embed_fields 168 ), 169 } 170 } 171 } 172 173 loguru.logger.debug(query) 174 return query 175 176 @classmethod 177 def get_id_mapping(cls, hit): 178 """ 179 Get the document ID 180 181 :param hit: The raw document result 182 :return: 183 The document's ID 184 """ 185 return hit[cls.id_mapping]
A generic elasticsearch query. Contains methods for NIR-style (embedding) queries and normal BM25 queries. Requires topics, configs to be included
GenericElasticsearchQuery( topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs)
29 def __init__(self, topics, config, top_bm25_scores=None, mappings=None, id_mapping=None, *args, **kwargs): 30 super().__init__(topics, config) 31 32 if id_mapping is None: 33 self.id_mapping = "id" 34 35 if mappings is None: 36 self.mappings = ["Text"] 37 else: 38 self.mappings = mappings 39 40 self.topics = topics 41 self.config = config 42 self.query_type = self.config.query_type 43 44 self.embed_mappings = ["Text_Embedding"] 45 46 self.query_funcs = { 47 "query": self.generate_query, 48 "embedding": self.generate_query_embedding, 49 } 50 51 self.top_bm25_scores = top_bm25_scores
def
generate_query(self, topic_num, *args, **kwargs):
71 def generate_query(self, topic_num, *args, **kwargs): 72 """ 73 Generates a simple BM25 query based off the query facets. Searches over all the document facets. 74 :param topic_num: 75 :param args: 76 :param kwargs: 77 :return: 78 """ 79 _, _, should = self._generate_base_query(topic_num) 80 81 query = { 82 "query": { 83 "bool": should, 84 } 85 } 86 87 return query
Generates a simple BM25 query based off the query facets. Searches over all the document facets.
Parameters
- topic_num:
- args:
- kwargs:
Returns
def
set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]):
89 def set_bm25_scores(self, scores: Dict[Union[str, int], Union[int, float]]): 90 """ 91 Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used 92 for log normalization. 93 94 Score = log(bm25)/log(z) + embed_score 95 :param scores: Top BM25 Scores of the form {topic_num: top_bm25_score} 96 """ 97 self.top_bm25_scores = scores
Sets BM25 scores that are used for NIR-style scoring. The top BM25 score for each topic is used for log normalization.
Score = log(bm25)/log(z) + embed_score
Parameters
- scores: Top BM25 Scores of the form {topic_num: top_bm25_score}
def
has_bm25_scores(self):
99 def has_bm25_scores(self): 100 """ 101 Checks if BM25 scores have been set 102 :return: 103 """ 104 return self.top_bm25_scores is not None
Checks if BM25 scores have been set
Returns
def
generate_query_embedding(self, *args, **kwargs):
229 def use_config(self, *args, **kwargs): 230 """ 231 Replaces keywords and args passed to the function with ones from self.config. 232 233 :param self: 234 :param args: To be updated 235 :param kwargs: To be updated 236 :return: 237 """ 238 if self.config is not None: 239 kwargs = self.config.__update__(**kwargs) 240 241 return func(self, *args, **kwargs)
Generates an embedding script score query for Elasticsearch as part of the NIR scoring function.
Parameters
- topic_num: The topic number to search for
- encoder: The encoder that will be used for encoding the topics
- norm_weight: The BM25 log normalization constant
- ablations: Whether to execute ablation style queries (i.e. one query facet or one document facet at a time)
- cosine_ceiling: Cosine ceiling used for automatic z-log normalization parameter calculation
- args:
- kwargs: Pass disable_cache to disable encoder caching
Returns
An elasticsearch script_score query
@classmethod
def
get_id_mapping(cls, hit):
176 @classmethod 177 def get_id_mapping(cls, hit): 178 """ 179 Get the document ID 180 181 :param hit: The raw document result 182 :return: 183 The document's ID 184 """ 185 return hit[cls.id_mapping]
Get the document ID
Parameters
- hit: The raw document result
Returns
The document's ID