MentorFlow / teacher_agent_dev /compare_strategies.py
CorneliusWang's picture
Update teacher_agent_dev/compare_strategies.py
d06d2e6 verified
"""
Compare three training strategies:
1. Random: Random questions until student can pass difficult questions
2. Progressive: Easy β†’ Medium β†’ Hard within each family sequentially
3. Teacher: RL teacher agent learns optimal curriculum
Uses LM Student (DistilBERT) instead of MockStudentAgent.
"""
import sys
import os
import random # Added for global seeding
import numpy as np # Added for global seeding
from pathlib import Path
# Add student_agent_dev to path for LM student import
student_agent_dev_path = Path(__file__).parent.parent / "student_agent_dev"
if str(student_agent_dev_path) not in sys.path:
sys.path.insert(0, str(student_agent_dev_path))
import numpy as np
from typing import Dict, Tuple
from interfaces import Task
try:
from tqdm import tqdm
HAS_TQDM = True
except ImportError:
HAS_TQDM = False
tqdm = None
# Import LM Student instead of MockStudentAgent
try:
from student_agent import StudentAgent as LMStudentAgent
USE_LM_STUDENT = True
print("βœ… Using LM Student (DistilBERT)")
except ImportError as e:
print(f"⚠️ Could not import LM Student: {e}")
print(" Falling back to MockStudentAgent")
from mock_student import MockStudentAgent
USE_LM_STUDENT = False
from mock_task_generator import MockTaskGenerator
from teacher_agent import TeacherAgent, compute_reward
from train_teacher import train_teacher
def evaluate_difficult_questions(student, generator: MockTaskGenerator, num_questions: int = 20) -> float:
"""
Evaluate student on difficult questions from all topics.
"""
topics = generator.get_available_topics()
eval_tasks = []
# Generate difficult questions from all topics
questions_per_topic = max(1, num_questions // len(topics))
for topic in topics:
for _ in range(questions_per_topic):
eval_tasks.append(generator.generate_task(topic, 'hard'))
return student.evaluate(eval_tasks)
def train_strategy_random(num_iterations: int = 500, seed: int = 42, target_accuracy: float = 0.75) -> Dict:
"""
Strategy 1: Random questions until student can confidently pass difficult questions.
"""
# Set global seeds to ensure MockTaskGenerator behaves deterministically
random.seed(seed)
np.random.seed(seed)
rng = random.Random(seed)
device = os.environ.get("CUDA_DEVICE", "cpu")
if device == "cuda":
try:
import torch
if torch.cuda.is_available():
print(f"βœ… Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = "cpu"
except:
device = "cpu"
print(f"πŸ”§ LM Student device: {device}")
student = LMStudentAgent(
learning_rate=5e-5,
retention_constant=80.0,
device=device,
max_length=256,
gradient_accumulation_steps=4
) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
# --- FIX 1: REMOVED seed=seed ---
generator = MockTaskGenerator()
topics = generator.get_available_topics()
difficulties = generator.get_available_difficulties()
# Evaluation on difficult questions - CREATE FIXED SET ONCE
hard_eval_tasks = []
eval_difficulty = 'expert' if 'expert' in difficulties else 'hard'
for topic in topics:
for _ in range(5):
hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty))
# Create FIXED general eval set
general_eval_tasks = [
generator.generate_task(topic, 'medium')
for topic in topics
for _ in range(3)
]
history = {
'iterations': [],
'student_accuracies': [],
'difficult_accuracies': [],
'teacher_rewards': [],
'topics': [],
'difficulties': [],
'strategy': 'random'
}
iterator = range(num_iterations)
if HAS_TQDM:
iterator = tqdm(iterator, desc="Random Strategy", unit="iter")
for iteration in iterator:
topic = rng.choice(topics)
difficulty = rng.choice(difficulties)
task = generator.generate_task(topic, difficulty)
accuracy_before = student.evaluate(hard_eval_tasks)
student.learn(task)
accuracy_after = student.evaluate(hard_eval_tasks)
general_accuracy = student.evaluate(general_eval_tasks)
student.advance_time(1.0)
history['iterations'].append(iteration)
history['student_accuracies'].append(general_accuracy)
history['difficult_accuracies'].append(accuracy_after)
history['teacher_rewards'].append(accuracy_after - accuracy_before)
history['topics'].append(topic)
history['difficulties'].append(difficulty)
if accuracy_after >= target_accuracy and iteration > 50:
if 'reached_target' not in locals():
print(f" Random strategy reached target accuracy {target_accuracy:.2f} at iteration {iteration}")
reached_target = True
return history
def train_strategy_progressive(num_iterations: int = 500, seed: int = 42) -> Dict:
"""
Strategy 2: Progressive difficulty within each family.
"""
random.seed(seed)
np.random.seed(seed)
student = LMStudentAgent(
learning_rate=5e-5,
retention_constant=80.0,
device='cpu',
max_length=256,
gradient_accumulation_steps=4
) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
# --- FIX 2: REMOVED seed=seed ---
generator = MockTaskGenerator()
topics = generator.get_available_topics()
all_difficulties = generator.get_available_difficulties()
difficulties = all_difficulties
hard_eval_tasks = []
eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard'
for topic in topics:
for _ in range(5):
hard_eval_tasks.append(generator.generate_task(topic, eval_difficulty))
general_eval_tasks = [
generator.generate_task(topic, 'medium')
for topic in topics
for _ in range(3)
]
history = {
'iterations': [],
'student_accuracies': [],
'difficult_accuracies': [],
'teacher_rewards': [],
'topics': [],
'difficulties': [],
'strategy': 'progressive'
}
questions_per_difficulty = max(1, num_iterations // (len(topics) * len(difficulties)))
iterator = range(num_iterations)
if HAS_TQDM:
iterator = tqdm(iterator, desc="Progressive Strategy", unit="iter")
for iteration in iterator:
phase = iteration // questions_per_difficulty if questions_per_difficulty > 0 else iteration
topic_idx = (phase // len(difficulties)) % len(topics)
diff_idx = phase % len(difficulties)
topic = topics[topic_idx]
difficulty = difficulties[diff_idx]
task = generator.generate_task(topic, difficulty)
accuracy_before = student.evaluate(hard_eval_tasks)
student.learn(task)
accuracy_after = student.evaluate(hard_eval_tasks)
general_accuracy = student.evaluate(general_eval_tasks)
student.advance_time(1.0)
history['iterations'].append(iteration)
history['student_accuracies'].append(general_accuracy)
history['difficult_accuracies'].append(accuracy_after)
history['teacher_rewards'].append(accuracy_after - accuracy_before)
history['topics'].append(topic)
history['difficulties'].append(difficulty)
return history
def train_strategy_teacher(num_iterations: int = 500, seed: int = 42) -> Dict:
"""
Strategy 3: RL Teacher Agent learns optimal curriculum.
"""
random.seed(seed)
np.random.seed(seed)
# --- FIX 3: REMOVED seed=seed ---
generator = MockTaskGenerator()
teacher = TeacherAgent(exploration_bonus=2.0, task_generator=generator)
student = LMStudentAgent(
learning_rate=5e-5,
retention_constant=80.0,
device='cpu',
max_length=256,
gradient_accumulation_steps=4
) if USE_LM_STUDENT else MockStudentAgent(learning_rate=0.15, forgetting_rate=0.01, seed=seed)
topics = generator.get_available_topics()
eval_tasks = [
generator.generate_task(topic, 'medium')
for topic in topics
for _ in range(3)
]
all_difficulties = generator.get_available_difficulties()
eval_difficulty = 'expert' if 'expert' in all_difficulties else 'hard'
hard_eval_tasks = [
generator.generate_task(topic, eval_difficulty)
for topic in topics
for _ in range(5)
]
history = {
'iterations': [],
'student_accuracies': [],
'difficult_accuracies': [],
'teacher_rewards': [],
'actions': [],
'topics': [],
'difficulties': [],
'is_reviews': [],
'strategy': 'teacher'
}
iterator = range(num_iterations)
if HAS_TQDM:
iterator = tqdm(iterator, desc="Teacher Strategy", unit="iter")
for iteration in iterator:
student_state = student.get_state()
action = teacher.select_action(student_state)
if action.is_review:
task = generator.generate_task(action.topic, 'medium')
else:
task = generator.generate_task(action.topic, action.difficulty)
accuracy_before = student.evaluate(eval_tasks)
difficult_acc_before = student.evaluate(hard_eval_tasks)
student.learn(task)
accuracy_after = student.evaluate(eval_tasks)
difficult_acc_after = student.evaluate(hard_eval_tasks)
reward = compute_reward(
accuracy_before,
accuracy_after,
action.difficulty,
action.is_review
)
teacher.update(action, reward)
student.advance_time(1.0)
history['iterations'].append(iteration)
history['student_accuracies'].append(accuracy_after)
history['difficult_accuracies'].append(difficult_acc_after)
history['teacher_rewards'].append(reward)
history['actions'].append(action)
history['topics'].append(action.topic)
history['difficulties'].append(action.difficulty)
history['is_reviews'].append(action.is_review)
return history
def plot_comparison(histories: Dict[str, Dict], save_path: str = 'teacher_agent_dev/comparison_all_strategies.png'):
"""
Create comprehensive comparison plots of all three strategies.
"""
import matplotlib.pyplot as plt
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig, axes = plt.subplots(4, 1, figsize=(16, 14))
colors = {
'Random': '#FF6B6B', # Red
'Progressive': '#4ECDC4', # Teal
'Teacher': '#2ECC71' # Green
}
line_styles = {
'Random': '--',
'Progressive': '-.',
'Teacher': '-'
}
line_widths = {
'Random': 2.0,
'Progressive': 2.0,
'Teacher': 3.5
}
# 1. Plot 1: General Accuracy
ax = axes[0]
for name, history in histories.items():
iterations = history['iterations']
accuracies = history['student_accuracies']
if len(accuracies) > 50:
# Smooth curves
window = 10
smoothed = np.convolve(accuracies, np.ones(window)/window, mode='same')
ax.plot(iterations, smoothed,
label=name,
color=colors[name],
linestyle=line_styles[name],
linewidth=line_widths[name],
alpha=0.9)
else:
ax.plot(iterations, accuracies,
label=name,
color=colors[name],
linestyle=line_styles[name],
linewidth=line_widths[name])
ax.set_xlabel('Training Iteration')
ax.set_ylabel('General Accuracy')
ax.set_title('Learning Curves')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_ylim([0.0, 1.0])
# 2. Plot 2: Difficult Question Accuracy
ax = axes[1]
for name, history in histories.items():
iterations = history['iterations']
difficult_accuracies = history['difficult_accuracies']
if len(difficult_accuracies) > 50:
window = 10
smoothed = np.convolve(difficult_accuracies, np.ones(window)/window, mode='same')
ax.plot(iterations, smoothed,
label=name,
color=colors[name],
linestyle=line_styles[name],
linewidth=line_widths[name])
else:
ax.plot(iterations, difficult_accuracies,
label=name,
color=colors[name],
linestyle=line_styles[name],
linewidth=line_widths[name])
ax.set_xlabel('Training Iteration')
ax.set_ylabel('Accuracy on Hard Questions')
ax.set_title('Performance on Difficult Content')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
ax.set_ylim([0.0, 1.0])
# 3. Plot 3: Topic Coverage
ax = axes[2]
for name, history in histories.items():
iterations = history['iterations']
topics_seen = history['topics']
unique_topics = []
seen_so_far = set()
for topic in topics_seen:
seen_so_far.add(topic)
unique_topics.append(len(seen_so_far))
ax.plot(iterations, unique_topics,
label=name,
color=colors[name],
linestyle=line_styles[name],
linewidth=line_widths[name])
ax.set_xlabel('Training Iteration')
ax.set_ylabel('Unique Topics Seen')
ax.set_title('Curriculum Diversity')
ax.legend(loc='lower right')
ax.grid(True, alpha=0.3)
# 4. Plot 4: Learning Efficiency
ax = axes[3]
target_acc = 0.75
strategy_stats = {}
for name, history in histories.items():
difficult_accuracies = history['difficult_accuracies']
iterations = history['iterations']
reached_target = False
target_iteration = len(iterations) - 1
for i, acc in enumerate(difficult_accuracies):
if acc >= target_acc:
target_iteration = i
reached_target = True
break
strategy_stats[name] = {
'reached': reached_target,
'iteration': target_iteration,
'final_acc': difficult_accuracies[-1]
}
names = list(strategy_stats.keys())
iterations_to_target = [
strategy_stats[n]['iteration'] if strategy_stats[n]['reached'] else len(histories[n]['iterations'])
for n in names
]
final_accs = [strategy_stats[n]['final_acc'] for n in names]
x = np.arange(len(names))
width = 0.35
ax.bar(x - width/2, iterations_to_target, width, label='Iterations to 75% on Hard',
color=[colors[n] for n in names], alpha=0.7)
ax.bar(x + width/2, [acc * max(iterations_to_target) for acc in final_accs], width,
label='Final Hard Accuracy (scaled)',
color=[colors[n] for n in names], alpha=0.5)
ax.set_title('Learning Efficiency')
ax.set_xticks(x)
ax.set_xticklabels(names)
ax.legend()
plt.tight_layout()
plt.savefig(save_path, dpi=150)
print(f"\nβœ… Saved comparison plot to {save_path}")
plt.close()
if __name__ == "__main__":
import argparse
import time
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=None)
parser.add_argument('--iterations', type=int, default=500)
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--runs', type=int, default=1)
args = parser.parse_args()
if args.deterministic:
seed = 42
print("⚠️ Using deterministic mode (seed=42)")
elif args.seed is not None:
seed = args.seed
else:
seed = int(time.time()) % 10000
print(f"Using seed: {seed}")
num_iterations = args.iterations
# Run strategies
print("Training Random Strategy...")
history_random = train_strategy_random(num_iterations=num_iterations, seed=seed)
print("\nTraining Progressive Strategy...")
history_progressive = train_strategy_progressive(num_iterations=num_iterations, seed=seed)
print("\nTraining Teacher Strategy...")
history_teacher = train_strategy_teacher(num_iterations=num_iterations, seed=seed)
histories = {
'Random': history_random,
'Progressive': history_progressive,
'Teacher': history_teacher
}
plot_comparison(histories, save_path='teacher_agent_dev/comparison_all_strategies.png')
print("\nβœ… Comparison complete!")