Spaces:
Runtime error
Runtime error
| import uuid | |
| from flask import Flask, render_template, request, redirect, url_for, send_from_directory | |
| import json | |
| import random | |
| import os | |
| import string | |
| import logging | |
| from datetime import datetime | |
| from huggingface_hub import login, HfApi, hf_hub_download | |
| from statistics import mean | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("app.log"), | |
| logging.StreamHandler() | |
| ]) | |
| logger = logging.getLogger(__name__) | |
| # Use the Hugging Face token from environment variables | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| login(token=hf_token) | |
| else: | |
| logger.error("HF_TOKEN not found in environment variables") | |
| app = Flask(__name__) | |
| app.config['SECRET_KEY'] = 'supersecretkey' | |
| # File-based session storage | |
| SESSION_DIR = '/tmp/sessions' | |
| os.makedirs(SESSION_DIR, exist_ok=True) | |
| # Update visualization directories for the 4 methods | |
| VISUALIZATION_DIRS = { | |
| "Text2SQL": "htmls_Text2SQL", | |
| "Dater": "htmls_DATER_mod2", | |
| "Chain-of-Table": "htmls_COT_mod", | |
| "Plan-of-SQLs": "htmls_POS_mod2" | |
| } | |
| # Update method directory mapping | |
| def get_method_dir(method): | |
| method_mapping = { | |
| 'Text2SQL': 'Text2SQL', | |
| 'Dater': 'DATER', | |
| 'Chain-of-Table': 'COT', | |
| 'Plan-of-SQLs': 'POS' | |
| } | |
| return method_mapping.get(method) | |
| # Update methods list to only include the 4 methods we want to rank | |
| METHODS = ["Text2SQL", "Dater", "Chain-of-Table", "Plan-of-SQLs"] | |
| def generate_session_id(): | |
| return str(uuid.uuid4()) | |
| def save_session_data(session_id, data): | |
| file_path = os.path.join(SESSION_DIR, f'{session_id}.json') | |
| with open(file_path, 'w') as f: | |
| json.dump(data, f) | |
| logger.info(f"Session data saved for session {session_id}") | |
| def load_session_data(session_id): | |
| file_path = os.path.join(SESSION_DIR, f'{session_id}.json') | |
| if os.path.exists(file_path): | |
| with open(file_path, 'r') as f: | |
| return json.load(f) | |
| return None | |
| def save_session_data_to_hf(session_id, data): | |
| try: | |
| username = data.get('username', 'unknown') | |
| seed = data.get('seed', 'unknown') | |
| start_time = data.get('start_time', datetime.now().isoformat()) | |
| file_name = f'{username}_seed{seed}_{start_time}_{session_id}_session.json' | |
| file_name = "".join(c for c in file_name if c.isalnum() or c in ['_', '-', '.']) | |
| json_data = json.dumps(data, indent=4) | |
| temp_file_path = f"/tmp/{file_name}" | |
| with open(temp_file_path, 'w') as f: | |
| f.write(json_data) | |
| api = HfApi() | |
| repo_path = "session_data_preference_ranking" | |
| api.upload_file( | |
| path_or_fileobj=temp_file_path, | |
| path_in_repo=f"{repo_path}/{file_name}", | |
| repo_id="luulinh90s/Tabular-LLM-Study-Data", | |
| repo_type="space", | |
| ) | |
| os.remove(temp_file_path) | |
| logger.info(f"Session data saved for session {session_id} in Hugging Face Data Space") | |
| except Exception as e: | |
| logger.exception(f"Error saving session data for session {session_id}: {e}") | |
| def load_samples_for_all_methods(metadata_files): | |
| samples_by_method = {} | |
| common_samples = [] | |
| # First, load all samples for each method | |
| for method in METHODS: | |
| method_samples = [] | |
| categories = ["TP", "TN", "FP", "FN"] | |
| for category in categories: | |
| method_dir = VISUALIZATION_DIRS[method] | |
| try: | |
| files = set(os.listdir(f'{method_dir}/{category}')) | |
| for file in files: | |
| index = file.split('-')[1].split('.')[0] | |
| metadata_key = f"{get_method_dir(method)}_test-{index}.html" | |
| # Get metadata for this sample | |
| sample_metadata = metadata_files[method].get(metadata_key, {}) | |
| method_samples.append({ | |
| 'category': category, | |
| 'file': file, | |
| 'metadata': sample_metadata | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error loading samples for method {method}, category {category}: {e}") | |
| samples_by_method[method] = method_samples | |
| # Find common samples across all methods | |
| file_sets = [] | |
| for method, samples in samples_by_method.items(): | |
| file_set = {s['file'] for s in samples} | |
| file_sets.append(file_set) | |
| common_files = set.intersection(*file_sets) | |
| # Create groups of samples that exist across all methods | |
| for file_name in common_files: | |
| sample_group = {} | |
| for method in METHODS: | |
| sample = next((s for s in samples_by_method[method] if s['file'] == file_name), None) | |
| if sample: | |
| sample_group[method] = sample | |
| if len(sample_group) == len(METHODS): | |
| common_samples.append(sample_group) | |
| return common_samples | |
| def select_balanced_samples(samples): | |
| try: | |
| # Get the category from any method (they should all be the same) | |
| sample_categories = [(s, next(iter(s.values()))['category']) for s in samples] | |
| # Separate samples into two groups | |
| tp_fp_samples = [s for s, cat in sample_categories if cat in ['TP', 'FP']] | |
| tn_fn_samples = [s for s, cat in sample_categories if cat in ['TN', 'FN']] | |
| # Select balanced samples | |
| if len(tp_fp_samples) >= 5 and len(tn_fn_samples) >= 5: | |
| selected_tp_fp = random.sample(tp_fp_samples, 5) | |
| selected_tn_fn = random.sample(tn_fn_samples, 5) | |
| selected_samples = selected_tp_fp + selected_tn_fn | |
| random.shuffle(selected_samples) | |
| else: | |
| logger.warning( | |
| f"Not enough samples for balanced selection. TP+FP: {len(tp_fp_samples)}, TN+FN: {len(tn_fn_samples)}") | |
| selected_samples = random.sample(samples, min(10, len(samples))) | |
| return selected_samples | |
| except Exception as e: | |
| logger.exception("Error selecting balanced samples") | |
| return [] | |
| def root(): | |
| return redirect(url_for('consent')) | |
| def consent(): | |
| if request.method == 'POST': | |
| return redirect(url_for('introduction')) | |
| return render_template('consent.html') | |
| def introduction(): | |
| return render_template('introduction.html') | |
| def attribution(): | |
| return render_template('attribution.html') | |
| def index(): | |
| if request.method == 'POST': | |
| username = request.form.get('username') | |
| seed = request.form.get('seed') | |
| if not username or not seed: | |
| return render_template('index.html', error="Please fill in all fields.") | |
| try: | |
| seed = int(seed) | |
| random.seed(seed) | |
| # Load metadata for all methods | |
| metadata_files = {} | |
| for method in METHODS: | |
| json_file = f'Tabular_LLMs_human_study_vis_6_{get_method_dir(method)}.json' | |
| with open(json_file, 'r') as f: | |
| metadata_files[method] = json.load(f) | |
| # Load and select samples | |
| all_samples = load_samples_for_all_methods(metadata_files) | |
| selected_samples = select_balanced_samples(all_samples) | |
| if len(selected_samples) == 0: | |
| return render_template('index.html', error="No common samples were found") | |
| # Create session | |
| session_id = generate_session_id() | |
| session_data = { | |
| 'username': username, | |
| 'seed': str(seed), | |
| 'selected_samples': selected_samples, | |
| 'current_index': 0, | |
| 'responses': [], | |
| 'start_time': datetime.now().isoformat(), | |
| 'session_id': session_id | |
| } | |
| save_session_data(session_id, session_data) | |
| return redirect(url_for('experiment', session_id=session_id)) | |
| except Exception as e: | |
| logger.exception(f"Error in index route: {e}") | |
| return render_template('index.html', error="An error occurred. Please try again.") | |
| return render_template('index.html') | |
| def experiment(session_id): | |
| try: | |
| session_data = load_session_data(session_id) | |
| if not session_data: | |
| return redirect(url_for('index')) | |
| selected_samples = session_data['selected_samples'] | |
| current_index = session_data['current_index'] | |
| if current_index >= len(selected_samples): | |
| return redirect(url_for('completed', session_id=session_id)) | |
| if request.method == 'POST': | |
| # Validate and save rankings | |
| rankings = {method: int(request.form.get(method)) for method in METHODS} | |
| if not all(1 <= rank <= 4 for rank in rankings.values()): | |
| return "Invalid rankings. Please use numbers 1-4.", 400 | |
| if len(set(rankings.values())) != 4: | |
| return "Each method must have a unique rank.", 400 | |
| session_data['responses'].append({ | |
| 'sample_id': current_index, | |
| 'rankings': rankings | |
| }) | |
| session_data['current_index'] += 1 | |
| save_session_data(session_id, session_data) | |
| return redirect(url_for('experiment', session_id=session_id)) | |
| # Get current sample group and prepare visualizations | |
| sample_group = selected_samples[current_index] | |
| visualizations = { | |
| method: url_for('send_visualization', | |
| filename=f"{VISUALIZATION_DIRS[method]}/{sample['category']}/{sample['file']}") | |
| for method, sample in sample_group.items() | |
| } | |
| # Get metadata from any method (they should all have the same statement) | |
| sample_metadata = next(iter(sample_group.values()))['metadata'] | |
| statement = sample_metadata.get('statement', '') | |
| return render_template('experiment.html', | |
| sample_id=current_index, | |
| statement=statement, | |
| visualizations=visualizations, | |
| methods=METHODS, | |
| session_id=session_id) | |
| except Exception as e: | |
| logger.exception(f"An error occurred in the experiment route: {e}") | |
| return "An error occurred", 500 | |
| def completed(session_id): | |
| try: | |
| session_data = load_session_data(session_id) | |
| if not session_data: | |
| return redirect(url_for('index')) | |
| session_data['end_time'] = datetime.now().isoformat() | |
| responses = session_data['responses'] | |
| # Calculate average ranking for each method | |
| average_rankings = { | |
| method: mean(r['rankings'][method] for r in responses) | |
| for method in METHODS | |
| } | |
| # Sort methods by average ranking (ascending) | |
| sorted_methods = sorted( | |
| average_rankings.items(), | |
| key=lambda x: x[1] | |
| ) | |
| session_data['average_rankings'] = average_rankings | |
| save_session_data_to_hf(session_id, session_data) | |
| # Clean up local session file | |
| try: | |
| os.remove(os.path.join(SESSION_DIR, f'{session_id}.json')) | |
| except Exception as e: | |
| logger.warning(f"Error removing session file: {e}") | |
| return render_template( | |
| 'completed.html', | |
| average_rankings=average_rankings, | |
| sorted_methods=sorted_methods | |
| ) | |
| except Exception as e: | |
| logger.exception(f"An error occurred in the completed route: {e}") | |
| return "An error occurred", 500 | |
| def send_visualization(filename): | |
| base_dir = os.getcwd() | |
| file_path = os.path.normpath(os.path.join(base_dir, filename)) | |
| if not file_path.startswith(base_dir): | |
| return "Access denied", 403 | |
| if not os.path.exists(file_path): | |
| return "File not found", 404 | |
| directory = os.path.dirname(file_path) | |
| file_name = os.path.basename(file_path) | |
| return send_from_directory(directory, file_name) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=7860, debug=True) |