[{"data":1,"prerenderedAt":567},["ShallowReactive",2],{"blog-teaching-models-to-write-kernels":3},{"id":4,"title":5,"author":6,"body":7,"categories":552,"date":557,"description":558,"extension":559,"hidden":560,"meta":561,"navigation":210,"path":562,"seo":563,"stem":564,"thumbnail":565,"__hash__":566},"blog\u002Fblog\u002Fteaching-models-to-write-kernels.md","Teaching models to write kernels; dataset, training, and honest results","Anurag Kanade",{"type":8,"value":9,"toc":542},"minimark",[10,14,17,22,43,47,59,69,77,95,98,108,110,114,117,120,171,178,180,183,186,373,375,378,381,415,417,421,426,455,460,500,503,505,508,519,521,525,533,535,538],[11,12,13],"p",{},"i wanted a model that actually understands low‑level GPU code: masked loads, boundary checks, shared memory... the stuff people sweat about and models usually mess up. this is a short, practical writeup: what i collected, how i cleaned it, how i trained, how i evaluated.",[15,16],"hr",{},[18,19,21],"h2",{"id":20},"tldr","tl;dr",[23,24,25,29,40],"ul",{},[26,27,28],"li",{},"a focused dataset of triton kernel bodies and cuda→triton translation pairs (a few thousand examples).",[26,30,31,32,39],{},"cleaned, ",[33,34,38],"a",{"href":35,"rel":36},"https:\u002F\u002Fen.wikipedia.org\u002Fwiki\u002FData_deduplication",[37],"nofollow","deduped",", licensed... ready for adapter-style fine‑tuning.",[26,41,42],{},"i fine‑tuned adapters on a qwen 8b checkpoint using a single RTX 3090... results are useful as kernel drafts.",[18,44,46],{"id":45},"whats-in-the-dataset","what's in the dataset",[11,48,49,53,54],{},[50,51,52],"strong",{},"HuggingFace Dataset:"," ",[33,55,58],{"href":56,"rel":57},"https:\u002F\u002Fhuggingface.co\u002Fdatasets\u002Fedwixx\u002Ftriton-code-dataset",[37],"edwixx\u002Ftriton-code-dataset",[11,60,61,53,64],{},[50,62,63],{},"Fine-tuned Model:",[33,65,68],{"href":66,"rel":67},"https:\u002F\u002Fhuggingface.co\u002Fedwixx\u002Fqwen3-8b-triton-finetune",[37],"edwixx\u002Fqwen3-8b-triton-finetune",[11,70,71,72,76],{},"Two CSV splits, simple schema (",[73,74,75],"code",{},"prompt,completion","):",[23,78,79,87],{},[26,80,81,86],{},[50,82,83],{},[73,84,85],{},"fim_sft.csv"," — fill‑in‑the‑middle for triton kernels (prompt: signature + context, completion: body).",[26,88,89,94],{},[50,90,91],{},[73,92,93],{},"cu2triton_sft.csv"," — CUDA functions paired with hand‑crafted or curated Triton rewrites.",[11,96,97],{},"Repo layout:",[99,100,105],"pre",{"className":101,"code":103,"language":104},[102],"language-text","\u002Ffim_sft.csv\n\u002Fcu2triton_sft.csv\n\u002FANALYSIS.json\n\u002FANALYSIS.md\n\u002FPROVENANCE.md\n\u002Ffigs\u002F*.png\n","text",[73,106,103],{"__ignoreMap":107},"",[15,109],{},[18,111,113],{"id":112},"how-i-built-it","how i built it",[11,115,116],{},"i started by searching repositories, documentation pages, and blog posts for kernel examples that show real GPU concerns. i prioritized permissively licensed sources and files that were self contained or had clear argument shapes.",[11,118,119],{},"next i normalized the text — unify newlines, trim whitespace, collapse repeated spaces and tabs:",[99,121,125],{"className":122,"code":123,"language":124,"meta":107,"style":107},"language-python shiki shiki-themes github-light github-dark","def normalize_code(s: str) -> str:\n    import re\n    s = s.replace('\\r\\n', '\\n')\n    s = s.strip()\n    s = re.sub(r'[ \\t]+', ' ', s)\n    s = re.sub(r'\\n{3,}', '\\n\\n', s)\n    return s\n","python",[73,126,127,135,141,147,153,159,165],{"__ignoreMap":107},[128,129,132],"span",{"class":130,"line":131},"line",1,[128,133,134],{},"def normalize_code(s: str) -> str:\n",[128,136,138],{"class":130,"line":137},2,[128,139,140],{},"    import re\n",[128,142,144],{"class":130,"line":143},3,[128,145,146],{},"    s = s.replace('\\r\\n', '\\n')\n",[128,148,150],{"class":130,"line":149},4,[128,151,152],{},"    s = s.strip()\n",[128,154,156],{"class":130,"line":155},5,[128,157,158],{},"    s = re.sub(r'[ \\t]+', ' ', s)\n",[128,160,162],{"class":130,"line":161},6,[128,163,164],{},"    s = re.sub(r'\\n{3,}', '\\n\\n', s)\n",[128,166,168],{"class":130,"line":167},7,[128,169,170],{},"    return s\n",[11,172,173,174,177],{},"after cleaning i deduplicated exactly. compute a SHA‑1 hash over ",[73,175,176],{},"normalize(prompt) ||| normalize(completion)"," and drop exact matches. then run a leakage filter using Jaccard word overlap and flag pairs with overlap ≥ 0.8 for manual review.",[15,179],{},[18,181,182],{"id":182},"training",[11,184,185],{},"i used qwen 8b and trained LoRA adapters on a single RTX 3090. loaded the base model with 4‑bit quantization using bitsandbytes.",[99,187,189],{"className":122,"code":188,"language":124,"meta":107,"style":107},"from transformers import BitsAndBytesConfig\nfrom peft import LoraConfig\nfrom trl import SFTConfig\n\nbnb = BitsAndBytesConfig(\n    load_in_4bit=True,\n    bnb_4bit_quant_type=\"nf4\",\n    bnb_4bit_compute_dtype=\"bfloat16\",\n)\n\npeft_cfg = LoraConfig(\n    r=32,\n    lora_alpha=32,\n    lora_dropout=0.05,\n    bias=\"none\",\n    target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"],\n    task_type=\"CAUSAL_LM\",\n)\n\ncfg = SFTConfig(\n    output_dir=\"qwen3_8b_triton_fim_lora\",\n    num_train_epochs=1,\n    per_device_train_batch_size=1,\n    gradient_accumulation_steps=16,\n    learning_rate=1e-4,\n    lr_scheduler_type=\"cosine\",\n    warmup_ratio=0.02,\n    packing=True,\n    max_length=1024,\n    bf16=True,\n    gradient_checkpointing=True,\n)\n",[73,190,191,196,201,206,212,217,222,227,233,239,244,250,256,262,268,274,280,286,291,296,302,308,314,320,326,332,338,344,350,356,362,368],{"__ignoreMap":107},[128,192,193],{"class":130,"line":131},[128,194,195],{},"from transformers import BitsAndBytesConfig\n",[128,197,198],{"class":130,"line":137},[128,199,200],{},"from peft import LoraConfig\n",[128,202,203],{"class":130,"line":143},[128,204,205],{},"from trl import SFTConfig\n",[128,207,208],{"class":130,"line":149},[128,209,211],{"emptyLinePlaceholder":210},true,"\n",[128,213,214],{"class":130,"line":155},[128,215,216],{},"bnb = BitsAndBytesConfig(\n",[128,218,219],{"class":130,"line":161},[128,220,221],{},"    load_in_4bit=True,\n",[128,223,224],{"class":130,"line":167},[128,225,226],{},"    bnb_4bit_quant_type=\"nf4\",\n",[128,228,230],{"class":130,"line":229},8,[128,231,232],{},"    bnb_4bit_compute_dtype=\"bfloat16\",\n",[128,234,236],{"class":130,"line":235},9,[128,237,238],{},")\n",[128,240,242],{"class":130,"line":241},10,[128,243,211],{"emptyLinePlaceholder":210},[128,245,247],{"class":130,"line":246},11,[128,248,249],{},"peft_cfg = LoraConfig(\n",[128,251,253],{"class":130,"line":252},12,[128,254,255],{},"    r=32,\n",[128,257,259],{"class":130,"line":258},13,[128,260,261],{},"    lora_alpha=32,\n",[128,263,265],{"class":130,"line":264},14,[128,266,267],{},"    lora_dropout=0.05,\n",[128,269,271],{"class":130,"line":270},15,[128,272,273],{},"    bias=\"none\",\n",[128,275,277],{"class":130,"line":276},16,[128,278,279],{},"    target_modules=[\"q_proj\",\"k_proj\",\"v_proj\",\"o_proj\",\"gate_proj\",\"up_proj\",\"down_proj\"],\n",[128,281,283],{"class":130,"line":282},17,[128,284,285],{},"    task_type=\"CAUSAL_LM\",\n",[128,287,289],{"class":130,"line":288},18,[128,290,238],{},[128,292,294],{"class":130,"line":293},19,[128,295,211],{"emptyLinePlaceholder":210},[128,297,299],{"class":130,"line":298},20,[128,300,301],{},"cfg = SFTConfig(\n",[128,303,305],{"class":130,"line":304},21,[128,306,307],{},"    output_dir=\"qwen3_8b_triton_fim_lora\",\n",[128,309,311],{"class":130,"line":310},22,[128,312,313],{},"    num_train_epochs=1,\n",[128,315,317],{"class":130,"line":316},23,[128,318,319],{},"    per_device_train_batch_size=1,\n",[128,321,323],{"class":130,"line":322},24,[128,324,325],{},"    gradient_accumulation_steps=16,\n",[128,327,329],{"class":130,"line":328},25,[128,330,331],{},"    learning_rate=1e-4,\n",[128,333,335],{"class":130,"line":334},26,[128,336,337],{},"    lr_scheduler_type=\"cosine\",\n",[128,339,341],{"class":130,"line":340},27,[128,342,343],{},"    warmup_ratio=0.02,\n",[128,345,347],{"class":130,"line":346},28,[128,348,349],{},"    packing=True,\n",[128,351,353],{"class":130,"line":352},29,[128,354,355],{},"    max_length=1024,\n",[128,357,359],{"class":130,"line":358},30,[128,360,361],{},"    bf16=True,\n",[128,363,365],{"class":130,"line":364},31,[128,366,367],{},"    gradient_checkpointing=True,\n",[128,369,371],{"class":130,"line":370},32,[128,372,238],{},[15,374],{},[18,376,377],{"id":377},"evaluation",[11,379,380],{},"i kept eval simple — mix of automated and manual:",[23,382,383,397,403,409],{},[26,384,385,388,389,392,393,396],{},[50,386,387],{},"static checks:"," grep for ",[73,390,391],{},"@triton.jit",", ",[73,394,395],{},"tl.load",", masked stores. fast signal.",[26,398,399,402],{},[50,400,401],{},"token overlap:"," for translation pairs, measure how much the output matches reference.",[26,404,405,408],{},[50,406,407],{},"spot compile:"," try compiling a few examples. catches syntax errors.",[26,410,411,414],{},[50,412,413],{},"human read:"," i read ~100 outputs. catches weird patterns machines miss.",[15,416],{},[18,418,420],{"id":419},"example-the-model-learned","example the model learned",[11,422,423],{},[50,424,425],{},"prompt",[99,427,429],{"className":122,"code":428,"language":124,"meta":107,"style":107},"import triton, triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):\n    # complete: load x,y (masked), add, store\n",[73,430,431,436,440,445,450],{"__ignoreMap":107},[128,432,433],{"class":130,"line":131},[128,434,435],{},"import triton, triton.language as tl\n",[128,437,438],{"class":130,"line":137},[128,439,211],{"emptyLinePlaceholder":210},[128,441,442],{"class":130,"line":143},[128,443,444],{},"@triton.jit\n",[128,446,447],{"class":130,"line":149},[128,448,449],{},"def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):\n",[128,451,452],{"class":130,"line":155},[128,453,454],{},"    # complete: load x,y (masked), add, store\n",[11,456,457],{},[50,458,459],{},"completion",[99,461,463],{"className":122,"code":462,"language":124,"meta":107,"style":107},"pid = tl.program_id(0)\noffsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\nmask = offsets \u003C n\nx = tl.load(x_ptr + offsets, mask=mask)\ny = tl.load(y_ptr + offsets, mask=mask)\nout = x + y\ntl.store(out_ptr + offsets, out, mask=mask)\n",[73,464,465,470,475,480,485,490,495],{"__ignoreMap":107},[128,466,467],{"class":130,"line":131},[128,468,469],{},"pid = tl.program_id(0)\n",[128,471,472],{"class":130,"line":137},[128,473,474],{},"offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n",[128,476,477],{"class":130,"line":143},[128,478,479],{},"mask = offsets \u003C n\n",[128,481,482],{"class":130,"line":149},[128,483,484],{},"x = tl.load(x_ptr + offsets, mask=mask)\n",[128,486,487],{"class":130,"line":155},[128,488,489],{},"y = tl.load(y_ptr + offsets, mask=mask)\n",[128,491,492],{"class":130,"line":161},[128,493,494],{},"out = x + y\n",[128,496,497],{"class":130,"line":167},[128,498,499],{},"tl.store(out_ptr + offsets, out, mask=mask)\n",[11,501,502],{},"This pattern repeats across kernels; a few dozen examples are enough for the model to reproduce it reliably.",[15,504],{},[18,506,507],{"id":507},"limitations",[23,509,510,513,516],{},[26,511,512],{},"small dataset: teaches idioms, not everything. expect hallucinations.",[26,514,515],{},"translations may not compile in edge cases. treat outputs as drafts.",[26,517,518],{},"no automatic proof of correctness across all outputs.",[15,520],{},[18,522,524],{"id":523},"next-steps","next steps",[23,526,527,530],{},[26,528,529],{},"add fuzzy near‑duplicate detection (minhash) to cut semantic leakage.",[26,531,532],{},"run compile tests for a subset and label examples as \"compilable.\"",[15,534],{},[11,536,537],{},"PS: i did all this just cause i wanted to — this was a fun project, nothing more.",[539,540,541],"style",{},"html .default .shiki span {color: var(--shiki-default);background: var(--shiki-default-bg);font-style: var(--shiki-default-font-style);font-weight: var(--shiki-default-font-weight);text-decoration: var(--shiki-default-text-decoration);}html .shiki span {color: var(--shiki-default);background: var(--shiki-default-bg);font-style: var(--shiki-default-font-style);font-weight: var(--shiki-default-font-weight);text-decoration: var(--shiki-default-text-decoration);}html .dark .shiki span {color: var(--shiki-dark);background: var(--shiki-dark-bg);font-style: var(--shiki-dark-font-style);font-weight: var(--shiki-dark-font-weight);text-decoration: var(--shiki-dark-text-decoration);}html.dark .shiki span {color: var(--shiki-dark);background: var(--shiki-dark-bg);font-style: var(--shiki-dark-font-style);font-weight: var(--shiki-dark-font-weight);text-decoration: var(--shiki-dark-text-decoration);}",{"title":107,"searchDepth":137,"depth":137,"links":543},[544,545,546,547,548,549,550,551],{"id":20,"depth":137,"text":21},{"id":45,"depth":137,"text":46},{"id":112,"depth":137,"text":113},{"id":182,"depth":137,"text":182},{"id":377,"depth":137,"text":377},{"id":419,"depth":137,"text":420},{"id":507,"depth":137,"text":507},{"id":523,"depth":137,"text":524},[553,554,555,556],"Machine Learning","GPU","Triton","CUDA","2025-10-29","A practical guide to building a focused dataset and fine-tuning models to understand low-level GPU code - masked loads, boundary checks, shared memory, and CUDA→Triton translations.","md",false,{},"\u002Fblog\u002Fteaching-models-to-write-kernels",{"title":5,"description":558},"blog\u002Fteaching-models-to-write-kernels",null,"MAIkJ9W_QOPCoWdg4ukOtLjeT92rpozVzYRtpOe09Tw",1775296369322]