Spaces:
Runtime error
Runtime error
| from haystack.document_stores import InMemoryDocumentStore | |
| from haystack.nodes.retriever import TfidfRetriever | |
| from haystack.pipelines import DocumentSearchPipeline, ExtractiveQAPipeline | |
| from haystack.nodes.retriever import EmbeddingRetriever | |
| import pickle | |
| from pprint import pprint | |
| dutch_datset_name = 'Partisan news 2019 (dutch)' | |
| german_datset_name = 'CDU election program 2021' | |
| class ExportableInMemoryDocumentStore(InMemoryDocumentStore): | |
| """ | |
| Wrapper class around the InMemoryDocumentStore. | |
| When the application is deployed to Huggingface Spaces there will be no GPU available. | |
| We need to load pre-calculated data into the InMemoryDocumentStore. | |
| """ | |
| def export(self, file_name='in_memory_store.pkl'): | |
| with open(file_name, 'wb') as f: | |
| pickle.dump(self.indexes, f) | |
| def load_data(self, file_name='in_memory_store.pkl'): | |
| with open(file_name, 'rb') as f: | |
| self.indexes = pickle.load(f) | |
| class SearchEngine(): | |
| def __init__(self, document_store_name_base, document_store_name_adpated, | |
| adapted_retriever_path): | |
| self.document_store = ExportableInMemoryDocumentStore(similarity='cosine') | |
| self.document_store.load_data(document_store_name_base) | |
| self.document_store_adapted = ExportableInMemoryDocumentStore(similarity='cosine') | |
| self.document_store_adapted.load_data(document_store_name_adpated) | |
| self.retriever = TfidfRetriever(document_store=self.document_store) | |
| self.base_dense_retriever = EmbeddingRetriever( | |
| document_store=self.document_store, | |
| embedding_model='sentence-transformers/paraphrase-multilingual-mpnet-base-v2', | |
| model_format='sentence_transformers' | |
| ) | |
| self.fine_tuned_retriever = EmbeddingRetriever( | |
| document_store=self.document_store_adapted, | |
| embedding_model=adapted_retriever_path, | |
| model_format='sentence_transformers' | |
| ) | |
| def sparse_retrieval(self, query): | |
| """Sparse retrieval pipeline""" | |
| scores = self.retriever._calc_scores(query) | |
| p_retrieval = DocumentSearchPipeline(self.retriever) | |
| documents = p_retrieval.run(query=query) | |
| documents['documents'][0].score = list(scores[0].values())[0] | |
| return documents | |
| def dense_retrieval(self, query, retriever='base'): | |
| if retriever == 'base': | |
| p_retrieval = DocumentSearchPipeline(self.base_dense_retriever) | |
| return p_retrieval.run(query=query) | |
| if retriever == 'adapted': | |
| p_retrieval = DocumentSearchPipeline(self.fine_tuned_retriever) | |
| return p_retrieval.run(query=query) | |
| def do_search(self, query): | |
| sparse_result = self.sparse_retrieval(query)['documents'][0] | |
| dense_base_result = self.dense_retrieval(query, 'base')['documents'][0] | |
| dense_adapted_result = self.dense_retrieval(query, 'adapted')['documents'][0] | |
| return sparse_result, dense_base_result, dense_adapted_result | |
| dutch_search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl', | |
| 'dutch-article-retriever') | |
| german_search_engine = SearchEngine('documentstore_german-election-idx.pkl', | |
| 'documentstore_german-election-idx_adapted.pkl', | |
| 'adapted-retriever') | |
| def do_search(query, dataset): | |
| if dataset == german_datset_name: | |
| return german_search_engine.do_search(query) | |
| else: | |
| return dutch_search_engine.do_search(query) | |
| if __name__ == '__main__': | |
| search_engine = SearchEngine('dutch-article-idx.pkl', 'dutch-article-idx_adapted.pkl', | |
| 'dutch-article-retriever') | |
| query = 'Kindergarten' | |
| result = search_engine.do_search(query) | |
| pprint(result) | |