Sujay S Kumar


Sujay is an AI Infrastructure Engineer at Tesla Autopilot, building ML infrastructure for training Tesla Full Self Driving and Optimus robot models. His main focus area is to enable efficient training of large language and vision models at scale. Prior to Tesla, Sujay graduated from Carnegie Mellon University where he trained large speech recognition models. Having been a part of early stage startups, he is also keenly aware of challenges of building AI products from zero to one.


Training Large Deep Learning Models: PyTorch vs JAX


As software developers tasked with building and deploying AI products into production, the first and foremost decision we need to make is the choice of framework for training our deep learning models. This decision has far reaching consequences as it is almost impossible to switch the frameworks once the development and research team has traction. Hence, it is imperative for software developers starting on AI/ML projects crucial to their business and careers to be aware of the common pros and cons of each training framework and pick the best one suited to their needs. In this talk I delve deeper into these tradeoffs so that the participants can make an informed decision for their use cases. While there are (only) 2 major frameworks that are widely adopted in the industry, PyTorch and JAX differ widely in their core basic philosophies.

  1. PyTorch is an open source training framework by Facebook’s AI Research lab, with a strong focus on dynamic computation graphs and ease of use for researchers.
  2. Provides a rich ecosystem of libraries and tools for tasks like computer vision, natural language processing, and reinforcement learning.
  3. Well-established in both research and production environments, with a large community and extensive documentation.
  4. Supports large model training through Data Distributed Parallel, Fully Sharded Data Parallel, Model Parallelism, Tensor Parallelism.
  1. JAX is an open source training framework by Google Brain, with a focus on high-performance numerical computing and automatic differentiation.
  2. Uses a functional programming style, which can be more concise and easier to reason about for certain types of models and computations.
  3. Relatively newer and smaller community compared to PyTorch, but rapidly gaining adoption in research and specific domains like scientific computing.
  4. JAX enables large model training through automatic compiler-level sharding with minimal development efforts.

In terms of efficiency vs. flexibility, PyTorch generally provides more flexibility and ease of use for dynamic computation graphs and rapid prototyping, while JAX prioritizes computational efficiency and performance optimization through its functional programming model and advanced automatic differentiation capabilities. In this talk, I explore the differences in programming styles and explain the different tools available for large model training such as Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), Model Parallelism (MP), Tensor Parallelism (TP) in PyTorch vs device meshes and automatic sharding in JAX. I will also spend a bit of time exploring the different tooling and ecosystem of open source resources available in each of these frameworks to enable the participants to choose the right framework.