{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# FFN vs KV-cache Bottleneck Analysis\n",
        "\n",
        "This notebook publishes the FFN versus KV-cache attention bottleneck analysis for a Gemma 3N E4B-style on-device SLM experiment. The goal is to make the transition visible: short contexts are dominated by FFN weight reads, while longer contexts make repeated KV-cache reads increasingly important during decode.\n",
        "\n",
        "The numbers below come from the context sweep recorded in the original analysis script. The notebook keeps the arithmetic explicit so the methodology can be reviewed and rerun."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Methodology\n",
        "\n",
        "For each context length, compute two memory-read quantities per generated token:\n",
        "\n",
        "- **FFN read per token**: fixed with respect to context length because each token reads the FFN weights.\n",
        "- **Attention KV read per token**: grows linearly with context length because decode attends over the cached keys and values.\n",
        "\n",
        "The measured stage percentages are then compared with the read-size ratio. The percentages are runtime-profile measurements, while the read-size columns are analytical estimates from model dimensions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "NUM_LAYERS = 35\n",
        "KV_HEADS = 2\n",
        "HEAD_DIM = 256\n",
        "KV_DIM = KV_HEADS * HEAD_DIM\n",
        "HIDDEN_DIM = 2048\n",
        "FFN_DIM = 16384\n",
        "BYTES_PER_ELEM = 2      # FP16 KV-cache\n",
        "WEIGHT_BYTES_PER_ELEM = 0.5  # INT4 weights\n",
        "\n",
        "contexts = [\n",
        "    {\"label\": \"short\",  \"input_tokens\": 13,  \"ffn_pct\": 75.3, \"attn_pct\": 0.9,\n",
        "     \"ffn_ms\": 993.8, \"attn_ms\": 11.8, \"decode_ms\": 1372.1},\n",
        "    {\"label\": \"medium\", \"input_tokens\": 172, \"ffn_pct\": 68.0, \"attn_pct\": 6.7,\n",
        "     \"ffn_ms\": 719.8, \"attn_ms\": 70.9, \"decode_ms\": 1103.8},\n",
        "    {\"label\": \"long\",   \"input_tokens\": 658, \"ffn_pct\": 49.1, \"attn_pct\": 28.9,\n",
        "     \"ffn_ms\": 528.5, \"attn_ms\": 311.3, \"decode_ms\": 1108.5},\n",
        "]\n",
        "\n",
        "per_layer_params = (\n",
        "    HIDDEN_DIM * HIDDEN_DIM +\n",
        "    HIDDEN_DIM * KV_DIM +\n",
        "    HIDDEN_DIM * KV_DIM +\n",
        "    HIDDEN_DIM * HIDDEN_DIM +\n",
        "    HIDDEN_DIM * FFN_DIM +\n",
        "    HIDDEN_DIM * FFN_DIM +\n",
        "    FFN_DIM * HIDDEN_DIM\n",
        ")\n",
        "total_weight_mb = NUM_LAYERS * per_layer_params * WEIGHT_BYTES_PER_ELEM / (1024**2)\n",
        "\n",
        "rows = []\n",
        "for ctx in contexts:\n",
        "    tokens = ctx[\"input_tokens\"]\n",
        "    kv_bytes = 2 * NUM_LAYERS * KV_HEADS * HEAD_DIM * tokens * BYTES_PER_ELEM\n",
        "    ffn_weight_bytes = NUM_LAYERS * (HIDDEN_DIM * FFN_DIM * 3) * WEIGHT_BYTES_PER_ELEM\n",
        "    attn_read_bytes = kv_bytes\n",
        "    rows.append({\n",
        "        \"label\": ctx[\"label\"],\n",
        "        \"input_tokens\": tokens,\n",
        "        \"kv_cache_mb\": round(kv_bytes / (1024**2), 2),\n",
        "        \"ffn_read_per_tok_mb\": round(ffn_weight_bytes / (1024**2), 1),\n",
        "        \"attn_read_per_tok_mb\": round(attn_read_bytes / (1024**2), 2),\n",
        "        \"attn_to_ffn_read_ratio\": round(attn_read_bytes / ffn_weight_bytes, 3),\n",
        "        \"measured_ffn_pct\": ctx[\"ffn_pct\"],\n",
        "        \"measured_attn_pct\": ctx[\"attn_pct\"],\n",
        "        \"ffn_ms\": ctx[\"ffn_ms\"],\n",
        "        \"attn_ms\": ctx[\"attn_ms\"],\n",
        "        \"decode_ms\": ctx[\"decode_ms\"],\n",
        "    })\n",
        "\n",
        "df = pd.DataFrame(rows)\n",
        "df"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Recorded context sweep\n",
        "\n",
        "| Context | Input tokens | KV-cache MB | FFN read/token MB | Attention read/token MB | Attention/FFN read ratio | Measured FFN % | Measured attention % |\n",
        "|---|---:|---:|---:|---:|---:|---:|---:|\n",
        "| short | 13 | 0.89 | 1680.0 | 0.89 | 0.001 | 75.3 | 0.9 |\n",
        "| medium | 172 | 11.76 | 1680.0 | 11.76 | 0.007 | 68.0 | 6.7 |\n",
        "| long | 658 | 44.98 | 1680.0 | 44.98 | 0.027 | 49.1 | 28.9 |\n",
        "\n",
        "Model weight estimate: about **1,855 MB** for the listed INT4 projection weights under the simplified accounting used here."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
        "\n",
        "axes[0].plot(df[\"input_tokens\"], df[\"ffn_read_per_tok_mb\"], marker=\"o\", label=\"FFN read/token\")\n",
        "axes[0].plot(df[\"input_tokens\"], df[\"attn_read_per_tok_mb\"], marker=\"o\", label=\"Attention KV read/token\")\n",
        "axes[0].set_title(\"Per-token memory read estimate\")\n",
        "axes[0].set_xlabel(\"Input context tokens\")\n",
        "axes[0].set_ylabel(\"MB per generated token\")\n",
        "axes[0].grid(True, alpha=0.3)\n",
        "axes[0].legend()\n",
        "\n",
        "axes[1].plot(df[\"input_tokens\"], df[\"measured_ffn_pct\"], marker=\"o\", label=\"Measured FFN share\")\n",
        "axes[1].plot(df[\"input_tokens\"], df[\"measured_attn_pct\"], marker=\"o\", label=\"Measured attention share\")\n",
        "axes[1].set_title(\"Measured decode component share\")\n",
        "axes[1].set_xlabel(\"Input context tokens\")\n",
        "axes[1].set_ylabel(\"Percent of decode profile\")\n",
        "axes[1].grid(True, alpha=0.3)\n",
        "axes[1].legend()\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "fig, ax = plt.subplots(figsize=(7, 4))\n",
        "ax.bar(df[\"label\"], df[\"ffn_ms\"], label=\"FFN ms/token\")\n",
        "ax.bar(df[\"label\"], df[\"attn_ms\"], bottom=df[\"ffn_ms\"], label=\"Attention ms/token\")\n",
        "ax.set_title(\"FFN and attention contribution by context\")\n",
        "ax.set_xlabel(\"Context bucket\")\n",
        "ax.set_ylabel(\"Measured ms/token\")\n",
        "ax.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Interpretation\n",
        "\n",
        "Short context uses a tiny KV-cache, so attention reads are small and FFN GEMV dominates decode time. As context length grows, the attention side remains a GEMV-like memory-read problem, but the matrix being read is the accumulated KV-cache rather than fixed weights. That is why attention rises from a negligible share toward a major share at longer contexts.\n",
        "\n",
        "This does not mean FFN stops mattering. It means decode has two memory-bound paths whose relative pressure changes with context length."
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
