RXTX: A Machine Learning-Guided Algorithm for Efficient Structured Matrix Multiplication
Discovering faster algorithms for matrix multiplication is a key pursuit in computer science and numerical linear algebra. Since the contributions of Strassen and Winograd in the late 1960s, several strategies have emerged, including gradient-based methods, heuristic techniques, group-theoretic frameworks, graph-based random walks, and deep reinforcement learning. However, less focus has been placed on structured matrix products, such as transposes or identical matrices, which are frequent in applications like statistics, deep learning, and communication.
Expressions like AAT and XXT play critical roles in various domains, especially in large language model training algorithms. While previous studies have explored structured matrix multiplication from both theoretical and machine learning-based perspectives, recent improvements have been limited to certain scenarios and matrix sizes.
Introducing RXTX
Researchers from the Chinese University and the Shenzhen Research Institute of Big Data have developed RXTX, an algorithm designed specifically for the efficient computation of XXT where X belongs to Rn*m. RXTX reduces operations—multiplications and additions—by approximately 5% compared to leading methods, demonstrating effectiveness even with small matrix sizes (e.g., n = 4). This algorithm was discovered through machine learning-based search and combinatorial optimization, leveraging the specific structure of XXT for constant-factor acceleration.
Performance and Efficiency
RXTX enhances matrix multiplication by employing 26 general matrix multiplications alongside optimized addition schemes, resulting in fewer total operations. Theoretical analysis indicates a reduction in multiplicative and combined operations, particularly for larger matrices. Practical tests on 6144 × 6144 matrices demonstrate that RXTX is approximately 9% faster than standard BLAS routines, with speedups observed in 99% of runs. These findings highlight RXTX’s effectiveness for large-scale symmetric matrix products compared to traditional and state-of-the-art algorithms.
Methodology Overview
The RXTX algorithm integrates reinforcement learning with a two-tier Mixed Integer Linear Programming (MILP) pipeline to identify efficient multiplication algorithms, particularly for computing XXT. The RL-guided Large Neighborhood Search generates a broad range of potential rank-1 bilinear products, which serve as candidate expressions. MILP-A explores linear combinations of these products to match desired outputs, while MILP-B identifies the minimal subset necessary to represent all targets. This approach simplifies the process by reducing the action space and focusing on lower-dimensional tensor products.
For instance, when calculating XXT for a 2×2 matrix X, the objective is to derive expressions like x1² + x2² or x1x3 + x2x4. The RL policy samples thousands of bilinear products using coefficients from {−1, 0, +1}. MILP-A then identifies combinations of these products that align with the target expressions, while MILP-B selects the fewest necessary to encompass all targets. This framework enabled the discovery of RXTX, achieving 5% fewer operations than prior methods.
RXTX proves to be an efficient solution for both large and small matrices, showcasing a successful integration of machine learning-based search with combinatorial optimization.
For further insights and technical details, please refer to the respective research paper.
All credit for this research goes to the researchers of this project. You can also follow us on Twitter and consider joining our 95k+ ML SubReddit as well as subscribing to our Newsletter.