debeir.datasets.clinical_trials
1import csv 2from dataclasses import dataclass 3from typing import Dict, List, Optional, Union 4 5import loguru 6from debeir.engines.elasticsearch.generate_script_score import generate_script 7from debeir.core.config import GenericConfig, apply_config 8from debeir.core.executor import GenericElasticsearchExecutor 9from debeir.core.parser import Parser 10from debeir.core.query import GenericElasticsearchQuery 11from debeir.rankers.transformer_sent_encoder import Encoder 12from debeir.utils.scaler import get_z_value 13from elasticsearch import AsyncElasticsearch as Elasticsearch 14 15 16@dataclass(init=True, unsafe_hash=True) 17class TrialsQueryConfig(GenericConfig): 18 query_field_usage: str = None 19 embed_field_usage: str = None 20 fields: List[str] = None 21 22 def validate(self): 23 """ 24 Checks if query type is included, and checks if an encoder is included for embedding queries 25 """ 26 if self.query_type == "embedding": 27 assert self.query_field_usage and self.embed_field_usage, ( 28 "Must have both field usages" " if embedding query" 29 ) 30 assert ( 31 self.encoder_fp and self.encoder 32 ), "Must provide encoder path for embedding model" 33 assert self.norm_weight is not None or self.automatic is not None, ( 34 "Norm weight be specified or be " "automatic " 35 ) 36 37 assert ( 38 self.query_field_usage is not None or self.fields is not None 39 ), "Must have a query field" 40 assert self.query_type in [ 41 "ablation", 42 "query", 43 "query_best", 44 "embedding", 45 ], "Check your query type" 46 47 @classmethod 48 def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig": 49 return super().from_toml(fp, cls, *args, **kwargs) 50 51 @classmethod 52 def from_dict(cls, **kwargs) -> "GenericConfig": 53 return super().from_dict(cls, **kwargs) 54 55 56class TrialsElasticsearchQuery(GenericElasticsearchQuery): 57 """ 58 Elasticsearch Query object for the Clinical Trials Index 59 """ 60 topics: Dict[int, Dict[str, str]] 61 query_type: str 62 fields: List[int] 63 query_funcs: Dict 64 config: GenericConfig 65 id_mapping: str = "_id" 66 mappings: List[str] 67 config: TrialsQueryConfig 68 69 def __init__(self, topics, query_type, config=None, *args, **kwargs): 70 super().__init__(topics, config, *args, **kwargs) 71 self.query_type = query_type 72 self.config = config 73 self.topics = topics 74 self.fields = [] 75 self.mappings = [ 76 "HasExpandedAccess", 77 "BriefSummary.Textblock", 78 "CompletionDate.Type", 79 "OversightInfo.Text", 80 "OverallContactBackup.PhoneExt", 81 "RemovedCountries.Text", 82 "SecondaryOutcome", 83 "Sponsors.LeadSponsor.Text", 84 "BriefTitle", 85 "IDInfo.NctID", 86 "IDInfo.SecondaryID", 87 "OverallContactBackup.Phone", 88 "Eligibility.StudyPop.Textblock", 89 "DetailedDescription.Textblock", 90 "Eligibility.MinimumAge", 91 "Sponsors.Collaborator", 92 "Reference", 93 "Eligibility.Criteria.Textblock", 94 "XMLName.Space", 95 "Rank", 96 "OverallStatus", 97 "InterventionBrowse.Text", 98 "Eligibility.Text", 99 "Intervention", 100 "BiospecDescr.Textblock", 101 "ResponsibleParty.NameTitle", 102 "NumberOfArms", 103 "ResponsibleParty.ResponsiblePartyType", 104 "IsSection801", 105 "Acronym", 106 "Eligibility.MaximumAge", 107 "DetailedDescription.Text", 108 "StudyDesign", 109 "OtherOutcome", 110 "VerificationDate", 111 "ConditionBrowse.MeshTerm", 112 "Enrollment.Text", 113 "IDInfo.Text", 114 "ConditionBrowse.Text", 115 "FirstreceivedDate", 116 "NumberOfGroups", 117 "OversightInfo.HasDmc", 118 "PrimaryCompletionDate.Text", 119 "ResultsReference", 120 "Eligibility.StudyPop.Text", 121 "IsFdaRegulated", 122 "WhyStopped", 123 "ArmGroup", 124 "OverallContact.LastName", 125 "Phase", 126 "RemovedCountries.Country", 127 "InterventionBrowse.MeshTerm", 128 "Eligibility.HealthyVolunteers", 129 "Location", 130 "OfficialTitle", 131 "OverallContact.Email", 132 "RequiredHeader.Text", 133 "RequiredHeader.URL", 134 "LocationCountries.Country", 135 "OverallContact.PhoneExt", 136 "Condition", 137 "PrimaryOutcome", 138 "LocationCountries.Text", 139 "BiospecDescr.Text", 140 "IDInfo.OrgStudyID", 141 "Link", 142 "OverallContact.Phone", 143 "Source", 144 "ResponsibleParty.InvestigatorAffiliation", 145 "StudyType", 146 "FirstreceivedResultsDate", 147 "Enrollment.Type", 148 "Eligibility.Gender", 149 "OverallContactBackup.LastName", 150 "Keyword", 151 "BiospecRetention", 152 "CompletionDate.Text", 153 "OverallContact.Text", 154 "RequiredHeader.DownloadDate", 155 "Sponsors.Text", 156 "Text", 157 "Eligibility.SamplingMethod", 158 "LastchangedDate", 159 "ResponsibleParty.InvestigatorFullName", 160 "StartDate", 161 "RequiredHeader.LinkText", 162 "OverallOfficial", 163 "Sponsors.LeadSponsor.AgencyClass", 164 "OverallContactBackup.Text", 165 "Eligibility.Criteria.Text", 166 "XMLName.Local", 167 "OversightInfo.Authority", 168 "PrimaryCompletionDate.Type", 169 "ResponsibleParty.Organization", 170 "IDInfo.NctAlias", 171 "ResponsibleParty.Text", 172 "TargetDuration", 173 "Sponsors.LeadSponsor.Agency", 174 "BriefSummary.Text", 175 "OverallContactBackup.Email", 176 "ResponsibleParty.InvestigatorTitle", 177 ] 178 179 self.best_recall_fields = [ 180 "LocationCountries.Country", 181 "BiospecRetention", 182 "DetailedDescription.Textblock", 183 "HasExpandedAccess", 184 "ConditionBrowse.MeshTerm", 185 "RequiredHeader.LinkText", 186 "WhyStopped", 187 "BriefSummary.Textblock", 188 "Eligibility.Criteria.Textblock", 189 "OfficialTitle", 190 "Eligibility.MaximumAge", 191 "Eligibility.StudyPop.Textblock", 192 "BiospecDescr.Textblock", 193 "BriefTitle", 194 "Eligibility.MinimumAge", 195 "ResponsibleParty.Organization", 196 "TargetDuration", 197 "Condition", 198 "IDInfo.OrgStudyID", 199 "Keyword", 200 "Source", 201 "Sponsors.LeadSponsor.Agency", 202 "ResponsibleParty.InvestigatorAffiliation", 203 "OversightInfo.Authority", 204 "OversightInfo.HasDmc", 205 "OverallContact.Phone", 206 "Phase", 207 "OverallContactBackup.LastName", 208 "Acronym", 209 "InterventionBrowse.MeshTerm", 210 "RemovedCountries.Country", 211 ] 212 self.best_map_fields = [ 213 "Eligibility.Gender", 214 "LocationCountries.Country", 215 "DetailedDescription.Textblock", 216 "BriefSummary.Textblock", 217 "ConditionBrowse.MeshTerm", 218 "Eligibility.Criteria.Textblock", 219 "InterventionBrowse.MeshTerm", 220 "StudyType", 221 "IsFdaRegulated", 222 "HasExpandedAccess", 223 "RequiredHeader.LinkText", 224 "BiospecRetention", 225 "OfficialTitle", 226 "Eligibility.SamplingMethod", 227 "Eligibility.StudyPop.Textblock", 228 "Condition", 229 "Eligibility.MinimumAge", 230 "Keyword", 231 "Eligibility.MaximumAge", 232 "BriefTitle", 233 ] 234 self.best_embed_fields = [ 235 "WhyStopped", 236 "HasExpandedAccess", 237 "BiospecRetention", 238 "BriefSummary.Textblock", 239 "LocationCountries.Country", 240 "ConditionBrowse.MeshTerm", 241 "DetailedDescription.Textblock", 242 "RequiredHeader.LinkText", 243 "Eligibility.Criteria.Textblock", 244 ] 245 246 self.sensible = [ 247 "BriefSummary.Textblock" "BriefTitle", 248 "Eligibility.StudyPop.Textblock", 249 "DetailedDescription.Textblock", 250 "Eligibility.MinimumAge", 251 "Eligibility.Criteria.Textblock", 252 "InterventionBrowse.Text", 253 "Eligibility.Text", 254 "BiospecDescr.Textblock", 255 "Eligibility.MaximumAge", 256 "DetailedDescription.Text", 257 "ConditionBrowse.MeshTerm", 258 "ConditionBrowse.Text", 259 "Eligibility.StudyPop.Text", 260 "InterventionBrowse.MeshTerm", 261 "OfficialTitle", 262 "Condition", 263 "PrimaryOutcome", 264 "BiospecDescr.Text", 265 "Eligibility.Gender", 266 "Keyword", 267 "BiospecRetention", 268 "Eligibility.Criteria.Text", 269 "BriefSummary.Text", 270 ] 271 272 self.sensible_embed = [ 273 "BriefSummary.Textblock" "BriefTitle", 274 "Eligibility.StudyPop.Textblock", 275 "DetailedDescription.Textblock", 276 "Eligibility.Criteria.Textblock", 277 "InterventionBrowse.Text", 278 "Eligibility.Text", 279 "BiospecDescr.Textblock", 280 "DetailedDescription.Text", 281 "ConditionBrowse.MeshTerm", 282 "ConditionBrowse.Text", 283 "Eligibility.StudyPop.Text", 284 "InterventionBrowse.MeshTerm", 285 "OfficialTitle", 286 "Condition", 287 "PrimaryOutcome", 288 "BiospecDescr.Text", 289 "Keyword", 290 "BiospecRetention", 291 "Eligibility.Criteria.Text", 292 "BriefSummary.Text", 293 ] 294 295 self.sensible_embed_safe = list( 296 set(self.best_recall_fields).intersection(set(self.sensible_embed)) 297 ) 298 299 self.query_funcs = { 300 "query": self.generate_query, 301 "ablation": self.generate_query_ablation, 302 "embedding": self.generate_query_embedding, 303 } 304 305 loguru.logger.debug(self.sensible_embed_safe) 306 307 self.field_usage = { 308 "best_recall_fields": self.best_recall_fields, 309 "all": self.mappings, 310 "best_map_fields": self.best_map_fields, 311 "best_embed_fields": self.best_embed_fields, 312 "sensible": self.sensible, 313 "sensible_embed": self.sensible_embed, 314 "sensible_embed_safe": self.sensible_embed_safe, 315 } 316 317 @apply_config 318 def generate_query(self, topic_num, query_field_usage, **kwargs) -> Dict: 319 """ 320 Generates a query for the clinical trials index 321 322 :param topic_num: Topic number to search 323 :param query_field_usage: Which document facets to search over 324 :param kwargs: 325 :return: 326 A basic elasticsearch query for clinical trials 327 """ 328 fields = self.field_usage[query_field_usage] 329 should = {"should": []} 330 331 qfield = list(self.topics[topic_num].keys())[0] 332 query = self.topics[topic_num][qfield] 333 334 for i, field in enumerate(fields): 335 should["should"].append( 336 { 337 "match": { 338 f"{field}": { 339 "query": query, 340 } 341 } 342 } 343 ) 344 345 query = { 346 "query": { 347 "bool": should, 348 } 349 } 350 351 return query 352 353 def generate_query_ablation(self, topic_num, **kwargs): 354 """ 355 Only search one document facet at a time 356 :param topic_num: 357 :param kwargs: 358 :return: 359 """ 360 query = {"query": {"match": {}}} 361 362 for field in self.fields: 363 query["query"]["match"][self.mappings[field]] = "" 364 365 for qfield in self.fields: 366 qfield = self.mappings[qfield] 367 for field in self.topics[topic_num]: 368 query["query"]["match"][qfield] += self.topics[topic_num][field] 369 370 return query 371 372 @apply_config 373 def generate_query_embedding( 374 self, 375 topic_num, 376 encoder, 377 query_field_usage, 378 embed_field_usage, 379 cosine_weights: List[float] = None, 380 query_weight: List[float] = None, 381 norm_weight=2.15, 382 ablations=False, 383 automatic_scores=None, 384 **kwargs, 385 ): 386 """ 387 Computes the NIR score for a given topic 388 389 Score = log(BM25)/log(norm_weight) + embedding_score 390 391 :param topic_num: 392 :param encoder: 393 :param query_field_usage: 394 :param embed_field_usage: 395 :param cosine_weights: 396 :param query_weight: 397 :param norm_weight: 398 :param ablations: 399 :param automatic_scores: 400 :param kwargs: 401 :return: 402 """ 403 should = {"should": []} 404 405 assert norm_weight or automatic_scores 406 407 query_fields = self.field_usage[query_field_usage] 408 embed_fields = self.field_usage[embed_field_usage] 409 410 qfield = list(self.topics[topic_num].keys())[0] 411 query = self.topics[topic_num][qfield] 412 413 for i, field in enumerate(query_fields): 414 should["should"].append( 415 { 416 "match": { 417 f"{field}": { 418 "query": query, 419 "boost": query_weight[i] if query_weight else 1, 420 } 421 } 422 } 423 ) 424 425 if automatic_scores is not None: 426 norm_weight = get_z_value( 427 cosine_ceiling=len(embed_fields) * len(query_fields), 428 bm25_ceiling=automatic_scores[topic_num], 429 ) 430 431 params = { 432 "weights": cosine_weights if cosine_weights else [1] * len(embed_fields), 433 "q_eb": encoder.encode(self.topics[topic_num][qfield]), 434 "offset": 1.0, 435 "norm_weight": norm_weight, 436 "disable_bm25": ablations, 437 } 438 439 query = { 440 "query": { 441 "script_score": { 442 "query": { 443 "bool": should, 444 }, 445 "script": generate_script(self.best_embed_fields, params=params), 446 }, 447 } 448 } 449 450 return query 451 452 def get_query_type(self, *args, **kwargs): 453 return self.query_funcs[self.query_type](*args, **kwargs) 454 455 def get_id_mapping(self, hit): 456 return hit[self.id_mapping] 457 458 459class ClinicalTrialsElasticsearchExecutor(GenericElasticsearchExecutor): 460 """ 461 Executes queries given a query object. 462 """ 463 query: TrialsElasticsearchQuery 464 465 def __init__( 466 self, 467 topics: Dict[Union[str, int], Dict[str, str]], 468 client: Elasticsearch, 469 index_name: str, 470 output_file: str, 471 query: TrialsElasticsearchQuery, 472 encoder: Optional[Encoder] = None, 473 config=None, 474 *args, 475 **kwargs, 476 ): 477 super().__init__( 478 topics, 479 client, 480 index_name, 481 output_file, 482 query, 483 encoder, 484 config=config, 485 *args, 486 **kwargs, 487 ) 488 489 self.query_fns = { 490 "query": self.generate_query, 491 "ablation": self.generate_query_ablation, 492 "embedding": self.generate_embedding_query, 493 } 494 495 496class ClinicalTrialParser(Parser): 497 """ 498 Parser for Clinical Trials topics 499 """ 500 501 @classmethod 502 def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]: 503 topics = {} 504 reader = csv.reader(csvfile) 505 for i, row in enumerate(reader): 506 if i == 0: 507 continue 508 509 _id = row[0] 510 text = row[1] 511 512 topics[_id] = {"text": text} 513 514 return topics
17@dataclass(init=True, unsafe_hash=True) 18class TrialsQueryConfig(GenericConfig): 19 query_field_usage: str = None 20 embed_field_usage: str = None 21 fields: List[str] = None 22 23 def validate(self): 24 """ 25 Checks if query type is included, and checks if an encoder is included for embedding queries 26 """ 27 if self.query_type == "embedding": 28 assert self.query_field_usage and self.embed_field_usage, ( 29 "Must have both field usages" " if embedding query" 30 ) 31 assert ( 32 self.encoder_fp and self.encoder 33 ), "Must provide encoder path for embedding model" 34 assert self.norm_weight is not None or self.automatic is not None, ( 35 "Norm weight be specified or be " "automatic " 36 ) 37 38 assert ( 39 self.query_field_usage is not None or self.fields is not None 40 ), "Must have a query field" 41 assert self.query_type in [ 42 "ablation", 43 "query", 44 "query_best", 45 "embedding", 46 ], "Check your query type" 47 48 @classmethod 49 def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig": 50 return super().from_toml(fp, cls, *args, **kwargs) 51 52 @classmethod 53 def from_dict(cls, **kwargs) -> "GenericConfig": 54 return super().from_dict(cls, **kwargs)
TrialsQueryConfig( query_type: str, index: str = None, encoder_normalize: bool = True, ablations: bool = False, norm_weight: float = None, automatic: bool = None, encoder: object = None, encoder_fp: str = None, query_weights: List[float] = None, cosine_weights: List[float] = None, evaluate: bool = False, qrels: str = None, config_fn: str = None, query_fn: str = None, parser_fn: str = None, executor_fn: str = None, cosine_ceiling: float = None, topics_path: str = None, return_id_only: bool = False, overwrite_output_if_exists: bool = False, output_file: str = None, run_name: str = None, query_field_usage: str = None, embed_field_usage: str = None, fields: List[str] = None)
def
validate(self):
23 def validate(self): 24 """ 25 Checks if query type is included, and checks if an encoder is included for embedding queries 26 """ 27 if self.query_type == "embedding": 28 assert self.query_field_usage and self.embed_field_usage, ( 29 "Must have both field usages" " if embedding query" 30 ) 31 assert ( 32 self.encoder_fp and self.encoder 33 ), "Must provide encoder path for embedding model" 34 assert self.norm_weight is not None or self.automatic is not None, ( 35 "Norm weight be specified or be " "automatic " 36 ) 37 38 assert ( 39 self.query_field_usage is not None or self.fields is not None 40 ), "Must have a query field" 41 assert self.query_type in [ 42 "ablation", 43 "query", 44 "query_best", 45 "embedding", 46 ], "Check your query type"
Checks if query type is included, and checks if an encoder is included for embedding queries
48 @classmethod 49 def from_toml(cls, fp: str, *args, **kwargs) -> "GenericConfig": 50 return super().from_toml(fp, cls, *args, **kwargs)
Instantiates a Config object from a toml file
Parameters
- fp: File path of the Config TOML file
- field_class: Class of the Config object to be instantiated
- args: Arguments to be passed to Config
- kwargs: Keyword arguments to be passed
Returns
A instantiated and validated Config object.
52 @classmethod 53 def from_dict(cls, **kwargs) -> "GenericConfig": 54 return super().from_dict(cls, **kwargs)
Instantiates a Config object from a dictionary
Parameters
- data_class:
- kwargs:
Returns
Inherited Members
57class TrialsElasticsearchQuery(GenericElasticsearchQuery): 58 """ 59 Elasticsearch Query object for the Clinical Trials Index 60 """ 61 topics: Dict[int, Dict[str, str]] 62 query_type: str 63 fields: List[int] 64 query_funcs: Dict 65 config: GenericConfig 66 id_mapping: str = "_id" 67 mappings: List[str] 68 config: TrialsQueryConfig 69 70 def __init__(self, topics, query_type, config=None, *args, **kwargs): 71 super().__init__(topics, config, *args, **kwargs) 72 self.query_type = query_type 73 self.config = config 74 self.topics = topics 75 self.fields = [] 76 self.mappings = [ 77 "HasExpandedAccess", 78 "BriefSummary.Textblock", 79 "CompletionDate.Type", 80 "OversightInfo.Text", 81 "OverallContactBackup.PhoneExt", 82 "RemovedCountries.Text", 83 "SecondaryOutcome", 84 "Sponsors.LeadSponsor.Text", 85 "BriefTitle", 86 "IDInfo.NctID", 87 "IDInfo.SecondaryID", 88 "OverallContactBackup.Phone", 89 "Eligibility.StudyPop.Textblock", 90 "DetailedDescription.Textblock", 91 "Eligibility.MinimumAge", 92 "Sponsors.Collaborator", 93 "Reference", 94 "Eligibility.Criteria.Textblock", 95 "XMLName.Space", 96 "Rank", 97 "OverallStatus", 98 "InterventionBrowse.Text", 99 "Eligibility.Text", 100 "Intervention", 101 "BiospecDescr.Textblock", 102 "ResponsibleParty.NameTitle", 103 "NumberOfArms", 104 "ResponsibleParty.ResponsiblePartyType", 105 "IsSection801", 106 "Acronym", 107 "Eligibility.MaximumAge", 108 "DetailedDescription.Text", 109 "StudyDesign", 110 "OtherOutcome", 111 "VerificationDate", 112 "ConditionBrowse.MeshTerm", 113 "Enrollment.Text", 114 "IDInfo.Text", 115 "ConditionBrowse.Text", 116 "FirstreceivedDate", 117 "NumberOfGroups", 118 "OversightInfo.HasDmc", 119 "PrimaryCompletionDate.Text", 120 "ResultsReference", 121 "Eligibility.StudyPop.Text", 122 "IsFdaRegulated", 123 "WhyStopped", 124 "ArmGroup", 125 "OverallContact.LastName", 126 "Phase", 127 "RemovedCountries.Country", 128 "InterventionBrowse.MeshTerm", 129 "Eligibility.HealthyVolunteers", 130 "Location", 131 "OfficialTitle", 132 "OverallContact.Email", 133 "RequiredHeader.Text", 134 "RequiredHeader.URL", 135 "LocationCountries.Country", 136 "OverallContact.PhoneExt", 137 "Condition", 138 "PrimaryOutcome", 139 "LocationCountries.Text", 140 "BiospecDescr.Text", 141 "IDInfo.OrgStudyID", 142 "Link", 143 "OverallContact.Phone", 144 "Source", 145 "ResponsibleParty.InvestigatorAffiliation", 146 "StudyType", 147 "FirstreceivedResultsDate", 148 "Enrollment.Type", 149 "Eligibility.Gender", 150 "OverallContactBackup.LastName", 151 "Keyword", 152 "BiospecRetention", 153 "CompletionDate.Text", 154 "OverallContact.Text", 155 "RequiredHeader.DownloadDate", 156 "Sponsors.Text", 157 "Text", 158 "Eligibility.SamplingMethod", 159 "LastchangedDate", 160 "ResponsibleParty.InvestigatorFullName", 161 "StartDate", 162 "RequiredHeader.LinkText", 163 "OverallOfficial", 164 "Sponsors.LeadSponsor.AgencyClass", 165 "OverallContactBackup.Text", 166 "Eligibility.Criteria.Text", 167 "XMLName.Local", 168 "OversightInfo.Authority", 169 "PrimaryCompletionDate.Type", 170 "ResponsibleParty.Organization", 171 "IDInfo.NctAlias", 172 "ResponsibleParty.Text", 173 "TargetDuration", 174 "Sponsors.LeadSponsor.Agency", 175 "BriefSummary.Text", 176 "OverallContactBackup.Email", 177 "ResponsibleParty.InvestigatorTitle", 178 ] 179 180 self.best_recall_fields = [ 181 "LocationCountries.Country", 182 "BiospecRetention", 183 "DetailedDescription.Textblock", 184 "HasExpandedAccess", 185 "ConditionBrowse.MeshTerm", 186 "RequiredHeader.LinkText", 187 "WhyStopped", 188 "BriefSummary.Textblock", 189 "Eligibility.Criteria.Textblock", 190 "OfficialTitle", 191 "Eligibility.MaximumAge", 192 "Eligibility.StudyPop.Textblock", 193 "BiospecDescr.Textblock", 194 "BriefTitle", 195 "Eligibility.MinimumAge", 196 "ResponsibleParty.Organization", 197 "TargetDuration", 198 "Condition", 199 "IDInfo.OrgStudyID", 200 "Keyword", 201 "Source", 202 "Sponsors.LeadSponsor.Agency", 203 "ResponsibleParty.InvestigatorAffiliation", 204 "OversightInfo.Authority", 205 "OversightInfo.HasDmc", 206 "OverallContact.Phone", 207 "Phase", 208 "OverallContactBackup.LastName", 209 "Acronym", 210 "InterventionBrowse.MeshTerm", 211 "RemovedCountries.Country", 212 ] 213 self.best_map_fields = [ 214 "Eligibility.Gender", 215 "LocationCountries.Country", 216 "DetailedDescription.Textblock", 217 "BriefSummary.Textblock", 218 "ConditionBrowse.MeshTerm", 219 "Eligibility.Criteria.Textblock", 220 "InterventionBrowse.MeshTerm", 221 "StudyType", 222 "IsFdaRegulated", 223 "HasExpandedAccess", 224 "RequiredHeader.LinkText", 225 "BiospecRetention", 226 "OfficialTitle", 227 "Eligibility.SamplingMethod", 228 "Eligibility.StudyPop.Textblock", 229 "Condition", 230 "Eligibility.MinimumAge", 231 "Keyword", 232 "Eligibility.MaximumAge", 233 "BriefTitle", 234 ] 235 self.best_embed_fields = [ 236 "WhyStopped", 237 "HasExpandedAccess", 238 "BiospecRetention", 239 "BriefSummary.Textblock", 240 "LocationCountries.Country", 241 "ConditionBrowse.MeshTerm", 242 "DetailedDescription.Textblock", 243 "RequiredHeader.LinkText", 244 "Eligibility.Criteria.Textblock", 245 ] 246 247 self.sensible = [ 248 "BriefSummary.Textblock" "BriefTitle", 249 "Eligibility.StudyPop.Textblock", 250 "DetailedDescription.Textblock", 251 "Eligibility.MinimumAge", 252 "Eligibility.Criteria.Textblock", 253 "InterventionBrowse.Text", 254 "Eligibility.Text", 255 "BiospecDescr.Textblock", 256 "Eligibility.MaximumAge", 257 "DetailedDescription.Text", 258 "ConditionBrowse.MeshTerm", 259 "ConditionBrowse.Text", 260 "Eligibility.StudyPop.Text", 261 "InterventionBrowse.MeshTerm", 262 "OfficialTitle", 263 "Condition", 264 "PrimaryOutcome", 265 "BiospecDescr.Text", 266 "Eligibility.Gender", 267 "Keyword", 268 "BiospecRetention", 269 "Eligibility.Criteria.Text", 270 "BriefSummary.Text", 271 ] 272 273 self.sensible_embed = [ 274 "BriefSummary.Textblock" "BriefTitle", 275 "Eligibility.StudyPop.Textblock", 276 "DetailedDescription.Textblock", 277 "Eligibility.Criteria.Textblock", 278 "InterventionBrowse.Text", 279 "Eligibility.Text", 280 "BiospecDescr.Textblock", 281 "DetailedDescription.Text", 282 "ConditionBrowse.MeshTerm", 283 "ConditionBrowse.Text", 284 "Eligibility.StudyPop.Text", 285 "InterventionBrowse.MeshTerm", 286 "OfficialTitle", 287 "Condition", 288 "PrimaryOutcome", 289 "BiospecDescr.Text", 290 "Keyword", 291 "BiospecRetention", 292 "Eligibility.Criteria.Text", 293 "BriefSummary.Text", 294 ] 295 296 self.sensible_embed_safe = list( 297 set(self.best_recall_fields).intersection(set(self.sensible_embed)) 298 ) 299 300 self.query_funcs = { 301 "query": self.generate_query, 302 "ablation": self.generate_query_ablation, 303 "embedding": self.generate_query_embedding, 304 } 305 306 loguru.logger.debug(self.sensible_embed_safe) 307 308 self.field_usage = { 309 "best_recall_fields": self.best_recall_fields, 310 "all": self.mappings, 311 "best_map_fields": self.best_map_fields, 312 "best_embed_fields": self.best_embed_fields, 313 "sensible": self.sensible, 314 "sensible_embed": self.sensible_embed, 315 "sensible_embed_safe": self.sensible_embed_safe, 316 } 317 318 @apply_config 319 def generate_query(self, topic_num, query_field_usage, **kwargs) -> Dict: 320 """ 321 Generates a query for the clinical trials index 322 323 :param topic_num: Topic number to search 324 :param query_field_usage: Which document facets to search over 325 :param kwargs: 326 :return: 327 A basic elasticsearch query for clinical trials 328 """ 329 fields = self.field_usage[query_field_usage] 330 should = {"should": []} 331 332 qfield = list(self.topics[topic_num].keys())[0] 333 query = self.topics[topic_num][qfield] 334 335 for i, field in enumerate(fields): 336 should["should"].append( 337 { 338 "match": { 339 f"{field}": { 340 "query": query, 341 } 342 } 343 } 344 ) 345 346 query = { 347 "query": { 348 "bool": should, 349 } 350 } 351 352 return query 353 354 def generate_query_ablation(self, topic_num, **kwargs): 355 """ 356 Only search one document facet at a time 357 :param topic_num: 358 :param kwargs: 359 :return: 360 """ 361 query = {"query": {"match": {}}} 362 363 for field in self.fields: 364 query["query"]["match"][self.mappings[field]] = "" 365 366 for qfield in self.fields: 367 qfield = self.mappings[qfield] 368 for field in self.topics[topic_num]: 369 query["query"]["match"][qfield] += self.topics[topic_num][field] 370 371 return query 372 373 @apply_config 374 def generate_query_embedding( 375 self, 376 topic_num, 377 encoder, 378 query_field_usage, 379 embed_field_usage, 380 cosine_weights: List[float] = None, 381 query_weight: List[float] = None, 382 norm_weight=2.15, 383 ablations=False, 384 automatic_scores=None, 385 **kwargs, 386 ): 387 """ 388 Computes the NIR score for a given topic 389 390 Score = log(BM25)/log(norm_weight) + embedding_score 391 392 :param topic_num: 393 :param encoder: 394 :param query_field_usage: 395 :param embed_field_usage: 396 :param cosine_weights: 397 :param query_weight: 398 :param norm_weight: 399 :param ablations: 400 :param automatic_scores: 401 :param kwargs: 402 :return: 403 """ 404 should = {"should": []} 405 406 assert norm_weight or automatic_scores 407 408 query_fields = self.field_usage[query_field_usage] 409 embed_fields = self.field_usage[embed_field_usage] 410 411 qfield = list(self.topics[topic_num].keys())[0] 412 query = self.topics[topic_num][qfield] 413 414 for i, field in enumerate(query_fields): 415 should["should"].append( 416 { 417 "match": { 418 f"{field}": { 419 "query": query, 420 "boost": query_weight[i] if query_weight else 1, 421 } 422 } 423 } 424 ) 425 426 if automatic_scores is not None: 427 norm_weight = get_z_value( 428 cosine_ceiling=len(embed_fields) * len(query_fields), 429 bm25_ceiling=automatic_scores[topic_num], 430 ) 431 432 params = { 433 "weights": cosine_weights if cosine_weights else [1] * len(embed_fields), 434 "q_eb": encoder.encode(self.topics[topic_num][qfield]), 435 "offset": 1.0, 436 "norm_weight": norm_weight, 437 "disable_bm25": ablations, 438 } 439 440 query = { 441 "query": { 442 "script_score": { 443 "query": { 444 "bool": should, 445 }, 446 "script": generate_script(self.best_embed_fields, params=params), 447 }, 448 } 449 } 450 451 return query 452 453 def get_query_type(self, *args, **kwargs): 454 return self.query_funcs[self.query_type](*args, **kwargs) 455 456 def get_id_mapping(self, hit): 457 return hit[self.id_mapping]
Elasticsearch Query object for the Clinical Trials Index
TrialsElasticsearchQuery(topics, query_type, config=None, *args, **kwargs)
70 def __init__(self, topics, query_type, config=None, *args, **kwargs): 71 super().__init__(topics, config, *args, **kwargs) 72 self.query_type = query_type 73 self.config = config 74 self.topics = topics 75 self.fields = [] 76 self.mappings = [ 77 "HasExpandedAccess", 78 "BriefSummary.Textblock", 79 "CompletionDate.Type", 80 "OversightInfo.Text", 81 "OverallContactBackup.PhoneExt", 82 "RemovedCountries.Text", 83 "SecondaryOutcome", 84 "Sponsors.LeadSponsor.Text", 85 "BriefTitle", 86 "IDInfo.NctID", 87 "IDInfo.SecondaryID", 88 "OverallContactBackup.Phone", 89 "Eligibility.StudyPop.Textblock", 90 "DetailedDescription.Textblock", 91 "Eligibility.MinimumAge", 92 "Sponsors.Collaborator", 93 "Reference", 94 "Eligibility.Criteria.Textblock", 95 "XMLName.Space", 96 "Rank", 97 "OverallStatus", 98 "InterventionBrowse.Text", 99 "Eligibility.Text", 100 "Intervention", 101 "BiospecDescr.Textblock", 102 "ResponsibleParty.NameTitle", 103 "NumberOfArms", 104 "ResponsibleParty.ResponsiblePartyType", 105 "IsSection801", 106 "Acronym", 107 "Eligibility.MaximumAge", 108 "DetailedDescription.Text", 109 "StudyDesign", 110 "OtherOutcome", 111 "VerificationDate", 112 "ConditionBrowse.MeshTerm", 113 "Enrollment.Text", 114 "IDInfo.Text", 115 "ConditionBrowse.Text", 116 "FirstreceivedDate", 117 "NumberOfGroups", 118 "OversightInfo.HasDmc", 119 "PrimaryCompletionDate.Text", 120 "ResultsReference", 121 "Eligibility.StudyPop.Text", 122 "IsFdaRegulated", 123 "WhyStopped", 124 "ArmGroup", 125 "OverallContact.LastName", 126 "Phase", 127 "RemovedCountries.Country", 128 "InterventionBrowse.MeshTerm", 129 "Eligibility.HealthyVolunteers", 130 "Location", 131 "OfficialTitle", 132 "OverallContact.Email", 133 "RequiredHeader.Text", 134 "RequiredHeader.URL", 135 "LocationCountries.Country", 136 "OverallContact.PhoneExt", 137 "Condition", 138 "PrimaryOutcome", 139 "LocationCountries.Text", 140 "BiospecDescr.Text", 141 "IDInfo.OrgStudyID", 142 "Link", 143 "OverallContact.Phone", 144 "Source", 145 "ResponsibleParty.InvestigatorAffiliation", 146 "StudyType", 147 "FirstreceivedResultsDate", 148 "Enrollment.Type", 149 "Eligibility.Gender", 150 "OverallContactBackup.LastName", 151 "Keyword", 152 "BiospecRetention", 153 "CompletionDate.Text", 154 "OverallContact.Text", 155 "RequiredHeader.DownloadDate", 156 "Sponsors.Text", 157 "Text", 158 "Eligibility.SamplingMethod", 159 "LastchangedDate", 160 "ResponsibleParty.InvestigatorFullName", 161 "StartDate", 162 "RequiredHeader.LinkText", 163 "OverallOfficial", 164 "Sponsors.LeadSponsor.AgencyClass", 165 "OverallContactBackup.Text", 166 "Eligibility.Criteria.Text", 167 "XMLName.Local", 168 "OversightInfo.Authority", 169 "PrimaryCompletionDate.Type", 170 "ResponsibleParty.Organization", 171 "IDInfo.NctAlias", 172 "ResponsibleParty.Text", 173 "TargetDuration", 174 "Sponsors.LeadSponsor.Agency", 175 "BriefSummary.Text", 176 "OverallContactBackup.Email", 177 "ResponsibleParty.InvestigatorTitle", 178 ] 179 180 self.best_recall_fields = [ 181 "LocationCountries.Country", 182 "BiospecRetention", 183 "DetailedDescription.Textblock", 184 "HasExpandedAccess", 185 "ConditionBrowse.MeshTerm", 186 "RequiredHeader.LinkText", 187 "WhyStopped", 188 "BriefSummary.Textblock", 189 "Eligibility.Criteria.Textblock", 190 "OfficialTitle", 191 "Eligibility.MaximumAge", 192 "Eligibility.StudyPop.Textblock", 193 "BiospecDescr.Textblock", 194 "BriefTitle", 195 "Eligibility.MinimumAge", 196 "ResponsibleParty.Organization", 197 "TargetDuration", 198 "Condition", 199 "IDInfo.OrgStudyID", 200 "Keyword", 201 "Source", 202 "Sponsors.LeadSponsor.Agency", 203 "ResponsibleParty.InvestigatorAffiliation", 204 "OversightInfo.Authority", 205 "OversightInfo.HasDmc", 206 "OverallContact.Phone", 207 "Phase", 208 "OverallContactBackup.LastName", 209 "Acronym", 210 "InterventionBrowse.MeshTerm", 211 "RemovedCountries.Country", 212 ] 213 self.best_map_fields = [ 214 "Eligibility.Gender", 215 "LocationCountries.Country", 216 "DetailedDescription.Textblock", 217 "BriefSummary.Textblock", 218 "ConditionBrowse.MeshTerm", 219 "Eligibility.Criteria.Textblock", 220 "InterventionBrowse.MeshTerm", 221 "StudyType", 222 "IsFdaRegulated", 223 "HasExpandedAccess", 224 "RequiredHeader.LinkText", 225 "BiospecRetention", 226 "OfficialTitle", 227 "Eligibility.SamplingMethod", 228 "Eligibility.StudyPop.Textblock", 229 "Condition", 230 "Eligibility.MinimumAge", 231 "Keyword", 232 "Eligibility.MaximumAge", 233 "BriefTitle", 234 ] 235 self.best_embed_fields = [ 236 "WhyStopped", 237 "HasExpandedAccess", 238 "BiospecRetention", 239 "BriefSummary.Textblock", 240 "LocationCountries.Country", 241 "ConditionBrowse.MeshTerm", 242 "DetailedDescription.Textblock", 243 "RequiredHeader.LinkText", 244 "Eligibility.Criteria.Textblock", 245 ] 246 247 self.sensible = [ 248 "BriefSummary.Textblock" "BriefTitle", 249 "Eligibility.StudyPop.Textblock", 250 "DetailedDescription.Textblock", 251 "Eligibility.MinimumAge", 252 "Eligibility.Criteria.Textblock", 253 "InterventionBrowse.Text", 254 "Eligibility.Text", 255 "BiospecDescr.Textblock", 256 "Eligibility.MaximumAge", 257 "DetailedDescription.Text", 258 "ConditionBrowse.MeshTerm", 259 "ConditionBrowse.Text", 260 "Eligibility.StudyPop.Text", 261 "InterventionBrowse.MeshTerm", 262 "OfficialTitle", 263 "Condition", 264 "PrimaryOutcome", 265 "BiospecDescr.Text", 266 "Eligibility.Gender", 267 "Keyword", 268 "BiospecRetention", 269 "Eligibility.Criteria.Text", 270 "BriefSummary.Text", 271 ] 272 273 self.sensible_embed = [ 274 "BriefSummary.Textblock" "BriefTitle", 275 "Eligibility.StudyPop.Textblock", 276 "DetailedDescription.Textblock", 277 "Eligibility.Criteria.Textblock", 278 "InterventionBrowse.Text", 279 "Eligibility.Text", 280 "BiospecDescr.Textblock", 281 "DetailedDescription.Text", 282 "ConditionBrowse.MeshTerm", 283 "ConditionBrowse.Text", 284 "Eligibility.StudyPop.Text", 285 "InterventionBrowse.MeshTerm", 286 "OfficialTitle", 287 "Condition", 288 "PrimaryOutcome", 289 "BiospecDescr.Text", 290 "Keyword", 291 "BiospecRetention", 292 "Eligibility.Criteria.Text", 293 "BriefSummary.Text", 294 ] 295 296 self.sensible_embed_safe = list( 297 set(self.best_recall_fields).intersection(set(self.sensible_embed)) 298 ) 299 300 self.query_funcs = { 301 "query": self.generate_query, 302 "ablation": self.generate_query_ablation, 303 "embedding": self.generate_query_embedding, 304 } 305 306 loguru.logger.debug(self.sensible_embed_safe) 307 308 self.field_usage = { 309 "best_recall_fields": self.best_recall_fields, 310 "all": self.mappings, 311 "best_map_fields": self.best_map_fields, 312 "best_embed_fields": self.best_embed_fields, 313 "sensible": self.sensible, 314 "sensible_embed": self.sensible_embed, 315 "sensible_embed_safe": self.sensible_embed_safe, 316 }
def
generate_query(self, *args, **kwargs):
229 def use_config(self, *args, **kwargs): 230 """ 231 Replaces keywords and args passed to the function with ones from self.config. 232 233 :param self: 234 :param args: To be updated 235 :param kwargs: To be updated 236 :return: 237 """ 238 if self.config is not None: 239 kwargs = self.config.__update__(**kwargs) 240 241 return func(self, *args, **kwargs)
Generates a query for the clinical trials index
Parameters
- topic_num: Topic number to search
- query_field_usage: Which document facets to search over
- kwargs:
Returns
A basic elasticsearch query for clinical trials
def
generate_query_ablation(self, topic_num, **kwargs):
354 def generate_query_ablation(self, topic_num, **kwargs): 355 """ 356 Only search one document facet at a time 357 :param topic_num: 358 :param kwargs: 359 :return: 360 """ 361 query = {"query": {"match": {}}} 362 363 for field in self.fields: 364 query["query"]["match"][self.mappings[field]] = "" 365 366 for qfield in self.fields: 367 qfield = self.mappings[qfield] 368 for field in self.topics[topic_num]: 369 query["query"]["match"][qfield] += self.topics[topic_num][field] 370 371 return query
Only search one document facet at a time
Parameters
- topic_num:
- kwargs:
Returns
def
generate_query_embedding(self, *args, **kwargs):
229 def use_config(self, *args, **kwargs): 230 """ 231 Replaces keywords and args passed to the function with ones from self.config. 232 233 :param self: 234 :param args: To be updated 235 :param kwargs: To be updated 236 :return: 237 """ 238 if self.config is not None: 239 kwargs = self.config.__update__(**kwargs) 240 241 return func(self, *args, **kwargs)
Computes the NIR score for a given topic
Score = log(BM25)/log(norm_weight) + embedding_score
Parameters
- topic_num:
- encoder:
- query_field_usage:
- embed_field_usage:
- cosine_weights:
- query_weight:
- norm_weight:
- ablations:
- automatic_scores:
- kwargs:
Returns
def
get_id_mapping(self, hit):
Get the document ID
Parameters
- hit: The raw document result
Returns
The document's ID
Inherited Members
460class ClinicalTrialsElasticsearchExecutor(GenericElasticsearchExecutor): 461 """ 462 Executes queries given a query object. 463 """ 464 query: TrialsElasticsearchQuery 465 466 def __init__( 467 self, 468 topics: Dict[Union[str, int], Dict[str, str]], 469 client: Elasticsearch, 470 index_name: str, 471 output_file: str, 472 query: TrialsElasticsearchQuery, 473 encoder: Optional[Encoder] = None, 474 config=None, 475 *args, 476 **kwargs, 477 ): 478 super().__init__( 479 topics, 480 client, 481 index_name, 482 output_file, 483 query, 484 encoder, 485 config=config, 486 *args, 487 **kwargs, 488 ) 489 490 self.query_fns = { 491 "query": self.generate_query, 492 "ablation": self.generate_query_ablation, 493 "embedding": self.generate_embedding_query, 494 }
Executes queries given a query object.
ClinicalTrialsElasticsearchExecutor( topics: Dict[Union[str, int], Dict[str, str]], client: elasticsearch.AsyncElasticsearch, index_name: str, output_file: str, query: debeir.datasets.clinical_trials.TrialsElasticsearchQuery, encoder: Optional[debeir.rankers.transformer_sent_encoder.Encoder] = None, config=None, *args, **kwargs)
466 def __init__( 467 self, 468 topics: Dict[Union[str, int], Dict[str, str]], 469 client: Elasticsearch, 470 index_name: str, 471 output_file: str, 472 query: TrialsElasticsearchQuery, 473 encoder: Optional[Encoder] = None, 474 config=None, 475 *args, 476 **kwargs, 477 ): 478 super().__init__( 479 topics, 480 client, 481 index_name, 482 output_file, 483 query, 484 encoder, 485 config=config, 486 *args, 487 **kwargs, 488 ) 489 490 self.query_fns = { 491 "query": self.generate_query, 492 "ablation": self.generate_query_ablation, 493 "embedding": self.generate_embedding_query, 494 }
497class ClinicalTrialParser(Parser): 498 """ 499 Parser for Clinical Trials topics 500 """ 501 502 @classmethod 503 def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]: 504 topics = {} 505 reader = csv.reader(csvfile) 506 for i, row in enumerate(reader): 507 if i == 0: 508 continue 509 510 _id = row[0] 511 text = row[1] 512 513 topics[_id] = {"text": text} 514 515 return topics
Parser for Clinical Trials topics
@classmethod
def
get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]:
502 @classmethod 503 def get_topics(cls, csvfile) -> Dict[int, Dict[str, str]]: 504 topics = {} 505 reader = csv.reader(csvfile) 506 for i, row in enumerate(reader): 507 if i == 0: 508 continue 509 510 _id = row[0] 511 text = row[1] 512 513 topics[_id] = {"text": text} 514 515 return topics
Instance method for getting topics, forwards instance self parameters to the _get_topics class method.