| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | import matplotlib as mpl |
| | import argparse |
| |
|
| |
|
| | mpl.rcParams['font.family'] = 'serif' |
| | mpl.rcParams['font.serif'] = ['Georgia'] |
| | mpl.rcParams['font.size'] = 20 |
| | mpl.rcParams['axes.titlesize']= 20 |
| | mpl.rcParams['axes.labelsize']= 18 |
| | mpl.rcParams['xtick.labelsize']=16 |
| | mpl.rcParams['ytick.labelsize']=16 |
| | |
| |
|
| | def plot_two_loss_curves( |
| | csv_file1, |
| | csv_file2, |
| | title="Loss Comparison on Qwen3-8B", |
| | dataset1_name="Dataset1", |
| | dataset2_name="Dataset2" |
| | ): |
| | |
| | df1 = pd.read_csv(csv_file1) |
| | df2 = pd.read_csv(csv_file2) |
| |
|
| | |
| | for df, path in ((df1, csv_file1), (df2, csv_file2)): |
| | if 'Step' not in df.columns or 'Loss' not in df.columns: |
| | raise ValueError(f"Missing 'Step' or 'Loss' columns in {path}") |
| |
|
| | |
| | plt.figure(figsize=(12, 8)) |
| |
|
| | |
| | plt.plot(df1['Step'], df1['Loss'], |
| | color='#1f77b4', linewidth=2.5) |
| | plt.plot(df2['Step'], df2['Loss'], |
| | color='#2ca02c', linewidth=2.5) |
| |
|
| | |
| | plt.title(title, fontweight='bold') |
| | plt.xlabel('Steps', fontweight='bold') |
| | plt.ylabel('Loss', fontweight='bold') |
| |
|
| | |
| | plt.grid(True, linestyle='--', alpha=0.7) |
| |
|
| | |
| | plt.tight_layout(pad=3.0) |
| |
|
| | |
| | plt.savefig('loss_comparison_qwen38b.svg', format='svg') |
| | plt.savefig('loss_comparison.png', dpi=300) |
| |
|
| | |
| | plt.show() |
| |
|
| | print("Saved: loss_comparison.svg, loss_comparison.png") |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Plot comparison of two training loss curves') |
| | parser.add_argument('csv_file1', help='Path to the first CSV file') |
| | parser.add_argument('csv_file2', help='Path to the second CSV file') |
| | parser.add_argument('--title', default='Training Loss Comparison', help='Title for the plot') |
| | parser.add_argument('--dataset1-name', default='Original Dataset', help='Name for the first dataset') |
| | parser.add_argument('--dataset2-name', default='Revised Dataset', help='Name for the second dataset') |
| | |
| | args = parser.parse_args() |
| | |
| | plot_two_loss_curves( |
| | args.csv_file1, |
| | args.csv_file2, |
| | title=args.title, |
| | dataset1_name=args.dataset1_name, |
| | dataset2_name=args.dataset2_name |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |