| import os | |
| import json | |
| import numpy as np | |
| import faiss | |
| from typing import List, Dict | |
| from sentence_transformers import SentenceTransformer | |
| class Retriever: | |
| def __init__(self): | |
| self.model = None | |
| self.index = None | |
| self.meta = {} | |
| self.embeddings = None | |
| self._load_index() | |
| def _load_index(self): | |
| try: | |
| if os.path.exists('data/index/index.faiss') and os.path.exists('data/index/meta.json'): | |
| self.index = faiss.read_index('data/index/index.faiss') | |
| self.embeddings = np.load('data/index/embeddings.npy') | |
| with open('data/index/meta.json', 'r', encoding='utf-8') as f: | |
| self.meta = json.load(f) | |
| print('Индекс загружен из кэша') | |
| else: | |
| print('Индекс не найден, будет создан при первом использовании') | |
| except Exception as e: | |
| print(f'Ошибка загрузки индекса: {e}') | |
| def _load_model(self): | |
| if self.model is None: | |
| try: | |
| self.model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') | |
| print('Модель эмбеддингов загружена') | |
| except Exception as e: | |
| print(f'Ошибка загрузки модели: {e}') | |
| raise | |
| def _build_index(self, courses: List[Dict]): | |
| if not courses: | |
| return | |
| self._load_model() | |
| texts = [] | |
| meta_data = {} | |
| for i, course in enumerate(courses): | |
| text = f"{course.get('name', '')} {course.get('short_desc', '')}" | |
| text = text.lower().strip() | |
| if len(text) > 220: | |
| text = text[:220] | |
| texts.append(text) | |
| meta_data[i] = course.get('id', str(i)) | |
| if not texts: | |
| return | |
| embeddings = self.model.encode(texts, convert_to_numpy=True, show_progress_bar=True) | |
| embeddings = embeddings.astype(np.float32) | |
| faiss.normalize_L2(embeddings) | |
| self.index = faiss.IndexFlatIP(embeddings.shape[1]) | |
| self.index.add(embeddings) | |
| self.embeddings = embeddings | |
| self.meta = meta_data | |
| self._save_index() | |
| def _save_index(self): | |
| try: | |
| os.makedirs('data/index', exist_ok=True) | |
| faiss.write_index(self.index, 'data/index/index.faiss') | |
| np.save('data/index/embeddings.npy', self.embeddings) | |
| with open('data/index/meta.json', 'w', encoding='utf-8') as f: | |
| json.dump(self.meta, f, ensure_ascii=False, indent=2) | |
| print('Индекс сохранен') | |
| except Exception as e: | |
| print(f'Ошибка сохранения индекса: {e}') | |
| def retrieve(self, query: str, k: int = 6, threshold: float = 0.35) -> List[Dict]: | |
| if self.index is None: | |
| return [] | |
| self._load_model() | |
| query_embedding = self.model.encode([query.lower().strip()], convert_to_numpy=True) | |
| query_embedding = query_embedding.astype(np.float32) | |
| faiss.normalize_L2(query_embedding) | |
| scores, indices = self.index.search(query_embedding, k) | |
| results = [] | |
| for score, idx in zip(scores[0], indices[0]): | |
| if score >= threshold and idx in self.meta: | |
| course_id = self.meta[idx] | |
| results.append({ | |
| 'course_id': course_id, | |
| 'score': float(score) | |
| }) | |
| return results | |
| def build_or_load_index(self, courses: List[Dict] = None): | |
| if self.index is None and courses: | |
| print('Создание индекса...') | |
| self._build_index(courses) | |
| elif self.index is None: | |
| print('Индекс не найден и данные не предоставлены') | |
| def get_embedding_dim(self) -> int: | |
| if self.embeddings is not None: | |
| return self.embeddings.shape[1] | |
| return 0 | |
| def get_index_size(self) -> int: | |
| if self.index is not None: | |
| return self.index.ntotal | |
| return 0 | |