2025年10月1日晚,福建省科技馆化身为一座充满未来感的“科学之城”。以“智趣AI幻夜奇旅”为主题的大型科普活动“科学之夜”正式拉开帷幕,为市民奉上...
2025-10-04 0
JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今天放出了Tunix这个库,专门做LLM的后训练——微调、强化学习、知识蒸馏这些都能搞。
这是个构建在JAX之上的后训练库,和Flax NNX集成得比较紧密。主要解决三类问题:
现在还在早期开发阶段,功能在持续迭代,支持的模型也在慢慢扩展。
监督微调:既支持全参数微调,也支持LoRA和Q-LoRA这类参数高效的方法。内存和算力受限的时候,PEFT方案还是挺实用的。
强化学习:实现了几个主流算法:PPO(Proximal Policy Optimization)、GRPO(Group Relative Policy Optimization)、还有token级别的GSPO。另外还有DPO(Direct Preference Optimization)做偏好对齐,这个在RLHF场景用得比较多。
知识蒸馏:支持几种策略,包括基于logit的概率分布匹配、注意力机制的转移和投影、跨架构的特征池化与投影。这几种方法在不同场景下各有用处。
库的设计比较模块化,组件可以自由组合,想扩展自定义流程也不算麻烦。分布式训练支持数据并行(DP)、完全分片数据并行(FSDP)和张量并行(TP),对TPU做了专门优化。
三种装法:
从PyPI装(推荐):
pip install "tunix[prod]"
或者直接从GitHub主分支:
pip install git+https://github.com/google/tunix
开发模式从源码装:
git clone https://github.com/google/tunix.git
cd tunix
pip install -e ".[dev]"
拿个英译法的任务来演示。用的是Google的Gemma 2B模型,跑在TPU v5e-8上。
环境准备
pip install -q kagglehub safetensors tensorflow tensorflow_datasets tensorboardX transformers grain datasets
pip install -q git+https://github.com/google/tunix
pip install -q git+https://github.com/google/qwix
# Flax需要升级到最新版
pip uninstall -q -y flax
pip install -q git+https://github.com/google/flax.git
完整流程
第一步,从Kaggle拉预训练checkpoint:
import kagglehub
model_path = "google/gemma/flax/2b"
kaggle_ckpt_path = kagglehub.model_download(model_path)
初始化模型和tokenizer:
from flax import nnx
from tunix.models.gemma import model as gemma_lib, params as params_lib
from tunix.generate import tokenizer_adapter as tokenizer_lib
base_model = gemma_lib.Transformer.from_params(
params_lib.load_and_format_params(kaggle_ckpt_path, "2b"),
version="2b"
)
tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=f"{kaggle_ckpt_path}/tokenizer.model")
挂上QLoRA adapter:
import qwix
lora_provider = qwix.LoraProvider(
module_path=".*(q_einsum|kv_einsum|proj)",
rank=16,
alpha=2.0,
weight_qtype="nf4" # enable QLoRA quantization
)
lora_model = qwix.apply_lora_to_model(base_model, lora_provider)
这里rank设成16,alpha是2.0,weight_qtype指定nf4量化格式。
加载训练数据:
from tunix.examples.data import translation_dataset
train_ds, validation_ds = translation_dataset.create_datasets(
dataset_name="mtnt/en-fr",
global_batch_size=16,
max_target_length=256,
num_train_epochs=3,
tokenizer=tokenizer,
)
用的是mtnt的英法平行语料,batch size 16,目标序列最长256个token。
开始训练:
from tunix.sft import peft_trainer, utils
import optax
trainer = peft_trainer.PeftTrainer(
lora_model,
optimizer=optax.adamw(1e-3),
config=peft_trainer.TrainingConfig(max_steps=100)
)
trainer.train(train_ds, validation_ds)
优化器用AdamW,学习率1e-3,跑100步看看效果。
推理测试:
训练完直接用adapter过的模型做生成。Tunix提供了Sampler工具:
from tunix.generate import sampler as sampler_lib
# initialize sampler
sampler = sampler_lib.Sampler(
transformer=lora_model,
tokenizer=tokenizer,
cache_config=sampler_lib.CacheConfig(
cache_size=256,
num_layers=base_model.num_layers,
num_kv_heads=base_model.num_kv_heads,
head_dim=base_model.head_dim,
),
)
# test prompts
input_batch = [
"Translate this into French:\nHello, my name is Morgane.\n",
"Translate this into French:\nThis dish is delicious!\n",
"Translate this into French:\nI am a student.\n",
"Translate this into French:\nHow's the weather today?\n",
]
# generate predictions
out_data = sampler(
input_strings=input_batch,
max_generation_steps=20,
)
# print results
for input_string, out_string in zip(input_batch, out_data.text):
print(f"----------------------")
print(f"Prompt:\n{input_string}")
print(f"Output:\n{out_string}")
如果用的是QLoRA,把lora_model换成qlora_model就行。生产环境可以考虑把adapter合并回基模型,推理延迟能降下来。
100步训练之后,模型已经能生成一些翻译结果了,虽然质量还不够好。多训练一段时间,准确率会明显提升,而且内存开销和训练速度都保持在不错的水平。
Tunix现在还比较新,但已经能看出一些潜力。TPU优先的设计、模块化的API、LoRA/QLoRA支持、完整的分布式训练策略,这些对做LLM适配研究的人来说都挺有用。
后续应该会继续扩展支持的模型类型和训练算法,值得关注。
作者:Abish Pius
相关文章
2025年10月1日晚,福建省科技馆化身为一座充满未来感的“科学之城”。以“智趣AI幻夜奇旅”为主题的大型科普活动“科学之夜”正式拉开帷幕,为市民奉上...
2025-10-04 0
10月1日,阜阳颍州万达广场内,一家集汽车、智能手机与智能家居于一体的全新体验店正式亮相,标志着小米在皖北地区的渠道布局迈出关键一步。当天上午,小米之...
2025-10-04 0
IT之家 10 月 4 日消息,消息人士 Evan Blass 今天在 X 平台放出了摩托罗拉 Moto Edge 70 手机的宣传物料。根据消息人士...
2025-10-04 0
JAX生态这两年在LLM训练这块追赶得挺快。PyTorch虽然还是主流但JAX在并行计算、TPU加速和API组合性上确实有些独特的优势。Google今...
2025-10-04 0
当大家第一次看到“拉布布被拍出108万”的消息时,第一反应是:这玩意值那么多钱?才过了一夜,“明星晒拉布布”、“黄牛排队囤货”、“情绪价值顶流”就反复...
2025-10-04 0
您好:这款游戏可以开挂,确实是有挂的,很多玩家在这款游戏中打牌都会发现很多用户的牌特别好,总是好牌,而且好像能看到-人的牌一样。所以很多小伙伴就怀疑这...
2025-10-04 8
手机麻将怎么让系统给你发好牌是一款可以让一直输的玩家,快速成为一个“必胜”的ai辅助神器,有需要的用户可以加微下载使用。 手机打牌可以一键让你轻松成...
2025-10-04 14
微乐保皇如何让系统发好牌是一款可以让一直输的玩家,快速成为一个“必胜”的ai辅助神器,有需要的用户可以加微下载使用。 手机打牌可以一键让你轻松成为“...
2025-10-04 11
发表评论