Skip to content

Investigate and implement Flash Attention #5

@0x000011b

Description

@0x000011b

For our use-case of fine-tuning LMs on up to 2048 tokens, flash attention might get us a ~2-4x speedup and a VRAM usage reduction of up to 10x. Sounds pretty amazing, so I'd like to give it a shot. Some code inspiration:

Metadata

Metadata

Labels

enhancementNew feature or request

Type

No type

Projects

Status

✅ Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions