chess / app.py
christopher's picture
Update app.py
a8850d7 verified
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
import chess
from datasets import load_dataset
import pickle
import time
from pyroaring import BitMap
def board_to_tokens(board):
return [(board.piece_at(sq).symbol(), chess.square_name(sq)) for sq in chess.SQUARES if board.piece_at(sq)]
def get_puzzle_positions(fen, moves_uci):
positions = []
board = chess.Board(fen)
board.push_uci(moves_uci.split()[0])
positions.append(board.copy())
for move_uci in moves_uci.split()[1:]:
board.push_uci(move_uci)
positions.append(board.copy())
return positions
def load_index(path='chess_index.pkl'):
with open(path, 'rb') as f: data = pickle.load(f)
return data['index'], data['metadata']
def query_positions(index, metadata, query_tokens):
result = index[query_tokens[0]].copy() if query_tokens[0] in index else BitMap()
for token in query_tokens[1:]:
if token in index: result &= index[token]
else: return BitMap()
return [(pos_id, metadata[pos_id]) for pos_id in result]
dset = load_dataset("Lichess/chess-puzzles", split="train")
index, metadata = load_index()
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/")
def read_root(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/search")
async def search(data: dict):
start = time.time()
board = chess.Board(data['fen'])
query_tokens = board_to_tokens(board)
matches = query_positions(index, metadata, query_tokens)
seen_puzzles = {}
for pos_id, (puzzle_row, move_idx) in matches:
if puzzle_row not in seen_puzzles:
seen_puzzles[puzzle_row] = (pos_id, move_idx)
results = []
for puzzle_row, (pos_id, move_idx) in seen_puzzles.items():
row = dset[puzzle_row]
positions = get_puzzle_positions(row['FEN'], row['Moves'])
matched_board = positions[move_idx]
results.append({
"PuzzleId": row['PuzzleId'],
"FEN": matched_board.fen(),
"Moves": row['Moves'],
"Rating": row['Rating'],
"Popularity": row['Popularity'],
"Themes": row['Themes'],
"MatchedMove": move_idx
})
elapsed_ms = (time.time() - start) * 1000
return {"count": len(results), "results": results, "time_ms": elapsed_ms}