Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import numpy as np | |
| import pypdfium2 as pdfium | |
| import torch | |
| import tqdm | |
| from model import encode_images, encode_queries | |
| from PIL import Image | |
| from sqlitedict import SqliteDict | |
| from voyager import Index, Space | |
| def iter_batch( | |
| X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = "" | |
| ) -> list: | |
| """Iterate over a list of elements by batch.""" | |
| batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)] | |
| if tqdm_bar: | |
| for batch in tqdm.tqdm( | |
| iterable=batchs, | |
| position=0, | |
| total=1 + len(X) // batch_size, | |
| desc=desc, | |
| ): | |
| yield batch | |
| else: | |
| yield from batchs | |
| class Voyager: | |
| """Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search. | |
| Parameters | |
| ---------- | |
| name | |
| The name of the collection. | |
| override | |
| Whether to override the collection if it already exists. | |
| embedding_size | |
| The number of dimensions of the embeddings. | |
| M | |
| The number of subquantizers. | |
| ef_construction | |
| The number of candidates to evaluate during the construction of the index. | |
| ef_search | |
| The number of candidates to evaluate during the search. | |
| """ | |
| def __init__( | |
| self, | |
| index_folder: str = "indexes", | |
| index_name: str = "base_collection", | |
| override: bool = False, | |
| embedding_size: int = 128, | |
| M: int = 64, | |
| ef_construction: int = 200, | |
| ef_search: int = 200, | |
| ) -> None: | |
| self.ef_search = ef_search | |
| if not os.path.exists(path=index_folder): | |
| os.makedirs(name=index_folder) | |
| self.index_path = os.path.join(index_folder, f"{index_name}.voyager") | |
| self.page_ids_to_data_path = os.path.join( | |
| index_folder, f"{index_name}_page_ids_to_data.sqlite" | |
| ) | |
| self.index = self._create_collection( | |
| index_path=self.index_path, | |
| embedding_size=embedding_size, | |
| M=M, | |
| ef_constructions=ef_construction, | |
| override=override, | |
| ) | |
| def _load_page_ids_to_data(self) -> SqliteDict: | |
| """Load the SQLite database that maps document IDs to images.""" | |
| return SqliteDict(self.page_ids_to_data_path, outer_stack=False) | |
| def _create_collection( | |
| self, | |
| index_path: str, | |
| embedding_size: int, | |
| M: int, | |
| ef_constructions: int, | |
| override: bool, | |
| ) -> None: | |
| """Create a new Voyager collection. | |
| Parameters | |
| ---------- | |
| index_path | |
| The path to the index. | |
| embedding_size | |
| The size of the embeddings. | |
| M | |
| The number of subquantizers. | |
| ef_constructions | |
| The number of candidates to evaluate during the construction of the index. | |
| override | |
| Whether to override the collection if it already exists. | |
| """ | |
| if os.path.exists(path=index_path) and not override: | |
| return Index.load(index_path) | |
| if os.path.exists(path=index_path): | |
| os.remove(index_path) | |
| # Create the Voyager index | |
| index = Index( | |
| Space.Cosine, | |
| num_dimensions=embedding_size, | |
| M=M, | |
| ef_construction=ef_constructions, | |
| ) | |
| index.save(index_path) | |
| if override and os.path.exists(path=self.page_ids_to_data_path): | |
| os.remove(path=self.page_ids_to_data_path) | |
| # Create the SQLite databases | |
| page_ids_to_data = self._load_page_ids_to_data() | |
| page_ids_to_data.close() | |
| return index | |
| def add_documents( | |
| self, | |
| paths: str | list[str], | |
| batch_size: int = 1, | |
| ) -> None: | |
| """Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents.""" | |
| if isinstance(paths, str): | |
| paths = [paths] | |
| page_ids_to_data = self._load_page_ids_to_data() | |
| images = [] | |
| num_pages = [] | |
| for path in paths: | |
| if path.lower().endswith(".pdf"): | |
| pdf = pdfium.PdfDocument(path) | |
| n_pages = len(pdf) | |
| num_pages.append(n_pages) | |
| for page_number in range(n_pages): | |
| page = pdf.get_page(page_number) | |
| pil_image = page.render( | |
| scale=1, | |
| rotation=0, | |
| ) | |
| pil_image = pil_image.to_pil() | |
| images.append(pil_image) | |
| pdf.close() | |
| else: | |
| pil_image = Image.open(path) | |
| images.append(pil_image) | |
| num_pages.append(1) | |
| embeddings = [] | |
| for batch in iter_batch( | |
| X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})" | |
| ): | |
| embeddings.extend(encode_images(batch)) | |
| embeddings_ids = self.index.add_items(embeddings) | |
| current_index = 0 | |
| for i, path in enumerate(paths): | |
| for page_number in range(num_pages[i]): | |
| page_ids_to_data[embeddings_ids[current_index]] = { | |
| "path": path, | |
| "image": images[current_index], | |
| "page_number": page_number, | |
| } | |
| current_index += 1 | |
| page_ids_to_data.commit() | |
| self.index.save(self.index_path) | |
| return self | |
| def __call__( | |
| self, | |
| queries: np.ndarray | torch.Tensor, | |
| k: int = 10, | |
| ) -> dict: | |
| """Query the index for the nearest neighbors of the queries embeddings. | |
| Parameters | |
| ---------- | |
| queries_embeddings | |
| The queries embeddings. | |
| k | |
| The number of nearest neighbors to return. | |
| """ | |
| queries_embeddings = encode_queries(queries) | |
| page_ids_to_data = self._load_page_ids_to_data() | |
| k = min(k, len(page_ids_to_data)) | |
| n_queries = len(queries_embeddings) | |
| indices, distances = self.index.query( | |
| queries_embeddings, k, query_ef=self.ef_search | |
| ) | |
| if len(indices) == 0: | |
| raise ValueError("Index is empty, add documents before querying.") | |
| documents = [ | |
| [page_ids_to_data[str(indice)] for indice in query_indices] | |
| for query_indices in indices | |
| ] | |
| page_ids_to_data.close() | |
| return { | |
| "documents": documents, | |
| "distances": distances.reshape(n_queries, -1, k), | |
| } | |