Introduction
This article:
- Discusses how to enable model running and quality validation processes on the Argonne National Laboratory supercomputing cluster using the JAX framework and Intel® Extension for OpenXLA* back end.
- Covers the hardware overview of the Argonne National Laboratory supercomputer and version compatibility.
- Gives updates for JAX and Intel Extension for OpenXLA, model deployment testing, and quality monitoring systems.
- Serves as a guide for the development and validation of JAX projects related to Argonne.
Aurora Supercomputer
Aurora is the first supercomputer to deploy Intel® Data Center GPU Max Series and is also the world’s largest system based on Intel® Xeon® CPU Max series. Additionally, it boasts the largest GPU cluster currently in existence. Aurora is equipped with 21,248 Intel Xeon processor Max series, providing a total of 110,000 cores, and 63,744 Intel Data Center GPU Max series for handling AI and HPC workloads.
Our hardware is Intel Data Center GPU Max Series 1550 (Intel GPU: Intel Data Center GPU Max Series Overview). This new GPU product is based on the Xe HPC microarchitecture. The GPU uses highly parallelized computing models associated with AI and HPC. It is supported by the oneAPI open ecosystem with the flexibility of Single Instruction Multiple Data (SIMD) and Single Instruction Multiple Threads (SIMT).
JAX and Intel® Extension for OpenXLA*
JAX is a Python* library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. This AI open source framework led Google* for Argon National Labs supercomputer.
In the Argonne project, JAX uses Intel Extension for OpenXLA as its compilation and execution backend to optimize and parallelize computations, especially for the PVC AI hardware accelerator from Intel in the Aurora supercomputer.
How to Enable Cases and Ensure the Quality for Argonne
Argonne Workload
The Argonne workload mainly includes the following public models, which we can visit and download in local hardware.
We cover some public cases: JAX, Gordon, Janet, QMC, and AlphaFold.
- JAX Public UT Cases
- Gordon: Code Demonstrating Joint Use of JAX and Numba* When Calculating Model Gradients
- Janet: Implementation of Forward Models for Galaxy SED and Photometry Based on JAX
- Implementation of the Quantum Monte Carlo (QMC) method Based on the JAX Framework. The QMC method has broad applications in computational physics, chemistry, and materials science.
- AlphaFold Is a Protein Structure Prediction Model Based on Deep Learning.
Quality Validation Structure
Prepare Test Environment
We need to obtain permission to access the Aurora computing cluster through the designated ports and apply for a specified number of compute nodes. The JAX project applied for four nodes from Aurora, with each node equipped with six GPU PVCs from Intel (Intel Data Center GPU Max Series 1550) cards.
Install Intel® oneAPI Base Toolkit (Base Kit) Packages
Install the following Base Kit components:
- Intel® oneAPI DPC++ Compiler
- Intel® oneAPI Math Kernel Library (oneMKL)
$ wget https://registrationcenter-download.intel.com/akdlm/IRC_NAS/e6ff8e9c-ee28-47fb-abd7-5c524c983e1c/l_BaseKit_p_2024.2.1.100_offline.sh
# 2 components are necessary: DPC++/C++ Compiler and oneMKL
sudo sh l_BaseKit_p_2024.2.1.100_offline.sh
# Source OneAPI env
$ source /opt/intel/oneapi/compiler/2024.2/env/vars.sh
$ source /opt/intel/oneapi/mkl/2024.2/env/vars.sh
Install JAX and Intel Extension for OpenXLA
$ pip install jax==0.4.26
$ pip install jaxlib==0.4.26
$ pip install --upgrade intel-extension-for-openxla
The table tracks intel-extension-for-openxla versions and compatible versions of jax and jaxlib. The compatibility between jax and jaxlib is maintained through JAX. This version restriction will be relaxed over time as the plug-in API matures.
Jenkins CI/CD
- Continuous testing to detect regression or bugs that may occur during optimization activities.
- Performance testing to check the performance gain thanks to optimization activities with A100.
- Functional testing to check if the calculation outcome is not broken due to optimization activities and the Intel driver and Base Kit upgrades.
- Maintaining pipelines in Jenkins and other CI-required components.
- Monitor test coverage
For the testing requirements of the previous cases, set up a quality validation system.
The previous image illustrates the entire case coverage and validation process for JAX validation.
Jenkins jobs, including PreCI, JAX nightly tests, and weekly tests.
- Pull Request: Monitor code submissions for intel-extension-for-openxla. PreCI primarily verifies the impact of Intel Extension for OpenXLA PR on JAX and Argonne workloads. The test content includes some JAX UT and Argonne workload three-model UT.
- JAX Nightly: The JAX nightly test cases include JAX UT and Argonne workloads, verifying the daily updates to Intel Extension for OpenXLA code.
- Weekly Test: There are separate jobs for feature tests and model tests. These individual jobs make it easier to debug and troubleshoot each testing task. Once all the test jobs are ready, they will be merged into the weekly test, which includes UT tests, feature tests, and model tests.
The requirements for ensuring product quality are as follows: The main changes come from Intel Extension for OpenXLA code pushes and driver/Base Kit upgrades.
The driver and Base Kit upgrades should maintain the same configuration as the Intel Extension for OpenXLA. Once the Intel Extension for OpenXLA completes the upgrade, JAX performs a synchronized upgrade to ensure compatibility. The upgrades of the driver and Base Kit are generally managed by Borealis system administrators, but some specific versions of Base Kit may need to be installed manually. After the upgrade is completed, weekly tests will be conducted to ensure that JAX workloads pass testing with the new versions of Driver and Base Kit.
The requirements for ensuring product quality are as follows: The main changes come from Intel Extension for OpenXLA code pushes and the driver/Base Kit upgrades.
- Intel Extension for OpenXLA Code Validation: Pull request jobs and nightly jobs will test PR merges to ensure the functionality of JAX UT and ARGONNE workload 3-model UT.
- Driver/Base Kit Upgrade Validation: All UT tests, feature tests, and model tests need to pass with performance regression of less than or equal to 5%.
If there are any failed cases or model performance degradation, a JIRA ticket will be created for tracking and analysis.