Spaces:
Sleeping
Sleeping
| import os | |
| import urllib.request | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from datasets import load_dataset | |
| def download_sakila_db(): | |
| """Download Sakila SQLite database.""" | |
| if os.path.exists("./sakila.db"): | |
| print("β Sakila database already exists") | |
| return | |
| print("Downloading Sakila database...") | |
| url = "https://github.com/ivanceras/sakila/raw/master/sqlite-sakila-db/sakila.db" | |
| urllib.request.urlretrieve(url, "./sakila.db") | |
| print("β Sakila database downloaded") | |
| def setup_agnews_chromadb(): | |
| """Load original AG News and compute embeddings.""" | |
| print("\nLoading AG News dataset...") | |
| ds = load_dataset("fancyzhx/ag_news", split="train[:500]") | |
| print(f"β Loaded {len(ds)} articles") | |
| os.makedirs("./chroma_agnews/", exist_ok=True) | |
| client = chromadb.PersistentClient(path="./chroma_agnews/") | |
| try: | |
| client.delete_collection("ag_news") | |
| except: | |
| pass | |
| # Create collection with embedding function | |
| embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction( | |
| model_name="all-mpnet-base-v2" | |
| ) | |
| collection = client.create_collection( | |
| name="ag_news", | |
| embedding_function=embedding_fn, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| # Label mapping | |
| label_names = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"} | |
| # Adding to ChromaDB | |
| print("Computing embeddings and adding to ChromaDB...") | |
| ids = [f"doc_{i}" for i in range(len(ds))] | |
| documents = [item['text'] for item in ds] | |
| metadatas = [{ | |
| "label": item['label'], | |
| "label_text": label_names[item['label']], | |
| "title": item['text'][:100] + "..." if len(item['text']) > 100 else item['text'] | |
| } for item in ds] | |
| collection.add( | |
| ids=ids, | |
| documents=documents, | |
| metadatas=metadatas | |
| ) | |
| print(f"β Added {len(ds)} articles to ChromaDB") | |
| if __name__ == "__main__": | |
| print("=== Setting up databases ===\n") | |
| download_sakila_db() | |
| setup_agnews_chromadb() | |
| print("\n Setup complete! Run 'streamlit run chatbot.py'") |