Teaching models to write kernels; dataset, training, and honest results
i wanted a model that actually understands low‑level GPU code: masked loads, boundary checks, shared memory... the stuff people sweat about and models usually mess up. this is a short, practical writeup: what i collected, how i cleaned it, how i trained, how i evaluated.
tl;dr
- a focused dataset of triton kernel bodies and cuda→triton translation pairs (a few thousand examples).
- cleaned, deduped, licensed... ready for adapter-style fine‑tuning.
- i fine‑tuned adapters on a qwen 8b checkpoint using a single RTX 3090... results are useful as kernel drafts.
what's in the dataset
🤗 HuggingFace Dataset: edwixx/triton-code-dataset
🤗 Fine-tuned Model: edwixx/qwen3-8b-triton-finetune
two CSV splits, simple schema (prompt,completion):
fim_sft.csv— fill‑in‑the‑middle for triton kernels (prompt: signature + context, completion: body).cu2triton_sft.csv— CUDA functions paired with hand‑crafted or curated Triton rewrites.
repo layout i publish:
/fim_sft.csv
/cu2triton_sft.csv
/ANALYSIS.json
/ANALYSIS.md
/PROVENANCE.md
/figs/*.png
quick visual tour
FIM — prompt tokens
prompts are compact... mostly function signatures, short docstrings, and one or two lines of context. that makes them perfect for body completion the model only needs to learn common patterns, not long surrounding prose.
FIM: completion tokens
bodies are short to medium in length, many sit in the 20 to 200 token range. this is where the useful signal is: masked loads and stores, offset arithmetic, simple reductions. those recurring idioms are what the model picks up fastest.
CUDA to Triton > prompts & completions
translation pairs expose the model to two different styles... c++ indexing and macro‑heavy patterns on one side, compact, pythonic triton code on the other. that contrast helps the model map idioms across languages and produce cleaner Triton rewrites.(not necessarily lmao)
how i built it
i started by searching repositories, documentation pages, and blog posts for kernel examples that show real GPU concerns. i prioritized permissively licensed sources and files that were self contained or had clear argument shapes. i copied the minimal unit needed to understand the kernel — the function or kernel plus one or two lines of context describing input shapes. smaller context is better... it forces the model to learn from code, not from surrounding narrative.
next i normalized the text. unify newlines, trim leading and trailing whitespace, collapse repeated spaces and tabs, and collapse long runs of blank lines. this keeps token counts honest and makes deduplication reliable. here is a small python helper i used:
def normalize_code(s: str) -> str:
import re
s = s.replace('\r\n', '\n')
s = s.strip()
s = re.sub(r'[ \t]+', ' ', s)
s = re.sub(r'\n{3,}', '\n\n', s)
return s
# usage
# clean = normalize_code(open('example.cu', 'r', encoding='utf-8').read())
after cleaning i deduplicated exactly. compute a SHA‑1 hash over the string formed by normalize(prompt) ||| normalize(completion) and drop exact matches. exact duplicates skew metrics and give the model false confidence. once exact duplicates are removed, run a leakage filter using Jaccard word overlap and flag pairs with overlap greater than or equal to 0.8 for manual review. that catches cases where the prompt already contains most of the answer.
keep provenance at every step... preserve upstream license headers in the files you include and record the repository name, file path, and a short license snippet in PROVENANCE.md so origins are traceable. when the data looks clean and reviewed, export two CSVs named fim_sft.csv and cu2triton_sft.csv with a simple prompt,completion row per example.
training (what i actually did)
i used qwen 8b and trained LoRA adapters on a single RTX 3090... that is the whole stack. i did not run multi‑gpu or cloud A100 jobs for this work. i wanted practical edits you can reproduce on a local machine that has a 24gb 3090.
how i fit it on a 3090... i loaded the base model with 4‑bit quantization using bitsandbytes and kept compute in float16. that lets the model live on a single GPU while training only the adapter weights.
main config i ran:
# exact config i used
from transformers import BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
)
peft_cfg = LoraConfig(
r=32,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
task_type="CAUSAL_LM",
)
cfg = SFTConfig(
output_dir="qwen3_8b_triton_fim_lora",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=1e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.02,
logging_steps=20,
save_steps=1000,
packing=True,
max_length=1024,
bf16=True,
gradient_checkpointing=True,
)
MODEL = 'Qwen/Qwen3-8B'
quant = {
'load_in_4bit': True,
'bnb_4bit_quant_type': 'nf4',
'bnb_4bit_compute_dtype': 'float16'
}
peft = {
'r': 32,
'lora_alpha': 32,
'lora_dropout': 0.05,
'target_modules': ['q_proj','k_proj','v_proj','o_proj']
}
TRAIN = {
'epochs': 1-3,
'per_device_batch_size': 1,
'gradient_accumulation_steps': 16,
'learning_rate': 1e-4,
'max_length': 1024,
'bf16': False, # 3090 -> use fp16/float16 compute
'gradient_checkpointing': True
}
practical notes... on a 3090 you must rely on small per‑device batches and accumulation. packing helps if you can pack multiple short examples into a sequence. i kept the training runs short so i could iterate quickly and manually review outputs after every few hundred steps.
monitoring and validation... i watched token accuracy and simple static checks (presence of masked load/store patterns). loss is noisy on small datasets, so use heuristics and spot checks to decide whether to continue.
evaluation — quick and useful
i kept eval simple. mix of automated and manual:
- static checks: grep for
@triton.jit,tl.load, masked stores. fast signal. - token overlap: for translation pairs, measure how much the output matches reference. coarse but works.
- spot compile: try compiling a few examples. most won't run on random shapes, but catches syntax errors.
- human read: i read ~100 outputs. catches weird patterns machines miss.
good enough to separate usable drafts from garbage.
example the model learned
prompt
import triton, triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr):
# complete: load x,y (masked), add, store
completion
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
out = x + y
tl.store(out_ptr + offsets, out, mask=mask)
this pattern repeats across kernels; a few dozen examples are enough for the model to reproduce it reliably.
provenance & licensing
- original license headers are kept in source files (bsd/mit/apache).
PROVENANCE.mdlists repo names and short license snippets.- the dataset wrapper and scripts are MIT; underlying snippets retain upstream licenses.
if you reuse this dataset, keep headers intact and read provenance.
limitations
- small dataset: teaches idioms, not everything. expect hallucinations.
- translations may not compile in edge cases. treat outputs as drafts.
- no automatic proof of correctness across all outputs — that's future work.
next steps (practical)
- add fuzzy near‑duplicate detection (minhash) to cut semantic leakage.
- run compile tests for a subset and label examples as "compilable."
how to regenerate the graphs locally
run in python:
import pandas as pd
from matplotlib import pyplot as plt
rows = pd.read_csv('triton-code-dataset/fim_sft.csv')
plt.hist([len(x.split()) for x in rows['prompt']], bins=50); plt.savefig('re_fim_prompt.png')
PS: i did all this experimnets jus cause i wanna do it, this was just a fun project i did nothing more