MMS • Sergio De Simone
Apple’s MLX combines familiar APIs, composable function transformations, and lazy computation to create a machine learning framework inspired by NumPy and PyTorch that is optimized for Apple Silicon. Implemented in Python and C++, the framework aims to provide a user-friendly and efficient solution to train and deploy machine learning models on Apple Silicon.
According to Apple, MLX is designed by machine learning researchers for machine learning researchers and released under the MIT license to make it easier for them to extend and improve it. It supports transformer language model training, large-scale text generation with Mistral, image generation with Stable Diffusion, and speech recognition with Whisper.
MLX offers a NumPy-inspired, low-level Python API as well as a fully-featured C++ API that closely mirrors it. In addition, it provides a higher-level API that can be used to create more complex models following PyTorch API.
The framework supports automatic differentiation, automatic vectorization, and computation graph optimization through composable functions which make it easier to build complex array transformations. MLX also supports lazy computation, meaning it materializes arrays only when necessary to improve computational efficiency. Likewise, computation graphs are built dynamically, so that changing the shapes of function arguments does not trigger slow compilations.
A distinctive feature in MLX is the use of Apple Silicon unified memory, which sets it apart from other ML frameworks, says Apple. In short, this means arrays live in shared memory and array operations can be performed on either the CPU or GPU without transferring data across memory. For example, when you create an array, you do not need to specify its location, since it lives in unified memory, while you choose whether to run a transformation on it on the CPU or GPU when you execute it:
a = mx.random.normal((100,))
b = mx.random.normal((100,))
mx.add(a, b, stream=mx.cpu)
mx.add(a, b, stream=mx.gpu)
MLX can be used on any Apple Silicon CPUs, including the M1, and can leverage the integrated GPU, so researchers can choose the hardware that is best suited for their needs.
The MLX repo includes several examples of how to use it with different models, including BERT, Llama, Mistral, Stable Diffusion, and more. Each example lists which dependencies are required in a
requirements.txt file and provides ready-to-use CLI tools. For example, to generate images with Stable Diffusion, you first install all required dependencies, then run the
pip install -r requirements.txt
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
The Stable Diffusion example, though, does include a comparison of the performance of the UNet in Stable Diffusion when run through PyTorch and MLX. This shows that MLX achieves ~40% higher throughput than PyTorch with a batch size of 16 and ~15% higher when comparing the optimal batch sizes.
PyTorch, though, performs better for smaller batch sizes, with ~50% higher throughput for batch sizes of 1 and ~10% higher for batch sizes of 4. According to Apple, PyTorch’s advantage in those cases can be accounted for through compilation speed when the models are not loaded in memory and PyTorch’s MPS graph kernels are not cached.