핵심 개념
FAX는 JAX의 분할 메커니즘을 활용하여 TPU와 최신 JAX 런타임을 네이티브로 대상으로 하는 대규모 분산 및 연합 계산을 지원합니다. FAX는 연합 계산을 기본 연산으로 구현하여 세 가지 주요 이점을 제공합니다: 1) FAX 계산을 XLA HLO로 변환할 수 있습니다. 2) FAX는 연합 자동 미분(federated AD)의 완전한 구현을 제공하여 연합 계산의 표현을 크게 단순화합니다. 3) FAX 계산을 기존 프로덕션 크로스 디바이스 연합 컴퓨팅 시스템으로 해석할 수 있습니다.
초록
FAX는 JAX 기반 라이브러리로, 대규모 분산 및 연합 계산을 지원하도록 설계되었습니다. FAX는 JAX의 분할 메커니즘을 활용하여 TPU와 최신 JAX 런타임을 네이티브로 대상으로 합니다.
FAX는 연합 계산을 기본 연산으로 구현하여 세 가지 주요 이점을 제공합니다:
FAX 계산을 XLA HLO로 변환할 수 있습니다. 이를 통해 데이터 센터에서 효율적이고 확장 가능한 계산을 수행할 수 있습니다.
FAX는 연합 자동 미분(federated AD)의 완전한 구현을 제공하여 연합 계산의 표현을 크게 단순화합니다. 이를 통해 연합 학습 알고리즘 개발을 가속화할 수 있습니다.
FAX 계산을 기존 프로덕션 크로스 디바이스 연합 컴퓨팅 시스템으로 해석할 수 있습니다. 이를 통해 데이터 센터 성능과 프로덕션 시스템 간의 격차를 해소할 수 있습니다.
FAX는 연합 학습뿐만 아니라 다양한 병렬 및 분산 알고리즘을 표현, 분할 및 실행하는 데 사용될 수 있습니다. 이는 데이터 최소화가 필요하지 않거나 이기종 데이터에서 작동하지 않는 알고리즘을 포함합니다.
통계
350M 모델의 경우 라운드당 최대 3.355 × 10^7개의 토큰을 처리하고 2.293 × 10^13 FLOP을 수행합니다.
1B 모델의 경우 라운드당 최대 8.389 × 10^6개의 토큰을 처리하고 1.638 × 10^13 FLOP을 수행합니다.
8B 모델의 경우 라운드당 최대 2.097 × 10^6개의 토큰을 처리하고 3.277 × 10^13 FLOP을 수행합니다.