Sujay is an AI Infrastructure Engineer at Tesla Autopilot, building ML infrastructure for training Tesla Full Self Driving and Optimus robot models. His main focus area is to enable efficient training of large language and vision models at scale. Prior to Tesla, Sujay graduated from Carnegie Mellon University where he trained large speech recognition models. Having been a part of early stage startups, he is also keenly aware of challenges of building AI products from zero to one.
Training Large Deep Learning Models: PyTorch vs JAX
As software developers tasked with building and deploying AI products into production, the first and foremost decision we need to make is the choice of framework for training our deep learning models. This decision has far reaching consequences as it is almost impossible to switch the frameworks once the development and research team has traction. Hence, it is imperative for software developers starting on AI/ML projects crucial to their business and careers to be aware of the common pros and cons of each training framework and pick the best one suited to their needs. In this talk I delve deeper into these tradeoffs so that the participants can make an informed decision for their use cases. While there are (only) 2 major frameworks that are widely adopted in the industry, PyTorch and JAX differ widely in their core basic philosophies.
In terms of efficiency vs. flexibility, PyTorch generally provides more flexibility and ease of use for dynamic computation graphs and rapid prototyping, while JAX prioritizes computational efficiency and performance optimization through its functional programming model and advanced automatic differentiation capabilities. In this talk, I explore the differences in programming styles and explain the different tools available for large model training such as Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), Model Parallelism (MP), Tensor Parallelism (TP) in PyTorch vs device meshes and automatic sharding in JAX. I will also spend a bit of time exploring the different tooling and ecosystem of open source resources available in each of these frameworks to enable the participants to choose the right framework.