Faster prefix sum computation with SIMD and Mojo
Since a couple of months I am experimenting with the new programming language called Mojo. The language is still in a very early stage, so is the standard library. In order to try things out, I currently concentrate on very basic functionality like hash functions, sorting and some tree data structures.
While I was working on sorting, I implemented counting and radix sort, which incorporate prefix sum algorithm. A prefix sum algorithm is generally very easy to implement, but I found an article, where the author claims to make it couple of times faster by employing SIMD operations.
One of the key features of Mojo is first class SIMD support, so I decided to go down the rabbit hole and check if I could implement a faster prefix sum algorithm in Mojo.
Lets start from the beginning, what is a prefix sum?
For an in depth understanding I think it is best to follow the Wikipedia link above, but actually the simple implementation of the algorithm in Mojo is self explanatory:
var element = array[0]
for i in range(1, len(array)):
array[i] += element
element = array[i]
So an element at index i
in the array is equal to itself plus element at index i — 1
Given an array: 1, 1, 1, 1, 1, 1, 1, 1
A prefix sum of this array is: 1, 2, 3, 4, 5, 6, 7, 8
Which is computed by 7 additions:
1,
(1 + 1) = 2,
(2 + 1) = 3,
(3 + 1) = 4,
(4 +1) = 5,
(5 + 1) = 6,
(6 + 1) = 7,
(7 + 1 ) = 8
For me, it was hard to imagine how one would implement it with SIMD as every iteration is based on the previous one. But there is a way:
1, 1, 1, 1, 1, 1, 1, 1
+ 0, 1, 1, 1, 1, 1, 1, 1
= 1, 2, 2, 2, 2, 2, 2, 2
+ 0, 0, 1, 2, 2, 2, 2, 2
= 1, 2, 3, 4, 4, 4, 4, 4
+ 0, 0, 0, 0, 1, 2, 3, 4
= 1, 2, 3, 4, 5, 6, 7, 8
So it is possible to compute a prefix sum of an 8 element vector in 3 (log2(8)
) steps , where we combine a vector shift right with a vector addition operation. In Mojo the code looks as following:
var v1 = SIMD[DType.uint8, 8](1, 1, 1, 1, 1, 1, 1, 1)
print(v1) # [1, 1, 1, 1, 1, 1, 1, 1]
v1 += v1.shift_right[1]()
print(v1) # [1, 2, 2, 2, 2, 2, 2, 2]
v1 += v1.shift_right[2]()
print(v1) # [1, 2, 3, 4, 4, 4, 4, 4]
v1 += v1.shift_right[4]()
print(v1) # [1, 2, 3, 4, 5, 6, 7, 8]
On the first line we define a SIMD vector to be 8 elements wide with numeric values of type uint8. Then we print the vector and proceed with shift right and assigned addition to mutate the vector 3 times and print the (intermediate/final) results along the way.
You might be surprised by the syntax of the shift_right
method call. The number of places we want to shift is passed in square brackets instead of parentheses. This means that the value is passed not at runtime, but at compile time, which also means that the value needs to be known at compile time. For more info on this topic please consult Mojo Programming manual.
How can we compute a prefix sum for a generic array whose size is only known at run time?
In order to do this we need to break down the runtime known array into chunks and perform the static compile defined operations on those chunks.
In order to perform the SIMD prefix sum on chunks we actually need to change the algorithm a bit.
Say we still have an array: 1, 1, 1, 1, 1, 1, 1, 1
But now we want to break it down in two 4 element chunks. The computation should look as following:
First chunk:
1, 1, 1, 1
+ 0, 1, 1, 1
= 1, 2, 2, 2
+ 0, 0, 1, 2
= 1, 2, 3, 4
+ 0, 0, 0, 0
= 1, 2, 3, 4
Second chunk:
1, 1, 1, 1
+ 0, 1, 1, 1
= 1, 2, 2, 2
+ 0, 0, 1, 2
= 1, 2, 3, 4
+ 4, 4, 4, 4
= 5, 6, 7, 8
Which is reflected in following Mojo code:
var v1 = SIMD[DType.uint8, 4](1, 1, 1, 1)
var v2 = SIMD[DType.uint8, 4](1, 1, 1, 1)
print(v1) # [1, 1, 1, 1]
v1 += v1.shift_right[1]()
print(v1) # [1, 2, 2, 2]
v1 += v1.shift_right[2]()
print(v1) # [1, 2, 3, 4]
v1 += 0
print(v1) # [1, 2, 3, 4]
print(v2) # [1, 1, 1, 1]
v2 += v2.shift_right[1]()
print(v2) # [1, 2, 2, 2]
v2 += v2.shift_right[2]()
print(v2) # [1, 2, 3, 4]
v2 += v1[3]
print(v2) # [5, 6, 7, 8]
As you can see above, we need to carry over the last value from previous chunk in order to increment all vector values in current chunk by it.
So given the chunk of size n, we need to perform:
result += result.shift_right[1 << i]()
log2(n)
times, where i
is a number between 0
and log2(n)
My first instinct was to put the above statement in a for loop:
var v1 = SIMD[DType.uint8, 4](1, 1, 1, 1)
for i in range(0, 2):
v1 += v1.shift_right[1 << i]()
print(v1)
This however will not compile. As I already mentioned, the shift_right
method expects a compile time known value, an i
in the for loop is a runtime value (although the range is compile time known) But no worries, the Mojo standard library has our backs. It provides a function which allows us to perform loop unrolling in a more functional way:
from algorithm import unroll
fn prefix_sum_on_chunk(inout v1: SIMD[DType.uint8, 4], carry_over: UInt8):
@parameter
fn add[i: Int]():
v1 += v1.shift_right[1 << i]()
unroll[2, add]()
v1 += carry_over
var v1 = SIMD[DType.uint8, 4](1, 1, 1, 1)
print(v1) # [1, 1, 1, 1]
prefix_sum_on_chunk(v1, 0)
print(v1) # [1, 2, 3, 4]
This way the compiler emits code which is similar to what we wrote above (without runtime branching and condition checks)
Next question, how big should the chunks be?
This depends on the type of the element in the array (how much bytes one elements occupies) and the hardware we are running the algorithm on. Standard library does provide an autotune function which should automate this kind of decision, but to be honest with you, I did some manual tuning and come to the conclusion that on my laptop (11th Gen Intel(R) Core(TM) i7–1165G7 @ 2.80GHz) following vector width performs best:
- 1 byte elements (int8, uint8), with 256 wide SIMD vector
- 2 byte elements (int16, uint16, float16), with 128 wide SIMD vector
- 4 byte elements (int32, uint32, float32), with 64 wide SIMD vector
- 8 byte elements (int64, uint64, float64), with 32 wide SIMD vectors
This implies that, if the array is smaller then the preferred vector width, or not an exact multiple of the preferred vector width, we need to compute the rest with a smaller vector width.
To make it clear: say we have an array of uint64 with 80 elements in it. As I mentioned before we prefer to take chunks of 32 elements for uint64 arrays, which means that we can compute the first 64 elements by taking two chunks with 32 wide SIMD vector and then we would need to reduce the vector size to 16 in order to compute the rest.
You can find the complete simd_prefix_sum implementation if you follow the link.
Given my previous explanation you should be able to follow along the code. That said, please don’t hesitate to write a comment if you have questions or suggestions.
Last but not least I would like to talk about runtime characteristics of the scalar and SIMD prefix sum functions, but first another disclaimer.
After I implemented the SIMD prefix sum and pushed it to the GitHub repository, I announced it on the Mojo Discord server, where a user pointed out, that there is already a prefix sum function in standard library, which I missed. So for the benchmark comparison I included the runtime characteristics of the std function as well.
The table and chart above shows that the SIMD prefix sum takes from 0.03 to 0.2 nanoseconds per array element to compute dependent on the element and array size, where the scalar prefix sum is very stable at around 0.5 nanoseconds per element. The SIMD speedup is between 2.5x and 15x which is quite great. The results also shows that there is something strange going on with the prefix sum function in the standard library. It is some times comparable with my SIMD implementation, but in some cases, slower than the scalar prefix sum.
You can find the code I used for benchmarks here.
Thank you for reading and leave a clap or two if you will.