Spaces:
Paused
Paused
| """ | |
| Compare three training strategies: | |
| 1. Random: Random questions until student can pass difficult questions | |
| 2. Progressive: Easy β Medium β Hard within each family sequentially | |
| 3. Teacher: RL teacher agent learns optimal curriculum | |
| Uses LM Student (DistilBERT) instead of MockStudentAgent. | |
| """ | |
| import sys | |
| import os | |
| import random # Added for global seeding | |
| import numpy as np # Added for global seeding | |
| from pathlib import Path | |
| # Add student_agent_dev to path for LM student import | |
| student_agent_dev_path = Path(__file__).parent.parent / "student_agent_dev" | |
| if str(student_agent_dev_path) not in sys.path: | |
| sys.path.insert(0, str(student_agent_dev_path)) | |
| import numpy as np | |
| from typing import Dict, Tuple | |
| from interfaces import Task | |
| try: | |
| from tqdm import tqdm | |
| HAS_TQDM = True | |
| except ImportError: | |
| HAS_TQDM = False | |
| tqdm = None | |
| # Import LM Student instead of MockStudentAgent | |
| try: | |
| from student_agent import StudentAgent as LMStudentAgent | |
| USE_LM_STUDENT = True | |
| print("β Using LM Student (DistilBERT)") | |
| except ImportError as e: | |
| print(f"β οΈ Could not import LM Student: {e}") | |
| print(" Falling back to MockStudentAgent") | |
| from mock_student import MockStudentAgent | |
| USE_LM_STUDENT = False | |
| from mock_task_generator import MockTaskGenerator | |
| from teacher_agent import TeacherAgent, compute_reward | |
| from train_teacher import train_teacher | |
| def evaluate_difficult_questions(student, generator: MockTaskGenerator, num_questions: int = 20) -> float: | |
| """ | |
| Evaluate student on difficult questions from all topics. | |
| """ | |
| topics = generator.get_available_topics() | |
| eval_tasks = [] | |
| # Generate difficult questions from all topics | |
| questions_per_topic = max(1, num_questions // len(topics)) | |
| for topic in topics: | |
| for _ in range(questions_per_topic): | |
| eval_tasks.append(generator.generate_task(topic, 'hard')) | |
| return student.evaluate(eval_tasks) | |
| def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accuracy: float = 0.75) -> Dict: | |
| """ | |
| Strategy 1: Random questions until student can confidently pass difficult questions. | |
| """ | |
| # Set global seeds to ensure MockTaskGenerator behaves deterministically | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| rng = random.Random(seed) | |
| device = os.environ.get("CUDA_DEVICE", "cpu") | |
| if device == "cuda": | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| print(f"β Using GPU: {torch.cuda.get_device_name(0)}") | |
| else: | |
| device = "cpu" | |
| except: | |
| device = "cpu" | |
| print(f"π§ LM Student device: {device}") | |
| student = LMStudentAgent( | |
| learning_rate=5e-5, | |
| retention_constant=80.0, | |
| device=device, | |
| max_length=256, | |
| gradient_accumulation_steps=4 | |
| ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed) | |
| # --- FIX 1: REMOVED seed=seed --- | |
| generator = MockTaskGenerator() | |
| topics = generator.get_available_topics() | |
| difficulties = generator.get_available_difficulties() | |
| # Evaluation on difficult questions - CREATE FIXED SET ONCE | |
| hard_eval_tasks = [] | |
| eval_difficulty = 'expert' if 'expert' in difficulties else 'hard' | |
| for topic in topics: | |
| for _ in range(5): | |
| hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty)) | |
| # Create FIXED general eval set | |
| general_eval_tasks = [ | |
| generator.generate_task(topic, 'medium') | |
| for topic in topics | |
| for _ in range(3) | |
| ] | |
| history = { | |
| 'iterations': [], | |
| 'student_accuracies': [], | |
| 'difficult_accuracies': [], | |
| 'teacher_rewards': [], | |
| 'topics': [], | |
| 'difficulties': [], | |
| 'strategy': 'random' | |
| } | |
| iterator = range(num_iterations) | |
| if HAS_TQDM: | |
| iterator = tqdm(iterator, desc="Random Strategy", unit="iter") | |
| for iteration in iterator: | |
| topic = rng.choice(topics) | |
| difficulty = rng.choice(difficulties) | |
| task = generator.generate_task(topic, difficulty) | |
| accuracy_before = student.evaluate(hard_eval_tasks) | |
| student.learn(task) | |
| accuracy_after = student.evaluate(hard_eval_tasks) | |
| general_accuracy = student.evaluate(general_eval_tasks) | |
| student.advance_time(1.0) | |
| history['iterations'].append(iteration) | |
| history['student_accuracies'].append(general_accuracy) | |
| history['difficult_accuracies'].append(accuracy_after) | |
| history['teacher_rewards'].append(accuracy_after - accuracy_before) | |
| history['topics'].append(topic) | |
| history['difficulties'].append(difficulty) | |
| if accuracy_after >= target_accuracy and iteration > 50: | |
| if 'reached_target' not in locals(): | |
| print(f" Random strategy reached target accuracy {target_accuracy:.2f} at iteration {iteration}") | |
| reached_target = True | |
| return history | |
| def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dict: | |
| """ | |
| Strategy 2: Progressive difficulty within each family. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| student = LMStudentAgent( | |
| learning_rate=5e-5, | |
| retention_constant=80.0, | |
| device='cpu', | |
| max_length=256, | |
| gradient_accumulation_steps=4 | |
| ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed) | |
| # --- FIX 2: REMOVED seed=seed --- | |
| generator = MockTaskGenerator() | |
| topics = generator.get_available_topics() | |
| all_difficulties = generator.get_available_difficulties() | |
| difficulties = all_difficulties | |
| hard_eval_tasks = [] | |
| eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard' | |
| for topic in topics: | |
| for _ in range(5): | |
| hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty)) | |
| general_eval_tasks = [ | |
| generator.generate_task(topic, 'medium') | |
| for topic in topics | |
| for _ in range(3) | |
| ] | |
| history = { | |
| 'iterations': [], | |
| 'student_accuracies': [], | |
| 'difficult_accuracies': [], | |
| 'teacher_rewards': [], | |
| 'topics': [], | |
| 'difficulties': [], | |
| 'strategy': 'progressive' | |
| } | |
| questions_per_difficulty = max(1, num_iterations // (len(topics) * len(difficulties))) | |
| iterator = range(num_iterations) | |
| if HAS_TQDM: | |
| iterator = tqdm(iterator, desc="Progressive Strategy", unit="iter") | |
| for iteration in iterator: | |
| phase = iteration // questions_per_difficulty if questions_per_difficulty > 0 else iteration | |
| topic_idx = (phase // len(difficulties)) % len(topics) | |
| diff_idx = phase % len(difficulties) | |
| topic = topics[topic_idx] | |
| difficulty = difficulties[diff_idx] | |
| task = generator.generate_task(topic, difficulty) | |
| accuracy_before = student.evaluate(hard_eval_tasks) | |
| student.learn(task) | |
| accuracy_after = student.evaluate(hard_eval_tasks) | |
| general_accuracy = student.evaluate(general_eval_tasks) | |
| student.advance_time(1.0) | |
| history['iterations'].append(iteration) | |
| history['student_accuracies'].append(general_accuracy) | |
| history['difficult_accuracies'].append(accuracy_after) | |
| history['teacher_rewards'].append(accuracy_after - accuracy_before) | |
| history['topics'].append(topic) | |
| history['difficulties'].append(difficulty) | |
| return history | |
| def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict: | |
| """ | |
| Strategy 3: RL Teacher Agent learns optimal curriculum. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| # --- FIX 3: REMOVED seed=seed --- | |
| generator = MockTaskGenerator() | |
| teacher = TeacherAgent(exploration_bonus=2.0, task_generator=generator) | |
| student = LMStudentAgent( | |
| learning_rate=5e-5, | |
| retention_constant=80.0, | |
| device='cpu', | |
| max_length=256, | |
| gradient_accumulation_steps=4 | |
| ) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed) | |
| topics = generator.get_available_topics() | |
| eval_tasks = [ | |
| generator.generate_task(topic, 'medium') | |
| for topic in topics | |
| for _ in range(3) | |
| ] | |
| all_difficulties = generator.get_available_difficulties() | |
| eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard' | |
| hard_eval_tasks = [ | |
| generator.generate_task(topic, eval_difficulty) | |
| for topic in topics | |
| for _ in range(5) | |
| ] | |
| history = { | |
| 'iterations': [], | |
| 'student_accuracies': [], | |
| 'difficult_accuracies': [], | |
| 'teacher_rewards': [], | |
| 'actions': [], | |
| 'topics': [], | |
| 'difficulties': [], | |
| 'is_reviews': [], | |
| 'strategy': 'teacher' | |
| } | |
| iterator = range(num_iterations) | |
| if HAS_TQDM: | |
| iterator = tqdm(iterator, desc="Teacher Strategy", unit="iter") | |
| for iteration in iterator: | |
| student_state = student.get_state() | |
| action = teacher.select_action(student_state) | |
| if action.is_review: | |
| task = generator.generate_task(action.topic, 'medium') | |
| else: | |
| task = generator.generate_task(action.topic, action.difficulty) | |
| accuracy_before = student.evaluate(eval_tasks) | |
| difficult_acc_before = student.evaluate(hard_eval_tasks) | |
| student.learn(task) | |
| accuracy_after = student.evaluate(eval_tasks) | |
| difficult_acc_after = student.evaluate(hard_eval_tasks) | |
| reward = compute_reward( | |
| accuracy_before, | |
| accuracy_after, | |
| action.difficulty, | |
| action.is_review | |
| ) | |
| teacher.update(action, reward) | |
| student.advance_time(1.0) | |
| history['iterations'].append(iteration) | |
| history['student_accuracies'].append(accuracy_after) | |
| history['difficult_accuracies'].append(difficult_acc_after) | |
| history['teacher_rewards'].append(reward) | |
| history['actions'].append(action) | |
| history['topics'].append(action.topic) | |
| history['difficulties'].append(action.difficulty) | |
| history['is_reviews'].append(action.is_review) | |
| return history | |
| def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_dev/comparison_all_strategies.png'): | |
| """ | |
| Create comprehensive comparison plots of all three strategies. | |
| """ | |
| import matplotlib.pyplot as plt | |
| # Ensure directory exists | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| fig, axes = plt.subplots(4, 1, figsize=(16, 14)) | |
| colors = { | |
| 'Random': '#FF6B6B', # Red | |
| 'Progressive': '#4ECDC4', # Teal | |
| 'Teacher': '#2ECC71' # Green | |
| } | |
| line_styles = { | |
| 'Random': '--', | |
| 'Progressive': '-.', | |
| 'Teacher': '-' | |
| } | |
| line_widths = { | |
| 'Random': 2.0, | |
| 'Progressive': 2.0, | |
| 'Teacher': 3.5 | |
| } | |
| # 1. Plot 1: General Accuracy | |
| ax = axes[0] | |
| for name, history in histories.items(): | |
| iterations = history['iterations'] | |
| accuracies = history['student_accuracies'] | |
| if len(accuracies) > 50: | |
| # Smooth curves | |
| window = 10 | |
| smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same') | |
| ax.plot(iterations, smoothed, | |
| label=name, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name], | |
| alpha=0.9) | |
| else: | |
| ax.plot(iterations, accuracies, | |
| label=name, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name]) | |
| ax.set_xlabel('Training Iteration') | |
| ax.set_ylabel('General Accuracy') | |
| ax.set_title('Learning Curves') | |
| ax.legend(loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| ax.set_ylim([0.0, 1.0]) | |
| # 2. Plot 2: Difficult Question Accuracy | |
| ax = axes[1] | |
| for name, history in histories.items(): | |
| iterations = history['iterations'] | |
| difficult_accuracies = history['difficult_accuracies'] | |
| if len(difficult_accuracies) > 50: | |
| window = 10 | |
| smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same') | |
| ax.plot(iterations, smoothed, | |
| label=name, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name]) | |
| else: | |
| ax.plot(iterations, difficult_accuracies, | |
| label=name, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name]) | |
| ax.set_xlabel('Training Iteration') | |
| ax.set_ylabel('Accuracy on Hard Questions') | |
| ax.set_title('Performance on Difficult Content') | |
| ax.legend(loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| ax.set_ylim([0.0, 1.0]) | |
| # 3. Plot 3: Topic Coverage | |
| ax = axes[2] | |
| for name, history in histories.items(): | |
| iterations = history['iterations'] | |
| topics_seen = history['topics'] | |
| unique_topics = [] | |
| seen_so_far = set() | |
| for topic in topics_seen: | |
| seen_so_far.add(topic) | |
| unique_topics.append(len(seen_so_far)) | |
| ax.plot(iterations, unique_topics, | |
| label=name, | |
| color=colors[name], | |
| linestyle=line_styles[name], | |
| linewidth=line_widths[name]) | |
| ax.set_xlabel('Training Iteration') | |
| ax.set_ylabel('Unique Topics Seen') | |
| ax.set_title('Curriculum Diversity') | |
| ax.legend(loc='lower right') | |
| ax.grid(True, alpha=0.3) | |
| # 4. Plot 4: Learning Efficiency | |
| ax = axes[3] | |
| target_acc = 0.75 | |
| strategy_stats = {} | |
| for name, history in histories.items(): | |
| difficult_accuracies = history['difficult_accuracies'] | |
| iterations = history['iterations'] | |
| reached_target = False | |
| target_iteration = len(iterations) - 1 | |
| for i, acc in enumerate(difficult_accuracies): | |
| if acc >= target_acc: | |
| target_iteration = i | |
| reached_target = True | |
| break | |
| strategy_stats[name] = { | |
| 'reached': reached_target, | |
| 'iteration': target_iteration, | |
| 'final_acc': difficult_accuracies[-1] | |
| } | |
| names = list(strategy_stats.keys()) | |
| iterations_to_target = [ | |
| strategy_stats[n]['iteration'] if strategy_stats[n]['reached'] else len(histories[n]['iterations']) | |
| for n in names | |
| ] | |
| final_accs = [strategy_stats[n]['final_acc'] for n in names] | |
| x = np.arange(len(names)) | |
| width = 0.35 | |
| ax.bar(x - width/2, iterations_to_target, width, label='Iterations to 75% on Hard', | |
| color=[colors[n] for n in names], alpha=0.7) | |
| ax.bar(x + width/2, [acc * max(iterations_to_target) for acc in final_accs], width, | |
| label='Final Hard Accuracy (scaled)', | |
| color=[colors[n] for n in names], alpha=0.5) | |
| ax.set_title('Learning Efficiency') | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(names) | |
| ax.legend() | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=150) | |
| print(f"\nβ Saved comparison plot to {save_path}") | |
| plt.close() | |
| if __name__ == "__main__": | |
| import argparse | |
| import time | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--seed', type=int, default=None) | |
| parser.add_argument('--iterations', type=int, default=500) | |
| parser.add_argument('--deterministic', action='store_true') | |
| parser.add_argument('--runs', type=int, default=1) | |
| args = parser.parse_args() | |
| if args.deterministic: | |
| seed = 42 | |
| print("β οΈ Using deterministic mode (seed=42)") | |
| elif args.seed is not None: | |
| seed = args.seed | |
| else: | |
| seed = int(time.time()) % 10000 | |
| print(f"Using seed: {seed}") | |
| num_iterations = args.iterations | |
| # Run strategies | |
| print("Training Random Strategy...") | |
| history_random = train_strategy_random(num_iterations=num_iterations, seed=seed) | |
| print("\nTraining Progressive Strategy...") | |
| history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=seed) | |
| print("\nTraining Teacher Strategy...") | |
| history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=seed) | |
| histories = { | |
| 'Random': history_random, | |
| 'Progressive': history_progressive, | |
| 'Teacher': history_teacher | |
| } | |
| plot_comparison(histories, save_path='teacher_agent_dev/comparison_all_strategies.png') | |
| print("\nβ Comparison complete!") |