Teaching models to write kernels; dataset, training, and honest results
i wanted a model that actually understands low‑level GPU code: masked loads, boundary checks, shared memory... the stuff people sweat about and models usually mess up. this is a short, practical writeup: what i collected, how i cleaned it, how i trained, how i evaluated.
tl;dr
- a focused dataset of triton kernel bodies and cuda→triton translation pairs (a few thousand examples).
- cleaned, deduped, licensed... ready for adapter-style fine‑tuning.
- i fine‑tuned adapters on a qwen 8b checkpoint using a single RTX 3090... results are useful as kernel drafts.
what's in the dataset
HuggingFace Dataset: edwixx/triton-code-dataset
Fine-tuned Model: edwixx/qwen3-8b-triton-finetune
Two CSV splits, simple schema (prompt,completion):
fim_sft.csv— fill‑in‑the‑middle for triton kernels (prompt: signature + context, completion: body).cu2triton_sft.csv— CUDA functions paired with hand‑crafted or curated Triton rewrites.
Repo layout:
/fim_sft.csv
/cu2triton_sft.csv
/ANALYSIS.json
/ANALYSIS.md
/PROVENANCE.md
/figs/*.png
how i built it
i started by searching repositories, documentation pages, and blog posts for kernel examples that show real GPU concerns. i prioritized permissively licensed sources and files that were self contained or had clear argument shapes.
next i normalized the text — unify newlines, trim whitespace, collapse repeated spaces and tabs:
def normalize_code(s: str) -> str:
import re
s = s.replace('\r\n', '\n')
s = s.strip()
s = re.sub(r'[ \t]+', ' ', s)
s = re.sub(r'\n{3,}', '\n\n', s)
return s
after cleaning i deduplicated exactly. compute a SHA‑1 hash over normalize(prompt) ||| normalize(completion) and drop exact matches. then run a leakage filter using Jaccard word overlap and flag pairs with overlap ≥ 0.8 for manual review.
training
i used qwen 8b and trained LoRA adapters on a single RTX 3090. loaded the base model with 4‑bit quantization using bitsandbytes.
from transformers import BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
)
peft_cfg = LoraConfig(
r=32,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
task_type="CAUSAL_LM",
)
cfg = SFTConfig(
output_dir="qwen3_8b_triton_fim_lora",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.02,
packing=True,
max_length=1024,
bf16=True,
gradient_checkpointing=True,
)
evaluation
i kept eval simple — mix of automated and manual:
- static checks: grep for
@triton.jit,tl.load, masked stores. fast signal. - token overlap: for translation pairs, measure how much the output matches reference.
- spot compile: try compiling a few examples. catches syntax errors.
- human read: i read ~100 outputs. catches weird patterns machines miss.
example the model learned
prompt
import triton, triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):
# complete: load x,y (masked), add, store
completion
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
out = x + y
tl.store(out_ptr + offsets, out, mask=mask)
This pattern repeats across kernels; a few dozen examples are enough for the model to reproduce it reliably.
limitations
- small dataset: teaches idioms, not everything. expect hallucinations.
- translations may not compile in edge cases. treat outputs as drafts.
- no automatic proof of correctness across all outputs.
next steps
- add fuzzy near‑duplicate detection (minhash) to cut semantic leakage.
- run compile tests for a subset and label examples as "compilable."
PS: i did all this just cause i wanted to — this was a fun project, nothing more.