Redirecting to original paper in 30 seconds...
Click below to go immediately or wait for automatic redirect
📄 Abstract
Abstract: Modern AI relies on huge matrix multiplications (MatMuls), whose computation
poses a scalability problem for inference and training. We propose an
alternative, GPU native bilinear operator to MatMuls in neural networks, which
offers a three-way tradeoff between: speed, accuracy and parameter count. In
particular, this operator requires substantially fewer FLOPs to evaluate ($\ll
n^3$), yet increases the parameter count compared to MatMul ($\gg n^2$). We
call this operator Strassen-Tile (STL). The key idea behind STL is a local
learnable change-of-basis, applied on tiles of the weight and activation
matrices, followed by an element-wise product between the tiles, implemented
simultaneously via MatMul. The key technical question we study is how to
optimize the change-of-basis of a given layer, which is a highly non-convex
problem. We show that theory-backed initializations (inspired by fast matrix
and polynomial multiplication) lead to substantially better accuracy than
random SGD initialization. This phenomenon motivates further algorithmic study
of STL optimization in DNNs. Our experiments demonstrate that STL can
approximate 4x4 MatMul of tiles while reducing FLOPs by a factor of 2.66, and
can improve Imagenet-1K accuracy of SoTA T2T-ViT-7 (4.3M parameters) while
lowering FLOPs. Even with non-CUDA optimized PyTorch code, STL achieves
wall-clock speedups in the compute-bound regime. These results, together with
its theoretical grounds, suggest STL as a promising building block for scalable
and cost-efficient AI.
Authors (4)
Nir Ailon
Akhiad Bercovich
Yahel Uffenheimer
Omri Weinstein
Key Contributions
This paper proposes Strassen-Tile (STL), a GPU-native bilinear operator as a GPU-efficient alternative to matrix multiplication (MatMul) in DNNs. STL offers a tradeoff between speed (fewer FLOPs) and parameter count (increased), achieving substantial computational savings by employing a local learnable change-of-basis on matrix tiles before an element-wise product.
Business Value
Enables faster and more efficient deployment of deep learning models, particularly for inference on resource-constrained devices or for large-scale training, reducing operational costs and latency.