Simplifying CUDA kernels with Triton
A Pythonic Approach to GPU Programming
Writing custom CUDA kernels for whatever reasons had always looked like a daunting task. This is where OpenAI’s triton comes in handy. With its pythonic and torch like syntax now its a bit more approachable. So lets dive in and see what triton has to offer.
Understanding GPU Memory Hierarchy
Before that lets have a quick introduction about how GPU memory works. I want to keep this short. I will leave the resources for extra reading here. But long story short there are two main kinds of memory in GPUs.
GPU DRAM (also known as High Bandwidth Memory, HBM): The typical memory we refer to when we say the GPU has a memory of 16GB or 80 GB etc. Think of it like memory in your backyard warehouse. This is slower but inexpensive to make compared to SRAM. For A100 for example the memory bandwidth for this kind of memory, is ~2TB/s
GPU SRAM (L1/L2 Caches): This is much smaller in size usually in KBs and MBs which is more local to the chip (within the die) , hence much faster to access, at the same time more expensive to make. This is sort of like the working memory of the chip. For A100 the L1 Cache size is 192KB per Streaming Multiprocessor (SM)(think of it like a bunch of cores lined up together) and L2 Cache Size is 40MB (think of it as memory inside the house)
For A100s again:
L1 Cache Bandwidth: ~100–200 TB/s theoretical bandwidth
L2 Cache Bandwidth: ~4–7 TB/s theoretical bandwidth
L1 Caches are like memory you have available on the table you are working on, L2 is more like somewhere in the room where you are working on and DRAM is probably in your backyard warehouse. Hence the difference in speed. Now you get the idea of where the speed lies. The key thing here is accessing memory from within these caches is much faster, and the SM has to wait a lot less shorter if data is available within these. These modern cores can perform multiple 100s of cycles during DRAM access times and that goes down to only 10s of cycles if the data is in the L1 Cache. So this is the cost of data movement!
Triton Basics: Thread Launching and Block Processing
Okay lets get back. Imagine you have to add/multiply 2 vectors of size 60k. I know 60k is just arbitrary and I chose it to make a point, clearly. With the 192KB of memory highway available to you, you want to make the best use of it. If we are talking about FP32 numbers within these vectors then the vector is already well beyond what can be stored within these tiny units of memory.
And you have 2 of these and still need that much space to store the outputs. So in triton or CUDA in general the way this is dealt with is by using block sizes. You can still get the same output if each SM is only looking at a partition of the data of size Block Size
Practical Example: Vector Addition in Triton
Now for example in our 60k vector, if we divide it into a non-overlapping blocks of size 8k each, the last block will only have 4k numbers. If we always assume we will have block size elements, we will need a mechanism to tell CUDA or triton that just look at the 4k in this last block ignore the other half. This is where masks come in. Now I will put this all together into a logical flow:
Lets see how to do this. I will first start with how we launch these threads in the first place. I will use the snippet from official triton docs here with some changes to make it a bit more readable
So here we launched a 1-D grid of threads/processes which is supposed to handle non-overlapping parts of the data and process and dump data independently into the output. Like the note says a key thing to note here is that the x_ptr or y_ptr that triton sees is sort of like an address to the very first element of the tensor (array in this case). Okay now lets see how to do this exactly:
To summarize, this is like any parallel processing logic. We just need to decide on an access pattern so to speak, which processes chunks of data without any race-conditions like 2 processes simultaneously trying to access the same data pointer and the like.
Advanced Implementation: Matrix Multiplication
Now that’s out of the way lets look at a bit more complicated example of matrix multiplication. Now bear in mind the kernel that you are about to read is not the most optimal implementation of matrix multiplication. This is just so that we can get our hands dirty and also not everything you do has to be efficient. It just has to work, and you would have learned a lot in the process. I am just going to leave the kernel here for you to get a hang of it.
I am also leaving a rough sketch of how these access patterns came about. The numbers you see below are the indexes of those positions, if we were to flatten this matrix out. This would in turn be our memory offsets/differences in the ptr from the first element in the matrix. It goes row wise first.
Resources for Further Learning
Nexts steps would be to try and implement custom kernels for all sorts of things, like a cross-entropy loss, or a softmax etc. Hope you get the idea. I tried to keep it short. If you want to learn more I am leaving you with a few good resources to have a look at:
An advanced implementation of a custom Cross Entropy Loss kernel from Unsloth
and Claude ofc. Makes the learning a whole lot easier!
As ML models continue to grow, what role do you see for Python-based GPU programming languages versus traditional CUDA? Is Triton filling a temporary gap or establishing a new paradigm?










