Skip to content

Conversation

@jiecaoyu
Copy link

@jiecaoyu jiecaoyu commented Oct 29, 2023

With @jspark1105 's commits enabling FP8 allgather, we can run test_te.py and also local training without PP.

However, if enabling PP, there are some issues with FP8 allgather that need to be fixed. This diff copies the changes from @jspark1105 's PR and includes the fixes we need in fairscale.

The fixes we need are as following (most fixes are naive and need better implementations):

  • When run model_chunk._rebuild_full_params_recursive() in xlformers/src/model_parallel_core/pipeline_parallel/fwd_bwd_schedules.py, we need to pass the FP8 training related settings into the context. All changes in xlformers are included in this commit.
    image

  • In TransformerEngine, we don't need to return the weight gradients for FP8 training since the gradients will be accumulated in .main_grad. All changes in TE are included in this commit.
    image

  • The FlattenParamsWrapper creates the view of the parameters every forward pass. It is unnecessary if we are not doing resharding after forward. Also, it creates a problem for FP8 allgather + PP because we create .main_grad in the beginning of the forward, and we can only access the last view of parameters. The earlier views of parameters are no longer accessable.
    image

  • We should not free the . _free_fp16_param_shard in the _post_backward_hook. The FP16 shard needs to be kept since each backward pass needs to use it.
    image

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 29, 2023
@jiecaoyu jiecaoyu changed the base branch from main to ngoyal_changes_for_pp_fp8 October 29, 2023 05:43
FP8GlobalStateManager.copy_amax_from_global_buffer(
m.fp8_meta, forward=True
)
# FIXME update_weight_scale_inv is only True for the first micro-batch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed it's OK to update weight_scale_inv multiple times but it can be still annoying if we see numerical differences btw no-PP and PP (actually just micro-batching). Actually I wonder we can check is_first_microbatch in kwargs to skip this.

@jiecaoyu jiecaoyu force-pushed the jiecaoyu_fp8allgather_debug branch from fb0b563 to 8bebf15 Compare November 9, 2023 05:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants