Upload train_classifier.ipynb
Browse files- train_classifier.ipynb +20 -5
train_classifier.ipynb
CHANGED
|
@@ -339,7 +339,7 @@
|
|
| 339 |
" min_samples = counts.min()\n",
|
| 340 |
" # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
|
| 341 |
" # target_samples = int(2.0 * min_samples)\n",
|
| 342 |
-
" target_samples =
|
| 343 |
" \n",
|
| 344 |
" indices_to_keep = np.hstack([\n",
|
| 345 |
" np.random.choice(\n",
|
|
@@ -521,7 +521,7 @@
|
|
| 521 |
"# Loss and optimizer\n",
|
| 522 |
"criterion = nn.CrossEntropyLoss()\n",
|
| 523 |
"optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
|
| 524 |
-
"lambda_l1 = 1e-
|
| 525 |
]
|
| 526 |
},
|
| 527 |
{
|
|
@@ -539,7 +539,7 @@
|
|
| 539 |
"metadata": {},
|
| 540 |
"outputs": [],
|
| 541 |
"source": [
|
| 542 |
-
"epochs =
|
| 543 |
"train_losses, test_losses = [], []\n",
|
| 544 |
"\n",
|
| 545 |
"for epoch in range(epochs):\n",
|
|
@@ -577,7 +577,7 @@
|
|
| 577 |
" \n",
|
| 578 |
" precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
|
| 579 |
" accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
|
| 580 |
-
" if epoch %
|
| 581 |
" print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
|
| 582 |
]
|
| 583 |
},
|
|
@@ -620,6 +620,9 @@
|
|
| 620 |
"metadata": {},
|
| 621 |
"outputs": [],
|
| 622 |
"source": [
|
|
|
|
|
|
|
|
|
|
| 623 |
"conf_matrix = confusion_matrix(all_targets, all_preds)\n",
|
| 624 |
"labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
|
| 625 |
" # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
|
|
@@ -627,7 +630,19 @@
|
|
| 627 |
"# plt.title('Confusion Matrix')\n",
|
| 628 |
"plt.xlabel('Predicted Label')\n",
|
| 629 |
"plt.ylabel('True Label')\n",
|
| 630 |
-
"plt.show()"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
]
|
| 632 |
},
|
| 633 |
{
|
|
|
|
| 339 |
" min_samples = counts.min()\n",
|
| 340 |
" # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
|
| 341 |
" # target_samples = int(2.0 * min_samples)\n",
|
| 342 |
+
" target_samples = 7500\n",
|
| 343 |
" \n",
|
| 344 |
" indices_to_keep = np.hstack([\n",
|
| 345 |
" np.random.choice(\n",
|
|
|
|
| 521 |
"# Loss and optimizer\n",
|
| 522 |
"criterion = nn.CrossEntropyLoss()\n",
|
| 523 |
"optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
|
| 524 |
+
"lambda_l1 = 1e-5 # L1 regularization strength"
|
| 525 |
]
|
| 526 |
},
|
| 527 |
{
|
|
|
|
| 539 |
"metadata": {},
|
| 540 |
"outputs": [],
|
| 541 |
"source": [
|
| 542 |
+
"epochs = 10\n",
|
| 543 |
"train_losses, test_losses = [], []\n",
|
| 544 |
"\n",
|
| 545 |
"for epoch in range(epochs):\n",
|
|
|
|
| 577 |
" \n",
|
| 578 |
" precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
|
| 579 |
" accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
|
| 580 |
+
" if epoch % 2==0:\n",
|
| 581 |
" print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
|
| 582 |
]
|
| 583 |
},
|
|
|
|
| 620 |
"metadata": {},
|
| 621 |
"outputs": [],
|
| 622 |
"source": [
|
| 623 |
+
"print(np.unique(all_targets, return_counts=True))\n",
|
| 624 |
+
"print(np.unique(all_preds, return_counts=True))\n",
|
| 625 |
+
"\n",
|
| 626 |
"conf_matrix = confusion_matrix(all_targets, all_preds)\n",
|
| 627 |
"labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
|
| 628 |
" # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
|
|
|
|
| 630 |
"# plt.title('Confusion Matrix')\n",
|
| 631 |
"plt.xlabel('Predicted Label')\n",
|
| 632 |
"plt.ylabel('True Label')\n",
|
| 633 |
+
"plt.show()\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"def showClassWiseAcc(conf_matrix):\n",
|
| 636 |
+
" # Calculate accuracy per class\n",
|
| 637 |
+
" class_accuracies = conf_matrix.diagonal() / conf_matrix.sum(axis=1)\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" # Prepare accuracy data for writing to file\n",
|
| 640 |
+
" accuracy_data = \"\\n\".join([f\"Accuracy for class {i}: {class_accuracies[i]:.4f}\" for i in range(len(class_accuracies))])\n",
|
| 641 |
+
"\n",
|
| 642 |
+
" # Print accuracy per class and write to a file\n",
|
| 643 |
+
" print(accuracy_data) # Print to console\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"showClassWiseAcc(conf_matrix)"
|
| 646 |
]
|
| 647 |
},
|
| 648 |
{
|