Spaces:
Runtime error
Runtime error
| import json | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from promptSearchEngine import PromptSearchEngine | |
| from vectorizer import Vectorizer | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| EMBEDDING_MODEL = "all-MiniLM-L6-v2" | |
| DATASET = "Gustavosta/Stable-Diffusion-Prompts" | |
| model = SentenceTransformer(EMBEDDING_MODEL) | |
| dataset = load_dataset(DATASET , split="train[:1%]") | |
| promptSearchEngine = PromptSearchEngine(dataset["Prompt"], model) | |
| class SearchRequest(BaseModel): | |
| query: str | |
| n: int | None = 5 | |
| app = FastAPI() | |
| async def root(): | |
| return {"message": 'GET /docs'} | |
| async def search(q: str, n: int = 5): | |
| results = [] | |
| if q.isspace() or q =="": | |
| return {"message": "Enter query"} | |
| else: | |
| results = promptSearchEngine.most_similar(q, n) | |
| if not results: | |
| raise HTTPException(status_code=404, detail="No prompts found.") | |
| return promptSearchEngine.stringify_prompts(results) | |
| async def searchPost(request: SearchRequest): | |
| results = promptSearchEngine.most_similar(request.query, request.n) | |
| if not results: | |
| raise HTTPException(status_code=404, detail="No prompts found.") | |
| formatted_results = [{"similarity": float(similarity), "prompt": prompt } for similarity, prompt in results] | |
| return { "data" : formatted_results } | |