{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Strengthened Latency Benchmark Methodology\n",
        "\n",
        "This notebook is a stronger benchmark harness derived from the original `run_experiment.py` flow. It keeps the same load-once, repeat-many structure, but expands the methodology:\n",
        "\n",
        "- default run count is at least **N = 20** per prompt and output-token setting;\n",
        "- output-token settings are swept across **32 / 64 / 128 / 256**;\n",
        "- multiple prompts are tested;\n",
        "- summary reporting includes mean, standard deviation, p50, p90, and p99;\n",
        "- plots cover latency versus output length, component breakdown, and run variance.\n",
        "\n",
        "## How to run\n",
        "\n",
        "Run this notebook from the local Gemma 3N E4B experiment directory that contains `CPU_CORE.py`, `safeTensor.py`, `main.py`, and the local weight files. The benchmark cells do not require changes to the published site. Set the configuration cell first, then run the hardware-dependent cells."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from pathlib import Path\n",
        "import csv\n",
        "import gc\n",
        "import os\n",
        "import shutil\n",
        "import time\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "WORKSPACE = Path.cwd()\n",
        "OUTPUT_CSV = WORKSPACE / \"strengthened_latency_results.csv\"\n",
        "PROFILE_DIR = WORKSPACE / \"strengthened_latency_profiles\"\n",
        "\n",
        "WEIGHT_MODE = \"INT4\"\n",
        "FEATURE_MODE = \"FP32\"\n",
        "TEMPERATURE = 0.65\n",
        "TOP_P = 0.9\n",
        "REP_PENALTY = 1.15\n",
        "KV_DIM = 512\n",
        "\n",
        "N_RUNS = 20\n",
        "OUTPUT_TOKEN_SETTINGS = [32, 64, 128, 256]\n",
        "PROMPTS = [\n",
        "    \"What is on-device AI and why is it important?\",\n",
        "    \"Explain why KV-cache size matters during long-context decoding.\",\n",
        "    \"Summarize the tradeoffs between quantization and model quality.\",\n",
        "    \"Write a concise note about memory bandwidth bottlenecks in transformer inference.\",\n",
        "]\n",
        "\n",
        "os.environ.setdefault(\"OMP_NUM_THREADS\", \"6\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def load_runtime_modules(workspace: Path):\n",
        "    import sys\n",
        "\n",
        "    os.chdir(workspace)\n",
        "    if str(workspace) not in sys.path:\n",
        "        sys.path.insert(0, str(workspace))\n",
        "\n",
        "    import CPU_CORE\n",
        "    import safeTensor\n",
        "\n",
        "    try:\n",
        "        import IGPU_CORE as FAST_MATRIX_CORE\n",
        "        accel = \"IGPU\"\n",
        "    except Exception:\n",
        "        import CPU_MATRIX_CORE as FAST_MATRIX_CORE\n",
        "        accel = \"CPU\"\n",
        "\n",
        "    from main import (\n",
        "        forward_one_token,\n",
        "        decode_logits,\n",
        "        _sample,\n",
        "        GLOBAL_PROFILE_DATA,\n",
        "        generate_profile_html,\n",
        "        _IGPU_WEIGHT_KEYS,\n",
        "        NUM_LAYERS,\n",
        "        _variant_dir_for_mode,\n",
        "    )\n",
        "\n",
        "    return {\n",
        "        \"CPU_CORE\": CPU_CORE,\n",
        "        \"safeTensor\": safeTensor,\n",
        "        \"FAST_MATRIX_CORE\": FAST_MATRIX_CORE,\n",
        "        \"ACCEL\": accel,\n",
        "        \"forward_one_token\": forward_one_token,\n",
        "        \"decode_logits\": decode_logits,\n",
        "        \"_sample\": _sample,\n",
        "        \"GLOBAL_PROFILE_DATA\": GLOBAL_PROFILE_DATA,\n",
        "        \"generate_profile_html\": generate_profile_html,\n",
        "        \"_IGPU_WEIGHT_KEYS\": _IGPU_WEIGHT_KEYS,\n",
        "        \"NUM_LAYERS\": NUM_LAYERS,\n",
        "        \"_variant_dir_for_mode\": _variant_dir_for_mode,\n",
        "    }\n",
        "\n",
        "\n",
        "def load_model(runtime):\n",
        "    FAST_MATRIX_CORE = runtime[\"FAST_MATRIX_CORE\"]\n",
        "    safeTensor = runtime[\"safeTensor\"]\n",
        "    _variant_dir_for_mode = runtime[\"_variant_dir_for_mode\"]\n",
        "    _IGPU_WEIGHT_KEYS = runtime[\"_IGPU_WEIGHT_KEYS\"]\n",
        "\n",
        "    FAST_MATRIX_CORE.warmup()\n",
        "    vdir = _variant_dir_for_mode(WEIGHT_MODE)\n",
        "    if vdir is None:\n",
        "        if WEIGHT_MODE.lower() == \"int4\" and os.path.isdir(safeTensor.mmap_dir):\n",
        "            vdir = safeTensor.mmap_dir\n",
        "        else:\n",
        "            raise FileNotFoundError(f\"No weights found for {WEIGHT_MODE}\")\n",
        "\n",
        "    W_embed, W_ple_packed, W_ple_scale, norm_ple, W_ple_proj, \\\n",
        "        altup_projs, altup_unprojs, W_final_norm, W_lm_head, W = \\\n",
        "        safeTensor.load_local_weights(model_dir=vdir, mode=WEIGHT_MODE.lower())\n",
        "\n",
        "    FAST_MATRIX_CORE.preload_and_free(W, _IGPU_WEIGHT_KEYS)\n",
        "    FAST_MATRIX_CORE._get_or_upload_weight(W_lm_head)\n",
        "    return (W, W_embed, W_ple_packed, W_ple_scale, norm_ple,\n",
        "            W_ple_proj, altup_projs, altup_unprojs, W_final_norm, W_lm_head)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def run_once(runtime, prompt, max_tokens, weights):\n",
        "    CPU_CORE = runtime[\"CPU_CORE\"]\n",
        "    forward_one_token = runtime[\"forward_one_token\"]\n",
        "    decode_logits = runtime[\"decode_logits\"]\n",
        "    _sample = runtime[\"_sample\"]\n",
        "    GLOBAL_PROFILE_DATA = runtime[\"GLOBAL_PROFILE_DATA\"]\n",
        "    NUM_LAYERS = runtime[\"NUM_LAYERS\"]\n",
        "\n",
        "    (W, W_embed, W_ple_packed, W_ple_scale, norm_ple,\n",
        "     W_ple_proj, altup_projs, altup_unprojs, W_final_norm, W_lm_head) = weights\n",
        "\n",
        "    K = np.zeros((NUM_LAYERS, 2048, KV_DIM), dtype=np.float16)\n",
        "    V = np.zeros((NUM_LAYERS, 2048, KV_DIM), dtype=np.float16)\n",
        "    pos = 0\n",
        "\n",
        "    full = f\"<start_of_turn>user\\n{prompt}<end_of_turn>\\n<start_of_turn>model\\n\"\n",
        "    tokens = CPU_CORE.tokenize(full)\n",
        "    GLOBAL_PROFILE_DATA.clear()\n",
        "\n",
        "    for tid in tokens:\n",
        "        xs = forward_one_token(tid, pos, W, W_embed, W_ple_packed, W_ple_scale,\n",
        "                               norm_ple, W_ple_proj, altup_projs, K, V)\n",
        "        GLOBAL_PROFILE_DATA[-1][\"stage\"] = \"Prefill\"\n",
        "        pos += 1\n",
        "\n",
        "    generated = []\n",
        "    for _ in range(max_tokens):\n",
        "        logits = decode_logits(xs, altup_unprojs, W_final_norm, W_lm_head)\n",
        "        logits = 30.0 * np.tanh(logits / 30.0)\n",
        "        nt = _sample(logits, TEMPERATURE, TOP_P, REP_PENALTY, generated)\n",
        "        if nt in [1, 106]:\n",
        "            break\n",
        "        generated.append(nt)\n",
        "        xs = forward_one_token(nt, pos, W, W_embed, W_ple_packed, W_ple_scale,\n",
        "                               norm_ple, W_ple_proj, altup_projs, K, V)\n",
        "        GLOBAL_PROFILE_DATA[-1][\"stage\"] = \"Decode\"\n",
        "        pos += 1\n",
        "\n",
        "    prefill = [r for r in GLOBAL_PROFILE_DATA if r.get(\"stage\") == \"Prefill\"]\n",
        "    decode = [r for r in GLOBAL_PROFILE_DATA if r.get(\"stage\") == \"Decode\"]\n",
        "    decode_ms = [r.get(\"_total\", 0) * 1000 for r in decode]\n",
        "\n",
        "    component_keys = [\"ffn\", \"qkv\", \"o_proj\", \"attn\", \"ple\"]\n",
        "    component_ms = {\n",
        "        key: float(np.mean([r.get(key, 0) * 1000 for r in decode])) if decode else 0.0\n",
        "        for key in component_keys\n",
        "    }\n",
        "\n",
        "    mean_decode_ms = float(np.mean(decode_ms)) if decode_ms else 0.0\n",
        "    return {\n",
        "        \"input_tokens\": len(tokens),\n",
        "        \"output_tokens\": len(generated),\n",
        "        \"prefill_ms\": round(sum(r.get(\"_total\", 0) for r in prefill) * 1000, 3),\n",
        "        \"decode_mean_ms\": round(mean_decode_ms, 3),\n",
        "        \"decode_std_ms\": round(float(np.std(decode_ms)), 3) if decode_ms else 0.0,\n",
        "        \"tok_s\": round(1000 / mean_decode_ms, 4) if mean_decode_ms > 0 else 0.0,\n",
        "        \"ffn_ms\": round(component_ms[\"ffn\"], 3),\n",
        "        \"qkv_ms\": round(component_ms[\"qkv\"], 3),\n",
        "        \"oproj_ms\": round(component_ms[\"o_proj\"], 3),\n",
        "        \"attn_ms\": round(component_ms[\"attn\"], 3),\n",
        "        \"ple_ms\": round(component_ms[\"ple\"], 3),\n",
        "    }"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def run_benchmark():\n",
        "    runtime = load_runtime_modules(WORKSPACE)\n",
        "    PROFILE_DIR.mkdir(parents=True, exist_ok=True)\n",
        "\n",
        "    print(f\"Mode: W={WEIGHT_MODE} A={FEATURE_MODE} Accel={runtime['ACCEL']}\")\n",
        "    print(f\"Runs per setting: {N_RUNS}\")\n",
        "    print(f\"Output token settings: {OUTPUT_TOKEN_SETTINGS}\")\n",
        "    print(f\"Prompt count: {len(PROMPTS)}\")\n",
        "\n",
        "    t0 = time.time()\n",
        "    weights = load_model(runtime)\n",
        "    print(f\"Model loaded in {time.time() - t0:.1f}s\")\n",
        "\n",
        "    rows = []\n",
        "    total = len(PROMPTS) * len(OUTPUT_TOKEN_SETTINGS) * N_RUNS\n",
        "    idx = 0\n",
        "    for prompt_id, prompt in enumerate(PROMPTS, start=1):\n",
        "        for max_tokens in OUTPUT_TOKEN_SETTINGS:\n",
        "            for run_id in range(1, N_RUNS + 1):\n",
        "                idx += 1\n",
        "                print(f\"[{idx}/{total}] prompt={prompt_id} max_new={max_tokens} run={run_id}\")\n",
        "                t1 = time.time()\n",
        "                row = run_once(runtime, prompt, max_tokens, weights)\n",
        "                row.update({\n",
        "                    \"prompt_id\": prompt_id,\n",
        "                    \"prompt\": prompt,\n",
        "                    \"max_new_tokens\": max_tokens,\n",
        "                    \"run\": run_id,\n",
        "                    \"elapsed_s\": round(time.time() - t1, 3),\n",
        "                    \"weight_mode\": WEIGHT_MODE,\n",
        "                    \"feature_mode\": FEATURE_MODE,\n",
        "                    \"accel\": runtime[\"ACCEL\"],\n",
        "                })\n",
        "                rows.append(row)\n",
        "\n",
        "                runtime[\"generate_profile_html\"]()\n",
        "                if Path(\"ProfilerReport.html\").exists():\n",
        "                    shutil.copy2(\"ProfilerReport.html\", PROFILE_DIR / f\"p{prompt_id}_t{max_tokens}_r{run_id}.html\")\n",
        "                gc.collect()\n",
        "\n",
        "    result_df = pd.DataFrame(rows)\n",
        "    result_df.to_csv(OUTPUT_CSV, index=False)\n",
        "    return result_df\n",
        "\n",
        "# Uncomment when the local runtime and weights are available.\n",
        "# results = run_benchmark()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def summarize_results(results: pd.DataFrame) -> pd.DataFrame:\n",
        "    grouped = results.groupby([\"max_new_tokens\", \"prompt_id\"], as_index=False)\n",
        "    return grouped.agg(\n",
        "        runs=(\"run\", \"count\"),\n",
        "        actual_output_mean=(\"output_tokens\", \"mean\"),\n",
        "        prefill_mean_ms=(\"prefill_ms\", \"mean\"),\n",
        "        decode_mean_ms=(\"decode_mean_ms\", \"mean\"),\n",
        "        decode_std_ms=(\"decode_mean_ms\", \"std\"),\n",
        "        decode_p50_ms=(\"decode_mean_ms\", lambda s: s.quantile(0.50)),\n",
        "        decode_p90_ms=(\"decode_mean_ms\", lambda s: s.quantile(0.90)),\n",
        "        decode_p99_ms=(\"decode_mean_ms\", lambda s: s.quantile(0.99)),\n",
        "        tok_s_mean=(\"tok_s\", \"mean\"),\n",
        "        elapsed_mean_s=(\"elapsed_s\", \"mean\"),\n",
        "        ffn_mean_ms=(\"ffn_ms\", \"mean\"),\n",
        "        qkv_mean_ms=(\"qkv_ms\", \"mean\"),\n",
        "        oproj_mean_ms=(\"oproj_ms\", \"mean\"),\n",
        "        attn_mean_ms=(\"attn_ms\", \"mean\"),\n",
        "        ple_mean_ms=(\"ple_ms\", \"mean\"),\n",
        "    ).round(3)\n",
        "\n",
        "# If a benchmark CSV already exists, load it for analysis.\n",
        "if OUTPUT_CSV.exists():\n",
        "    results = pd.read_csv(OUTPUT_CSV)\n",
        "    summary = summarize_results(results)\n",
        "    display(summary)\n",
        "else:\n",
        "    print(f\"No benchmark CSV found at {OUTPUT_CSV}. Run run_benchmark() first.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def plot_latency(summary: pd.DataFrame):\n",
        "    by_tokens = summary.groupby(\"max_new_tokens\", as_index=False).agg(\n",
        "        decode_mean_ms=(\"decode_mean_ms\", \"mean\"),\n",
        "        decode_p50_ms=(\"decode_p50_ms\", \"mean\"),\n",
        "        decode_p90_ms=(\"decode_p90_ms\", \"mean\"),\n",
        "        decode_p99_ms=(\"decode_p99_ms\", \"mean\"),\n",
        "    )\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(8, 4))\n",
        "    ax.plot(by_tokens[\"max_new_tokens\"], by_tokens[\"decode_mean_ms\"], marker=\"o\", label=\"mean\")\n",
        "    ax.plot(by_tokens[\"max_new_tokens\"], by_tokens[\"decode_p50_ms\"], marker=\"o\", label=\"p50\")\n",
        "    ax.plot(by_tokens[\"max_new_tokens\"], by_tokens[\"decode_p90_ms\"], marker=\"o\", label=\"p90\")\n",
        "    ax.plot(by_tokens[\"max_new_tokens\"], by_tokens[\"decode_p99_ms\"], marker=\"o\", label=\"p99\")\n",
        "    ax.set_title(\"Decode latency vs output-token setting\")\n",
        "    ax.set_xlabel(\"Max new tokens\")\n",
        "    ax.set_ylabel(\"Decode ms/token\")\n",
        "    ax.grid(True, alpha=0.3)\n",
        "    ax.legend()\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "\n",
        "def plot_component_breakdown(summary: pd.DataFrame):\n",
        "    components = [\"ffn_mean_ms\", \"qkv_mean_ms\", \"oproj_mean_ms\", \"attn_mean_ms\", \"ple_mean_ms\"]\n",
        "    labels = [\"FFN\", \"QKV\", \"O-proj\", \"Attention\", \"PLE\"]\n",
        "    by_tokens = summary.groupby(\"max_new_tokens\", as_index=True)[components].mean()\n",
        "\n",
        "    fig, ax = plt.subplots(figsize=(8, 4))\n",
        "    bottom = np.zeros(len(by_tokens))\n",
        "    x = by_tokens.index.astype(str)\n",
        "    for column, label in zip(components, labels):\n",
        "        values = by_tokens[column].to_numpy()\n",
        "        ax.bar(x, values, bottom=bottom, label=label)\n",
        "        bottom += values\n",
        "    ax.set_title(\"Per-component decode breakdown\")\n",
        "    ax.set_xlabel(\"Max new tokens\")\n",
        "    ax.set_ylabel(\"Mean ms/token\")\n",
        "    ax.legend()\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "\n",
        "def plot_run_variance(results: pd.DataFrame):\n",
        "    fig, ax = plt.subplots(figsize=(9, 4))\n",
        "    for max_tokens, group in results.groupby(\"max_new_tokens\"):\n",
        "        ax.scatter(group[\"run\"], group[\"decode_mean_ms\"], s=18, alpha=0.65, label=str(max_tokens))\n",
        "    ax.set_title(\"Run-to-run decode variance\")\n",
        "    ax.set_xlabel(\"Run index\")\n",
        "    ax.set_ylabel(\"Decode ms/token\")\n",
        "    ax.grid(True, alpha=0.3)\n",
        "    ax.legend(title=\"Max new\")\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "# After loading or producing results:\n",
        "# plot_latency(summary)\n",
        "# plot_component_breakdown(summary)\n",
        "# plot_run_variance(results)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Reporting checklist\n",
        "\n",
        "For each benchmark run, report:\n",
        "\n",
        "- runtime mode and accelerator backend;\n",
        "- prompt list and output-token settings;\n",
        "- run count per setting;\n",
        "- mean, standard deviation, p50, p90, p99 decode latency;\n",
        "- mean prefill latency;\n",
        "- per-component decode breakdown;\n",
        "- run-to-run variance plot.\n",
        "\n",
        "Avoid comparing a single prompt, a single output length, or only five runs when claiming benchmark behavior."
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "pygments_lexer": "ipython3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
