视频生成模型太贵太慢怎么办?
普林斯顿大学和 Meta 联合推出的新框架LinGen,以 MATE 线性复杂度块取代传统自注意力,将视频生成从像素数的平方复杂度压到线性复杂度,使单张 GPU 就能在分钟级长度下生成高质量视频,大幅提高了模型的可扩展性和生成效率。
实验结果表明,LinGen在视频质量上优于 DiT(胜率达 75.6%),并且最高可减少 15 ×(11.5 ×)FLOPs(延迟)。此外,自动指标和人工评估均显示,LinGen-4B在视频质量上与最先进模型相当(分别以 50.5%、52.1%、49.1% 的胜率优于 Gen-3、Luma Labs 和 Kling)。
方法:线性复杂度的 MATE 模块
LinGen 维持 Diffusion Transformer(DiT)中的其他结构不变,而将其计算瓶颈——平方复杂度的自注意力模块替换为线性复杂度的 MATE 模块,它由 MA 分支和 TE 分支组成。
其中,MA 分支包含一个双向的 Mamba2 模块。
Mamba2 作为 State Space Model(SSM)的变体,善于处理超长的 token 序列,同时又对硬件非常友好,可以使用 attention 的各种硬件加速核,如 xformers,FlashAttention 等。但是 Mamba 系列模型在语言任务上的优秀表现难以直接迁移到大型视觉任务上,生成的高分辨率视频往往一致性很差、质量不高。
一些特殊的 scan 方法尝试解决这一问题,如 Zigzag scan,Hilbert scan,但它们都要求对序列做复杂的顺序变换,而这个操作对硬件极其不友好。在处理高分辨率、长视频时,会带来显著的额外延迟。
针对于此,LinGen 提出Rotary Major Scan(RMS),相邻层中四种 scan 方式交替切换。
以上图的方式为例,W,H 和 T 分别在展开时有第一、第二和第三优先级,通过交换展开的优先级,就可以实现不同的 scan 方式。
相比于已有方法,该方法最大的好处是对硬件非常友好、可以通过简单的 tensor reshaping 实现,因此也几乎没有额外开销,同时还把 scan 后原相邻 token 的平均距离降到了和已有特殊 scan 方式相同的水平。
然而,所有这些特殊的 scan 方式仍然不足以完全解决 Mamba 的临近信息丢失问题,因为在模型的任意一层中,只会有一种 scan 方式被应用,如果不考虑跨层交流,大量临近信息在单层中依旧有损失。
针对于此,LinGen 在 TE 分支中应用了TEmporal Swin Attention(TESA):它是一种特殊的 3D window attention,窗口范围在不同层中会滑动,每一个窗口都很小,并且窗口大小不随视频分辨率和长度(即 3D tensor 的大小)的变化而变化。
这是因为 TESA 仅用来处理最临近的信息,这一固定的窗口大小也使得 TESA 实现了相对 3D tensor 中 token 数的线性复杂度。
作为额外的补充,LinGen 还在 MA 分支中引入了review tokens。它被用以增强视频中极长程的一致性,例如在 60 秒视频的结尾复现视频前几秒消失的人。它把待处理 video tensor 的概览提前写入 Mamba 的 hidden state memory 中,为后续的视频处理提供帮助。
评估:远超基线,对标 SOTA
从人类评测和模型自动评测两个角度将 LinGen 与已有的先进视频生成模型、以及 DiT baseline 进行比较。
无论是人类评测的结果,还是在 VBench 上的自动评测的结果,都显示 LinGen 与先进的商业模型 Kling、Runway Gen-3 生成的视频质量接近,并且远胜于 OpenSora v1.2。
可以看到,在 FLOPs 方面,当生成 17 秒、34 秒和 68 秒长度的 512p 视频时,LinGen-4B 相对于 DiT-4B 分别实现了 5 ×、8 × 和 15 × 的加速;
在延迟方面,当在单个 H100 上生成 512p 和 768p 的 17 秒视频时,LinGen-4B 相对于 DiT-4B 分别实现了 2.0 × 和 3.6 × 的加速;
当生成 17 秒、34 秒和 68 秒长度的 512p 视频时,LinGen-4B 相对于 DiT-4B 分别实现了 2.0 ×、3.9 × 和 11.5 × 的延迟加速。
这说明 LinGen 具有线性复杂度,可以在单卡上实现分钟级视频生成,速度远快于 DiT。与相同大小的 DiT 相比,LinGen 可实现推理速度 11 倍以上的提升。
另外,LinGen 和相同大小、在相同数据集上以相同 training recipe 训练的 DiT baseline 相比,在视频质量和文字 - 视频一致性上取得全面领先。相比起 DiT,LinGen 可以更快地适应更长的 token 序列。
通常认为自注意力模块的线性替代是对完整自注意力的近似,虽然在速度上有显著优势,但在模型性能上往往略逊一筹,而 LinGen 打破了这个惯有的看法。
在整个预训练过程中,模型从低分辨率图像生成开始,学习低分辨率视频生成,再不断增加所生成视频的分辨率和长度,所处理的 token 数增长了上千倍。
而在从少 token 数的任务迁移到多 token 数的任务时,LinGen 的适应性远强于 DiT(a 图中是从 256x256 分辨率视频生成迁移到 512x512 分辨率视频生成任务时的 loss curve),这可能是受益于 Mamba 对于长序列的高适应性,这一特征已经在语言任务上被观察到。
为了进一步验证这里推理,选取这一预训练阶段的早期 checkpoint 进行比较,发现 LinGen 比 DiT 的 win rate 优势变得更加显著。这暗示了虽然 LinGen 在任务迁移的早期能大幅领先 DiT,但是这种优势随着预训练的进行,在不断减小。
尽管如此,在训练资源有限的情况下,LinGen 在预训练的极长一段时间内仍旧能对 DiT 保持优势。
项目主页:https://lineargen.github.io/
论文链接:https://arxiv.org/abs/2412.09856
项目代码:https://github.com/jha-lab/LinGen
一键三连「点赞」「转发」「小心心」
欢迎在评论区留下你的想法!
— 完 —
点亮星标
科技前沿进展每日见
登录后才可以发布评论哦
打开小程序可以发布评论哦