debeir.core.callbacks
Callbacks for before after running. E.g. before is for setup after is for evaluation/serialization etc
1""" 2Callbacks for before after running. 3E.g. before is for setup 4after is for evaluation/serialization etc 5""" 6 7import abc 8import os 9import tempfile 10import uuid 11from typing import List 12 13import loguru 14from debeir.datasets.factory import query_factory 15from debeir.evaluation.evaluator import Evaluator 16from debeir.core.config import GenericConfig, NIRConfig 17from debeir.core.pipeline import Pipeline 18 19 20class Callback: 21 def __init__(self): 22 self.pipeline = None 23 24 @abc.abstractmethod 25 def before(self, pipeline: Pipeline): 26 pass 27 28 @abc.abstractmethod 29 def after(self, results: List): 30 pass 31 32 33class SerializationCallback(Callback): 34 def __init__(self, config: GenericConfig, nir_config: NIRConfig): 35 super().__init__() 36 self.config = config 37 self.nir_config = nir_config 38 self.output_file = None 39 self.query_cls = query_factory[self.config.query_fn] 40 41 def before(self, pipeline: Pipeline): 42 """ 43 Check if output file exists 44 45 :return: 46 Output file path 47 """ 48 49 self.pipeline = Pipeline 50 51 output_file = self.config.output_file 52 output_dir = os.path.join(self.nir_config.output_directory, self.config.index) 53 54 if output_file is None: 55 os.makedirs(name=output_dir, exist_ok=True) 56 output_file = os.path.join(output_dir, str(uuid.uuid4())) 57 58 loguru.logger.info(f"Output file not specified, writing to: {output_file}") 59 60 else: 61 output_file = os.path.join(output_dir, output_file) 62 63 if os.path.exists(output_file): 64 if not self.config.overwrite_output_if_exists: 65 raise RuntimeError("Directory exists and isn't explicitly overwritten " 66 "in config with overwrite_output_if_exists=True") 67 68 loguru.logger.info(f"Output file exists: {output_file}. Overwriting...") 69 open(output_file, "w+").close() 70 71 pipeline.output_file = output_file 72 self.output_file = output_file 73 74 def after(self, results: List): 75 """ 76 Serialize results to self.output_file in a TREC-style format 77 :param topic_num: Topic number to serialize 78 :param res: Raw elasticsearch result 79 :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME) 80 """ 81 82 self._after(results, 83 output_file=self.output_file, 84 run_name=self.config.run_name) 85 86 @classmethod 87 def _after(self, results: List, output_file, run_name=None): 88 if run_name is None: 89 run_name = "NO_RUN_NAME" 90 91 with open(output_file, "a+t") as writer: 92 for doc in results: 93 line = f"{doc.topic_num}\t" \ 94 f"Q0\t" \ 95 f"{doc.doc_id}\t" \ 96 f"{doc.scores['rank']}\t" \ 97 f"{doc.score}\t" \ 98 f"{run_name}\n" 99 100 writer.write(line) 101 102 103class EvaluationCallback(Callback): 104 def __init__(self, evaluator: Evaluator, config): 105 super().__init__() 106 self.evaluator = evaluator 107 self.config = config 108 self.parsed_run = None 109 110 def before(self, pipeline: Pipeline): 111 self.pipeline = Pipeline 112 113 def after(self, results: List, id_field="id"): 114 if self.pipeline.output_file is None: 115 directory_name = tempfile.mkdtemp() 116 fn = str(uuid.uuid4()) 117 118 fp = os.path.join(directory_name, fn) 119 120 query = query_factory[self.config.query_fn] 121 query.id_field = id_field 122 123 SerializationCallback._after(results, 124 output_file=fp, 125 run_name=self.config.run_name) 126 127 self.pipeline.output_file = fp 128 129 parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file, 130 disable_cache=True) 131 self.parsed_run = parsed_run 132 133 return self.parsed_run
class
Callback:
21class Callback: 22 def __init__(self): 23 self.pipeline = None 24 25 @abc.abstractmethod 26 def before(self, pipeline: Pipeline): 27 pass 28 29 @abc.abstractmethod 30 def after(self, results: List): 31 pass
34class SerializationCallback(Callback): 35 def __init__(self, config: GenericConfig, nir_config: NIRConfig): 36 super().__init__() 37 self.config = config 38 self.nir_config = nir_config 39 self.output_file = None 40 self.query_cls = query_factory[self.config.query_fn] 41 42 def before(self, pipeline: Pipeline): 43 """ 44 Check if output file exists 45 46 :return: 47 Output file path 48 """ 49 50 self.pipeline = Pipeline 51 52 output_file = self.config.output_file 53 output_dir = os.path.join(self.nir_config.output_directory, self.config.index) 54 55 if output_file is None: 56 os.makedirs(name=output_dir, exist_ok=True) 57 output_file = os.path.join(output_dir, str(uuid.uuid4())) 58 59 loguru.logger.info(f"Output file not specified, writing to: {output_file}") 60 61 else: 62 output_file = os.path.join(output_dir, output_file) 63 64 if os.path.exists(output_file): 65 if not self.config.overwrite_output_if_exists: 66 raise RuntimeError("Directory exists and isn't explicitly overwritten " 67 "in config with overwrite_output_if_exists=True") 68 69 loguru.logger.info(f"Output file exists: {output_file}. Overwriting...") 70 open(output_file, "w+").close() 71 72 pipeline.output_file = output_file 73 self.output_file = output_file 74 75 def after(self, results: List): 76 """ 77 Serialize results to self.output_file in a TREC-style format 78 :param topic_num: Topic number to serialize 79 :param res: Raw elasticsearch result 80 :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME) 81 """ 82 83 self._after(results, 84 output_file=self.output_file, 85 run_name=self.config.run_name) 86 87 @classmethod 88 def _after(self, results: List, output_file, run_name=None): 89 if run_name is None: 90 run_name = "NO_RUN_NAME" 91 92 with open(output_file, "a+t") as writer: 93 for doc in results: 94 line = f"{doc.topic_num}\t" \ 95 f"Q0\t" \ 96 f"{doc.doc_id}\t" \ 97 f"{doc.scores['rank']}\t" \ 98 f"{doc.score}\t" \ 99 f"{run_name}\n" 100 101 writer.write(line)
SerializationCallback( config: debeir.core.config.GenericConfig, nir_config: debeir.core.config.NIRConfig)
42 def before(self, pipeline: Pipeline): 43 """ 44 Check if output file exists 45 46 :return: 47 Output file path 48 """ 49 50 self.pipeline = Pipeline 51 52 output_file = self.config.output_file 53 output_dir = os.path.join(self.nir_config.output_directory, self.config.index) 54 55 if output_file is None: 56 os.makedirs(name=output_dir, exist_ok=True) 57 output_file = os.path.join(output_dir, str(uuid.uuid4())) 58 59 loguru.logger.info(f"Output file not specified, writing to: {output_file}") 60 61 else: 62 output_file = os.path.join(output_dir, output_file) 63 64 if os.path.exists(output_file): 65 if not self.config.overwrite_output_if_exists: 66 raise RuntimeError("Directory exists and isn't explicitly overwritten " 67 "in config with overwrite_output_if_exists=True") 68 69 loguru.logger.info(f"Output file exists: {output_file}. Overwriting...") 70 open(output_file, "w+").close() 71 72 pipeline.output_file = output_file 73 self.output_file = output_file
Check if output file exists
Returns
Output file path
def
after(self, results: List):
75 def after(self, results: List): 76 """ 77 Serialize results to self.output_file in a TREC-style format 78 :param topic_num: Topic number to serialize 79 :param res: Raw elasticsearch result 80 :param run_name: The run name for TREC-style runs (default: NO_RUN_NAME) 81 """ 82 83 self._after(results, 84 output_file=self.output_file, 85 run_name=self.config.run_name)
Serialize results to self.output_file in a TREC-style format
Parameters
- topic_num: Topic number to serialize
- res: Raw elasticsearch result
- run_name: The run name for TREC-style runs (default: NO_RUN_NAME)
104class EvaluationCallback(Callback): 105 def __init__(self, evaluator: Evaluator, config): 106 super().__init__() 107 self.evaluator = evaluator 108 self.config = config 109 self.parsed_run = None 110 111 def before(self, pipeline: Pipeline): 112 self.pipeline = Pipeline 113 114 def after(self, results: List, id_field="id"): 115 if self.pipeline.output_file is None: 116 directory_name = tempfile.mkdtemp() 117 fn = str(uuid.uuid4()) 118 119 fp = os.path.join(directory_name, fn) 120 121 query = query_factory[self.config.query_fn] 122 query.id_field = id_field 123 124 SerializationCallback._after(results, 125 output_file=fp, 126 run_name=self.config.run_name) 127 128 self.pipeline.output_file = fp 129 130 parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file, 131 disable_cache=True) 132 self.parsed_run = parsed_run 133 134 return self.parsed_run
EvaluationCallback(evaluator: debeir.evaluation.evaluator.Evaluator, config)
def
after(self, results: List, id_field='id'):
114 def after(self, results: List, id_field="id"): 115 if self.pipeline.output_file is None: 116 directory_name = tempfile.mkdtemp() 117 fn = str(uuid.uuid4()) 118 119 fp = os.path.join(directory_name, fn) 120 121 query = query_factory[self.config.query_fn] 122 query.id_field = id_field 123 124 SerializationCallback._after(results, 125 output_file=fp, 126 run_name=self.config.run_name) 127 128 self.pipeline.output_file = fp 129 130 parsed_run = self.evaluator.evaluate_runs(self.pipeline.output_file, 131 disable_cache=True) 132 self.parsed_run = parsed_run 133 134 return self.parsed_run