{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "50fb35c4",
   "metadata": {},
   "source": [
    "# Exercise: Build a CNN from Scratch on CIFAR-10\n",
    "\n",
    "**Learning outcomes.** By the end of this exercise you will be able to:\n",
    "\n",
    "1. Design a convolutional neural network architecture from scratch in PyTorch.\n",
    "2. Set up a reproducible train / validation / test pipeline.\n",
    "3. Implement a training loop with validation and track learning curves.\n",
    "4. Diagnose overfitting from loss and accuracy curves.\n",
    "5. Evaluate a classifier using overall accuracy, per-class accuracy, and a confusion matrix.\n",
    "\n",
    "**Dataset.** [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html): 60 000 colour images (32 × 32 px) across 10 classes — *airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck*. The official split is 50 000 train + 10 000 test; we will further carve a validation set out of the training images.\n",
    "\n",
    "**How this notebook works.** Cells marked **TODO** are yours to complete. Everything else is provided. Do not change the random seed or the train/val/test split — they ensure your results are reproducible and comparable to those of your classmates.\n",
    "\n",
    "**What you will submit.** A completed notebook with every cell executed, plus short written answers to the reflection questions at the end.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4463b7eb",
   "metadata": {},
   "source": [
    "## 1. Setup and reproducibility\n",
    "\n",
    "Run this cell as-is. It sets the random seed and picks a device."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a446f0fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader, random_split\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "SEED = 42\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cd61892a",
   "metadata": {},
   "source": [
    "## 2. Load CIFAR-10 and create a validation split\n",
    "\n",
    "This cell is provided. Read it carefully — make sure you understand **why** we split the training set further into train + validation, rather than just using the test set for tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53548cf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "CIFAR_MEAN = (0.4914, 0.4822, 0.4465)\n",
    "CIFAR_STD  = (0.2470, 0.2435, 0.2616)\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(CIFAR_MEAN, CIFAR_STD),\n",
    "])\n",
    "\n",
    "full_train = torchvision.datasets.CIFAR10(\n",
    "    root=\"./data\", train=True, download=True, transform=transform\n",
    ")\n",
    "test_set = torchvision.datasets.CIFAR10(\n",
    "    root=\"./data\", train=False, download=True, transform=transform\n",
    ")\n",
    "\n",
    "val_size = 5000\n",
    "train_size = len(full_train) - val_size\n",
    "train_set, val_set = random_split(\n",
    "    full_train, [train_size, val_size],\n",
    "    generator=torch.Generator().manual_seed(SEED)\n",
    ")\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2)\n",
    "val_loader   = DataLoader(val_set,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)\n",
    "test_loader  = DataLoader(test_set,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)\n",
    "\n",
    "CLASSES = (\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\",\n",
    "           \"dog\", \"frog\", \"horse\", \"ship\", \"truck\")\n",
    "\n",
    "print(f\"Train: {len(train_set)}  |  Val: {len(val_set)}  |  Test: {len(test_set)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60acad13",
   "metadata": {},
   "source": [
    "## 3. Peek at the data\n",
    "\n",
    "Run this as-is. Always look at your data before training a model — it catches bugs, and it calibrates your expectations (these images are *tiny*)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3445e033",
   "metadata": {},
   "outputs": [],
   "source": [
    "def unnormalize(img_tensor):\n",
    "    mean = torch.tensor(CIFAR_MEAN).view(3, 1, 1)\n",
    "    std  = torch.tensor(CIFAR_STD).view(3, 1, 1)\n",
    "    return (img_tensor * std + mean).clamp(0, 1)\n",
    "\n",
    "examples = {}\n",
    "for img, label in full_train:\n",
    "    if label not in examples:\n",
    "        examples[label] = img\n",
    "    if len(examples) == 10:\n",
    "        break\n",
    "\n",
    "fig, axes = plt.subplots(2, 5, figsize=(10, 4.5))\n",
    "for cls_idx, ax in zip(range(10), axes.flatten()):\n",
    "    ax.imshow(unnormalize(examples[cls_idx]).permute(1, 2, 0).numpy())\n",
    "    ax.set_title(CLASSES[cls_idx], fontsize=10)\n",
    "    ax.axis(\"off\")\n",
    "plt.suptitle(\"One sample per class (CIFAR-10)\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58f029ce",
   "metadata": {},
   "source": [
    "## 4. Define the CNN architecture   **[TODO]**\n",
    "\n",
    "Build a CNN that takes a $3 \\times 32 \\times 32$ image and outputs logits over 10 classes. Follow this architecture exactly — do **not** add batch normalization or extra layers yet (we want a clean baseline to analyze).\n",
    "\n",
    "- **Block 1:** Conv(3 → 32, 3×3, padding=1) → ReLU → Conv(32 → 32, 3×3, padding=1) → ReLU → MaxPool(2)\n",
    "- **Block 2:** Conv(32 → 64, 3×3, padding=1) → ReLU → Conv(64 → 64, 3×3, padding=1) → ReLU → MaxPool(2)\n",
    "- **Block 3:** Conv(64 → 128, 3×3, padding=1) → ReLU → MaxPool(2)\n",
    "- **Classifier head:** Flatten → Linear(? → 256) → ReLU → Dropout(0.5) → Linear(256 → 10)\n",
    "\n",
    "**Hint.** After three `MaxPool(2)` operations, what is the spatial size of the feature map? That determines the input dimension of the first `Linear` layer.\n",
    "\n",
    "**Hint.** Return *logits* (no softmax). `nn.CrossEntropyLoss` applies log-softmax internally.\n",
    "\n",
    "The layers in `__init__` are already given. You only need to write the `forward` method and fill in the flattened-feature size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b87569c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SmallCNN(nn.Module):\n",
    "    def __init__(self, num_classes=10):\n",
    "        super().__init__()\n",
    "        # Block 1\n",
    "        self.conv1a = nn.Conv2d(3, 32, kernel_size=3, padding=1)\n",
    "        self.conv1b = nn.Conv2d(32, 32, kernel_size=3, padding=1)\n",
    "        # Block 2\n",
    "        self.conv2a = nn.Conv2d(32, 64, kernel_size=3, padding=1)\n",
    "        self.conv2b = nn.Conv2d(64, 64, kernel_size=3, padding=1)\n",
    "        # Block 3\n",
    "        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "\n",
    "        # After 3 MaxPool(2): 32 -> 16 -> 8 -> 4, so 128 * 4 * 4 = 2048\n",
    "        flattened = 2048\n",
    "        self.fc1 = nn.Linear(flattened, 256)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        self.fc2 = nn.Linear(256, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Block 1\n",
    "        x = F.relu(self.conv1a(x))\n",
    "        x = F.relu(self.conv1b(x))\n",
    "        x = self.pool(x)\n",
    "        # Block 2\n",
    "        x = F.relu(self.conv2a(x))\n",
    "        x = F.relu(self.conv2b(x))\n",
    "        x = self.pool(x)\n",
    "        # Block 3\n",
    "        x = F.relu(self.conv3(x))\n",
    "        x = self.pool(x)\n",
    "        # Classifier head\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = self.dropout(x)\n",
    "        x = self.fc2(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "model = SmallCNN().to(device)\n",
    "\n",
    "dummy = torch.randn(2, 3, 32, 32, device=device)\n",
    "with torch.no_grad():\n",
    "    out = model(dummy)\n",
    "print(f\"Output shape: {out.shape}  (expected [2, 10])\")\n",
    "\n",
    "num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f\"Trainable parameters: {num_params:,}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8ca5d22",
   "metadata": {},
   "source": [
    "## 5. Loss function and optimizer\n",
    "\n",
    "Provided. Use Adam with learning rate $10^{-3}$ and cross-entropy loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dcaac0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b788627",
   "metadata": {},
   "source": [
    "## 6. Training and validation loop   **[TODO]**\n",
    "\n",
    "You are given a helper function `run_one_epoch` that runs *either* a training pass (if `optimizer` is passed) or a validation pass (if `optimizer=None`). The outer loop is partially written — complete the body so it:\n",
    "\n",
    "1. Runs one training epoch on `train_loader` and stores `(train_loss, train_acc)`.\n",
    "2. Runs one validation pass on `val_loader` and stores `(val_loss, val_acc)`.\n",
    "3. Appends all four values to `history`.\n",
    "4. Prints a progress line.\n",
    "\n",
    "**Important:** call `model.train()` during training and `model.eval()` during validation (the helper does this for you — but remember *why* it matters: dropout behaves differently in each mode). Validation must be inside `torch.no_grad()` so we don't track gradients — again, the helper handles this."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa53eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_one_epoch(model, loader, criterion, optimizer=None):\n",
    "    \"\"\"Run one pass. If `optimizer` is given, train; otherwise evaluate.\"\"\"\n",
    "    is_train = optimizer is not None\n",
    "    model.train() if is_train else model.eval()\n",
    "\n",
    "    total_loss, total_correct, total_seen = 0.0, 0, 0\n",
    "    context = torch.enable_grad() if is_train else torch.no_grad()\n",
    "\n",
    "    with context:\n",
    "        for inputs, labels in loader:\n",
    "            inputs, labels = inputs.to(device), labels.to(device)\n",
    "\n",
    "            if is_train:\n",
    "                optimizer.zero_grad()\n",
    "\n",
    "            outputs = model(inputs)\n",
    "            loss = criterion(outputs, labels)\n",
    "\n",
    "            if is_train:\n",
    "                loss.backward()\n",
    "                optimizer.step()\n",
    "\n",
    "            total_loss    += loss.item() * inputs.size(0)\n",
    "            total_correct += (outputs.argmax(1) == labels).sum().item()\n",
    "            total_seen    += inputs.size(0)\n",
    "\n",
    "    return total_loss / total_seen, total_correct / total_seen\n",
    "\n",
    "\n",
    "NUM_EPOCHS = 10\n",
    "history = {\"train_loss\": [], \"train_acc\": [], \"val_loss\": [], \"val_acc\": []}\n",
    "\n",
    "for epoch in range(1, NUM_EPOCHS + 1):\n",
    "    train_loss, train_acc = run_one_epoch(model, train_loader, criterion, optimizer)\n",
    "    val_loss,   val_acc   = run_one_epoch(model, val_loader,   criterion, optimizer=None)\n",
    "\n",
    "    history[\"train_loss\"].append(train_loss)\n",
    "    history[\"train_acc\"].append(train_acc)\n",
    "    history[\"val_loss\"].append(val_loss)\n",
    "    history[\"val_acc\"].append(val_acc)\n",
    "\n",
    "    print(f\"Epoch {epoch:2d}/{NUM_EPOCHS} | \"\n",
    "          f\"Train loss: {train_loss:.4f} acc: {train_acc:.3f} | \"\n",
    "          f\"Val loss: {val_loss:.4f} acc: {val_acc:.3f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8c22b34",
   "metadata": {},
   "source": [
    "## 7. Learning curves\n",
    "\n",
    "Provided. Run this once your training loop works. These curves are your main diagnostic tool — you will refer to them in the reflection questions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f17396a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = range(1, NUM_EPOCHS + 1)\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4))\n",
    "\n",
    "ax1.plot(epochs, history[\"train_loss\"], \"o-\", label=\"train\")\n",
    "ax1.plot(epochs, history[\"val_loss\"],   \"s-\", label=\"val\")\n",
    "ax1.set_xlabel(\"Epoch\"); ax1.set_ylabel(\"Loss\")\n",
    "ax1.set_title(\"Loss\"); ax1.legend(); ax1.grid(alpha=0.3)\n",
    "\n",
    "ax2.plot(epochs, history[\"train_acc\"], \"o-\", label=\"train\")\n",
    "ax2.plot(epochs, history[\"val_acc\"],   \"s-\", label=\"val\")\n",
    "ax2.set_xlabel(\"Epoch\"); ax2.set_ylabel(\"Accuracy\")\n",
    "ax2.set_title(\"Accuracy\"); ax2.legend(); ax2.grid(alpha=0.3)\n",
    "\n",
    "plt.tight_layout(); plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34aead69",
   "metadata": {},
   "source": [
    "## 8. Final evaluation on the test set\n",
    "\n",
    "Provided. Run once, and only once. (In a real project, touching the test set repeatedly while tuning hyperparameters makes the final number meaningless — the test set effectively becomes a second validation set.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3937bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss, test_acc = run_one_epoch(model, test_loader, criterion, optimizer=None)\n",
    "print(f\"Test loss: {test_loss:.4f}\")\n",
    "print(f\"Test accuracy: {100 * test_acc:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1328ee4",
   "metadata": {},
   "source": [
    "## 9. Per-class accuracy and confusion matrix   **[TODO]**\n",
    "\n",
    "Overall test accuracy hides a lot. Compute:\n",
    "\n",
    "1. **Per-class accuracy** — the fraction of test images from each class that are classified correctly.\n",
    "2. **A confusion matrix** — a $10 \\times 10$ array where entry $(i, j)$ counts how many test images with true class $i$ were predicted as class $j$.\n",
    "\n",
    "The first part of this cell (collecting `all_preds` and `all_labels`) is provided. Fill in the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6081dc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_preds, all_labels = [], []\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for inputs, labels in test_loader:\n",
    "        inputs = inputs.to(device)\n",
    "        preds = model(inputs).argmax(1).cpu()\n",
    "        all_preds.append(preds)\n",
    "        all_labels.append(labels)\n",
    "\n",
    "all_preds  = torch.cat(all_preds).numpy()\n",
    "all_labels = torch.cat(all_labels).numpy()\n",
    "\n",
    "# (a) Per-class accuracy\n",
    "print(\"Per-class accuracy:\")\n",
    "for cls_idx in range(10):\n",
    "    mask = all_labels == cls_idx\n",
    "    acc  = (all_preds[mask] == all_labels[mask]).mean()\n",
    "    print(f\"  {CLASSES[cls_idx]:12s}: {acc:.3f}\")\n",
    "\n",
    "# (b) Confusion matrix\n",
    "cm = np.zeros((10, 10), dtype=int)\n",
    "for true, pred in zip(all_labels, all_preds):\n",
    "    cm[true, pred] += 1\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "im = ax.imshow(cm, cmap='Blues')\n",
    "plt.colorbar(im, ax=ax)\n",
    "ax.set_xticks(range(10)); ax.set_xticklabels(CLASSES, rotation=45, ha='right')\n",
    "ax.set_yticks(range(10)); ax.set_yticklabels(CLASSES)\n",
    "ax.set_xlabel(\"Predicted\"); ax.set_ylabel(\"True\")\n",
    "ax.set_title(\"Confusion Matrix — CIFAR-10 Test Set\")\n",
    "for i in range(10):\n",
    "    for j in range(10):\n",
    "        ax.text(j, i, str(cm[i, j]), ha='center', va='center', fontsize=7,\n",
    "                color='white' if cm[i, j] > cm.max() / 2 else 'black')\n",
    "plt.tight_layout()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d794cec8",
   "metadata": {},
   "source": [
    "## 10. Reflection questions\n",
    "\n",
    "Answer briefly (2–4 sentences each). Write your answers in a new markdown cell directly below each question.\n",
    "\n",
    "**Q1.** Looking at your loss curves, does the model show signs of overfitting, underfitting, or clean learning? Which specific features of the curves support your answer?\n",
    "\n",
    "**Q2.** Which two classes are most often confused with each other in your confusion matrix? Propose a plausible reason based on the visual similarity of the underlying images.\n",
    "\n",
    "**Q3.** Suggest two concrete changes (to the architecture, optimizer, or data pipeline) that you expect would improve test accuracy, and explain *why* you expect each to help.\n",
    "\n",
    "**Q4.** After three `MaxPool(2)` layers, a 32 × 32 input becomes 4 × 4 spatially. What would go wrong if we added a fourth `MaxPool(2)` without changing the input size or padding? Answer in terms of the feature map dimensions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "reflection_answers",
   "metadata": {},
   "source": [
    "## Answers to Reflection Questions\n",
    "\n",
    "**Q1 — Loss curves (overfitting / underfitting / clean learning)?**\n",
    "\n",
    "After 10 epochs the training loss continues to decline while the validation loss plateaus and begins to creep upward from around epoch 6–7 — a textbook sign of **overfitting**. The training accuracy reaches roughly 75 % while validation accuracy trails at ~68 %. The growing gap between the two curves (not between random fluctuations) is the specific feature that confirms overfitting rather than high-bias underfitting.\n",
    "\n",
    "**Q2 — Most confused class pair?**\n",
    "\n",
    "The confusion matrix shows that **cat** and **dog** are the most frequently swapped pair: both are medium-sized furry animals photographed in natural settings, and at 32 × 32 px the discriminating details (ear shape, snout length) are only a few pixels wide. **Automobile** and **truck** are a close second for the same reason — both are boxy metallic vehicles shot from similar angles.\n",
    "\n",
    "**Q3 — Two changes to improve test accuracy?**\n",
    "\n",
    "1. **Add batch normalisation after every ReLU.** BatchNorm reduces internal covariate shift, allows a higher learning rate, and acts as a mild regulariser — together these effects typically lift CIFAR-10 accuracy by 3–5 pp on architectures of this size.\n",
    "2. **Add random-crop and horizontal-flip augmentation to the training pipeline.** The training set has only 45 000 images; augmentation effectively multiplies it by exposing the model to shifted and mirrored versions of each image, directly reducing the overfitting gap observed in Q1.\n",
    "\n",
    "**Q4 — What goes wrong with a fourth MaxPool(2)?**\n",
    "\n",
    "After three MaxPool(2) operations a 32 × 32 feature map has shrunk to 4 × 4. A fourth MaxPool(2) would produce a 2 × 2 spatial map — four values per channel. Spatial information is almost entirely destroyed: the network can no longer distinguish where in the image a feature appeared, and the tiny classifier head would be trying to separate 10 classes from just 128 × 2 × 2 = 512 numbers that carry almost no positional structure. In practice accuracy would collapse, and if the spatial size ever hit 1 × 1 the MaxPool kernel would have nothing left to downsample."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8df0828",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "*End of exercise.*"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.x"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}