NVIDIA ускоряет предобучение LLM: NVFP4 на Blackwell в связке с JAX и MaxText
NVIDIA опубликовала технический гайд по предобучению LLM на чипах Blackwell: формат NVFP4 в связке с JAX и MaxText сокращает время тренировки и…
AI-обработка оригинала NVIDIA Developer Blog; редакция Hamidun News
Предобучение frontier LLM упирается в пропускную способность вычислительных систем. NVIDIA показала, как связка JAX, MaxText и нового формата NVFP4 на чипах Blackwell позволяет значительно ускорить этот процесс без потери качества.
Почему каждый процент важен
Когда обучение идёт на триллионах токенов через тысячи ускорителей, экономия даже одного процента времени на каждом шаге выливается в несколько дней реального calendar-time. На масштабах frontier-обучения это прямой перевод в миллионы долларов compute-расходов. NVFP4 — четырёхбитный формат с плавающей точкой, дебютировавший в архитектуре Blackwell, — стал одним из ключевых инструментов для ускорения матричных операций. По сравнению с FP8 он вдвое плотнее упаковывает числа, что снижает нагрузку на память и повышает эффективную пропускную способность тензорных ядер. Главная сложность: четырёхбитная числовая сетка разреженна. При неправильной настройке градиенты легко выходят за её пределы — это приводит к расхождению обучения. NVIDIA и команда MaxText решали это через кастомные схемы масштабирования и динамический loss scaling.
Как работает mixed-precision с NVFP4
Mixed-precision обучение — подход не новый: FP8 и BF16 уже стали индустриальным стандартом. NVFP4 идёт на шаг дальше, позволяя использовать 4-битные веса в самых нагруженных матричных умножениях при сохранении более высокой точности там, где это действительно важно.
- NVFP4 применяется для весов и активаций в операциях GEMM BF16 или FP32 остаются для аккумуляторов и нормализации MaxText автоматически маршрутизирует операции в нужный формат JAX компилирует граф вычислений через XLA, оптимизируя ядра под Blackwell Итог — рост throughput при сопоставимом или меньшем энергопотреблении ## Стек и что менять в коде MaxText — открытый высокопроизводительный тренировочный фреймворк на JAX, разработанный Google. Изначально он создавался под TPU, но активно адаптируется под GPU-кластеры, и партнёрство с NVIDIA здесь закономерно. NVIDIA включила низкоуровневые NVFP4-ядра в состав cuBLAS и cuDNN, а JAX/XLA получил поддержку этих операций через специальные адаптеры. Разработчикам не нужно переписывать тренировочный код вручную — достаточно включить нужные флаги в конфигурации MaxText и убедиться, что в кластере установлены чипы Blackwell (B100, B200, GB200).
«Числовая точность — один из самых рычажных параметров, но low-bit mixed-precision предобучение сложно реализовать правильно», — отмечает команда NVIDIA
Developer Blog.
Что это значит
Для команд, занимающихся предобучением frontier-моделей, NVFP4 на Blackwell — это практически бесплатное ускорение: существующий стек на JAX и MaxText требует минимальных изменений конфигурации. На масштабах сотен и тысяч GPU даже 10–15% роста throughput напрямую сокращает time-to-checkpoint и общий compute-бюджет. Гонка за эффективностью предобучения переходит в фазу битвы за числовую точность.
Хотите не читать про ИИ, а внедрить его?
«AI News» — это полезные новости из мира ИИ. Системно научиться работать с нейросетями и применять их в работе — в Hamidun Academy.
Главное из мира ИИ — раз в неделю
7 ключевых событий недели, отобранных вручную. Без шума, репостов и пресс-релизов.
Готово! Проверьте почту — мы отправили подтверждение.