Counting chars with SIMD in Mojo
Mojo is a very young (actually a work in progress) programming language designed and developed by a new company called Modular. Here is a blurb from their website:
Mojo combines the usability of Python with the performance of C, unlocking unparalleled programmability of AI hardware and extensibility of AI models.
So, the language is a Python superset, which should have full Python inter-op and allows developers, with proper know-how, develop efficient modules. This sounds a little bit like Cython, however Mojo has much more ambitious goals. If you are interested in more details, browse through the FAQ compiled by the Modular team.
To enable low-level programming, Mojo introduces extra keywords and concepts to Python. For instance, it introduces a struct
keyword for defining efficient value types and an fn
keyword for defining efficient functions. If you want to delve into more specific information, I recommend referring to the Mojo programming manual.
At some point in the future, Mojo is planned to be released as an open-source project. However, currently, external developers can only try out the language by requesting access to a Jupyter Notebook provided by the Modular team.
Yours truly was granted an access a couple of days ago, and after spending some time studying the documentation, I became curious about the low-level implementation of Strings in Mojo.
Mojo has three types to express a string:
StringLiteral
a built in typeStringRef
also a built in typeString
a type defined in the standard library
As the name implies StringLiteral
represents a string we write in the code:
let s = "hello" # s is of type StringLiteral
In Mojo, a StringRef
can be created from a StringLiteral
and it appears to primarily serve as a type for ABI (Application Binary Interface) inter-operation.
let s: StringRef = "hello"
# StringLiteral "hello" is converted to StringRef and assigned to s
A StringLiteral
has a data
method, which lets us get raw pointer to the underlying data. However, the type of the pointer, pointer<scalar<si8>>
, is quite mysterious since it’s not explained in the docs at all. At first, I thought it could be related to MLIR, because Mojo allows referencing MLIR types directly (see Low-level IR in Mojo), but even after searching, I couldn’t find any information about it on the MLIR side either. It’s definitely an intriguing aspect of Mojo’s implementation!
However, despite the mystery surrounding the pointer<scalar<si8>>
type, it’s worth noting that the Mojo standard library does include a module called Pointer
. This module defines a Pointer
struct that can be initialized with a pointer<type>
. Therefore, after some experimentation and tinkering, I was able to write the following code:
let s = "hello"
let p = Pointer(s.data())
for i in range(len(s)):
print(p[i])
Which returned the internal byte stream of the string:
104
101
108
108
111
And what do we learn from this? Mojo uses UTF-8 encoding to store strings!
Ok, but what about the actual String
type from the standard library?
If we want to use the String
struct from the standard library, we can simply write:
from String import String
let s: String = "hello"
In order to output the chars from a string, I figured, I should be able to write following:
from String import String
let s: String = "hello"
for c in s:
print(c)
But this end up in an error: ‘String’ does not implement the ‘__iter__’ method
. So lets try something else:
from String import String
let s: String = "hello"
for i in range(len(s)):
print(s[i])
And we get the characters printed:
h
e
l
l
o
That is great, but what about a string with characters, which are longer then one byte?
from String import String
let s: String = "hello 🔥"
for i in range(len(s)):
print(s[i])
While this code compiles successfully, it doesn’t produce any output in the Jupyter Notebook. It appears to result in a runtime error. Even attempting to print a specific character, such as print(s[6])
, doesn't work as expected. To troubleshoot the issue, let's examine the underlying byte stream:
let s = "hello 🔥"
let p = Pointer(s.data())
for i in range(len(s)):
print(p[i])
# returns
# 104
# 101
# 108
# 108
# 111
# 32
# -16
# -97
# -108
# -91
This printed output highlights one small oddity, which caught my eye earlier. The byte stream is typed as si8
which is a signed 1-byte integer. I think it is more logical to type sequence of bytes as an unsigned byte integer ui8
. But 🤷, that is not super important. What is important, we see that the byte sequence is 10 bytes long, if we execute print(len(s))
we also get 10 as a result. If we execute print(s[6:10])
we get 🔥
as the result.
What did we learn? In current implementation, the String
struct is a light weight wrapper around the UTF-8 byte sequence. The length corresponds to the number of bytes, not number of characters and if we try to access a multi byte character with an incorrect range, we get a runtime error.
To hone our Mojo skills, let’s create our own function that takes a StringLiteral
as input and returns the number of characters. However, before we proceed, it’s essential to grasp how UTF-8 encoding works and how we can identify when multiple bytes form a single character. Fortunately, all the necessary information about UTF-8 is available in this Wikipedia article.
Looking at the table I copied from the Wikipedia article, we can observe that every Unicode character can be encoded in UTF-8 using 1 to 4 bytes. The first byte in the sequence representing a character has a special trailing bits pattern that indicates the length of the sequence. Any byte that isn’t in the first position will always have a 10
trailing bits pattern. To determine whether a byte represents the start of a character, we can use the following check on each byte:
(byte >> 6) != 0b10
As described above, in Mojo, we have easy access to the underlying bytes of the string literal, so we can simply loop over each byte and apply the check. If the condition (byte >> 6) != 0b10
is true, we increment the character count. While this is a straightforward solution, we can further optimize it in Mojo.
Mojo offers excellent support for SIMD (Single Instruction Multiple Data) operations, which allow us to execute a single operation on a vector of values. In our case, we need to perform a right shift and an equality comparison on a sequence of 1-byte values. With SIMD, we can group these values into SIMD vectors and carry out both operations as following:
let p = DTypePointer[DType.si8](string_literal.data()).bitcast[DType.ui8]()
(p.simd_load[64]() >> 6) != 0b10
On the first line, we extract the data from the string literal and encapsulate it within a DTypePointer
. A DTypePointer
represents a pointer to DType
values, which is necessary for invoking the simd_load
method. This method, in turn, generates a SIMD type that enables us to execute vectorized operations.
You might be curious about the .bitcast[DType.ui8]()
operation. This is required in Mojo, because the string literal data is initially typed as si8
. By applying the .bitcast[DType.ui8]()
operation, we rectify this issue, ensuring compatibility with the >>
operator that we intend to use.
On the second line, we load 64 bytes from the pointer as a SIMD vector and perform shift to the right on all elements and after that, we compare all elements with 0b10
. The result of this comparison will be a SIMD vector of booleans, which we can cast to a SIMD vector of ui8
.cast[DType.ui8]()
Now we can sum all the element of the SIMD vector to get the number of chars.
.reduce_add().to_int()
Below is my first implementation of a function that calculates the number of characters in a string literal:
fn chars_len(s: StringLiteral) -> Int:
let p = DTypePointer[DType.si8](s.data()).bitcast[DType.ui8]()
let l = len(s)
var offset = 0
var result = 0
while l - offset >= 64:
result += ((p.simd_load[64](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 64
while l - offset >= 32:
result += ((p.simd_load[32](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 32
while l - offset >= 16:
result += ((p.simd_load[16](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 16
while l - offset >= 8:
result += ((p.simd_load[8](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 8
while l - offset >= 4:
result += ((p.simd_load[4](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 4
while l - offset >= 2:
result += ((p.simd_load[2](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 2
while l - offset >= 1:
result += ((p.simd_load[1](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += 1
return result
Oh boy, this looks like a lot of copy pasta! How comes?
Well, the length of the string literal is only known at runtime. The more bytes we can consume in a large SIMD vector, the better, so we “fall through” the different sizes of SIMD vectors.
Is there a better way?
After the first implementation, I come up with another one, which is better in the sense of less conditions, but has a memcpy:
from Bit import *
from Memory import *
from Pointer import *
fn chars_len[simd_width: Int](s: StringLiteral) -> Int:
let p = DTypePointer[DType.si8](s.data()).bitcast[DType.ui8]()
let l = len(s)
var offset = 0
var result = 0
while l - offset >= simd_width:
result += ((p.simd_load[simd_width](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
offset += simd_width
if offset < l:
let rest_p: DTypePointer[DType.ui8] = stack_allocation[simd_width, UI8, 1]()
memset_zero(rest_p, simd_width)
memcpy(rest_p, p.offset(offset), l - offset)
result += ((rest_p.simd_load[simd_width]() >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
result -= simd_width - (l - offset)
return result
In this implementation, we let the user provide the SIMD vector size, although they probably should use the autotune feature, but if they know that the string is short, they could pass a better vector size value.
We run the algorithm, explained above, on as many bytes as possible for the given vector size. Then we check, if we consumed all the bytes through the vector transformation. If not, we allocate memory on stack for another go and copy the unprocessed bytes, from string literals underlying bytes sequence, to the newly allocated stack region. As the rest_p
memory region has 0
bytes, which do not belong to the string, and 0
byte values will result in a positive char count, we need to subtract them from the result:
result -= simd_width - (l - offset)
And that concludes this blog post. However, I’m contemplating writing another one where we not only count the number of characters but also the number of graphemes in a string. Additionally, I plan to provide a function for safely truncating strings based on the number of bytes. If you find these topics interesting, please let me know in the comments section. Your feedback and suggestions are greatly appreciated!
Update 19th of May 2023
Every new technology needs a great community behind it and Mojo already has it!
Thanks to the community feedback, I was able to produce a third version of the code, which is very elegant, without sacrificing the performance:
from DType import DType
from Functional import vectorize
from Pointer import DTypePointer
from TargetInfo import dtype_simd_width
alias simd_width_u8 = dtype_simd_width[DType.ui8]()
fn chars_count(s: StringLiteral) -> Int:
let p = DTypePointer[DType.si8](s.data()).bitcast[DType.ui8]()
let string_byte_length = len(s)
var result = 0
@parameter
fn count[simd_width: Int](offset: Int):
result += ((p.simd_load[simd_width](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
vectorize[simd_width_u8, count](string_byte_length)
return result
There are a few concepts in this solution which I did not touch in my blog post before, so let me give a brief explanation.
alias simd_width_u8 = dtype_simd_width[DType.ui8]()
With the alias
keyword we identify that the value of simd_width_u8
is a constant, which will be computed at compile time. This value identifies how many entries a SIMD vector can have for the ui8
type. This can be computed at compile time, based on the architecture and what kind of SIMD support it has. In the Notebook, this value evaluates to 64 as the system, which runs the Notebook, has AVX512 support.
@parameter
fn count[simd_width: Int](offset: Int):
result += ((p.simd_load[simd_width](offset) >> 6) != 0b10).cast[DType.ui8]().reduce_add().to_int()
This is an inner function, which will be called with an offset to perform the count. The @parameter
decorator is needed, because we are capturing the result
variable. For more details please read the documentation.
vectorize[simd_width_u8, count](string_byte_length)
By calling the vectorize
function, I avoid the copy and paste frenzy, I introduced in the first solution. This function executes the count
function for us, based on the compile time parameters (the SIMD vector width, the function for the loop body) and arguments (the total loop count) we provide. My guess is, it does something similar to what I did manually in solution one, but this is not our burden to read and write anymore. Which is great!