Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import chess | |
| from datasets import load_dataset | |
| import pickle | |
| import time | |
| from pyroaring import BitMap | |
| def board_to_tokens(board): | |
| return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)] | |
| def get_puzzle_positions(fen, moves_uci): | |
| positions = [] | |
| board = chess.Board(fen) | |
| board.push_uci(moves_uci.split()[0]) | |
| positions.append(board.copy()) | |
| for move_uci in moves_uci.split()[1:]: | |
| board.push_uci(move_uci) | |
| positions.append(board.copy()) | |
| return positions | |
| def load_index(path='chess_index.pkl'): | |
| with open(path, 'rb') as f: data = pickle.load(f) | |
| return data['index'], data['metadata'] | |
| def query_positions(index, metadata, query_tokens): | |
| result = index[query_tokens[0]].copy() if query_tokens[0] in index else BitMap() | |
| for token in query_tokens[1:]: | |
| if token in index: result &= index[token] | |
| else: return BitMap() | |
| return [(pos_id, metadata[pos_id]) for pos_id in result] | |
| dset = load_dataset("Lichess/chess-puzzles", split="train") | |
| index, metadata = load_index() | |
| app = FastAPI() | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| def read_root(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def search(data: dict): | |
| start = time.time() | |
| board = chess.Board(data['fen']) | |
| query_tokens = board_to_tokens(board) | |
| matches = query_positions(index, metadata, query_tokens) | |
| seen_puzzles = {} | |
| for pos_id, (puzzle_row, move_idx) in matches: | |
| if puzzle_row not in seen_puzzles: | |
| seen_puzzles[puzzle_row] = (pos_id, move_idx) | |
| results = [] | |
| for puzzle_row, (pos_id, move_idx) in seen_puzzles.items(): | |
| row = dset[puzzle_row] | |
| positions = get_puzzle_positions(row['FEN'], row['Moves']) | |
| matched_board = positions[move_idx] | |
| results.append({ | |
| "PuzzleId": row['PuzzleId'], | |
| "FEN": matched_board.fen(), | |
| "Moves": row['Moves'], | |
| "Rating": row['Rating'], | |
| "Popularity": row['Popularity'], | |
| "Themes": row['Themes'], | |
| "MatchedMove": move_idx | |
| }) | |
| elapsed_ms = (time.time() - start) * 1000 | |
| return {"count": len(results), "results": results, "time_ms": elapsed_ms} |