--fp16
, 对应megatron/arguments.py
中的定义如下: group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
lm-cross-entropy
时默认是使用fp32来计算的,在开启--fp16
选项的前提下可以通过指定--fp16-lm-cross-entropy
来使用fp16计算lm-loss-entropy
,对应megatron/arguments.py
中的定义如下: group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
args.fp32_residual_connection
,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。 if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
validate_args
函数用于check参数有效性,fp16相关实现如下:def validate_args(args, defaults={}):
......
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
......
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
......
如果指定了fp16,这里的args.fp16
为True,对应的args.params_dtype
参数类型为torch.half
。
ParallelAttention中有self.query_key_value
、self.core_attention
和self.dense
等子模块,fp16对训练的影响会应用在子模块中。
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
...
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
...
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
...
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
bias=args.add_bias_linear,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs())
对于self.query_key_value
和self.dense
模块,fp16的设置能过参数中的**_args_to_kwargs()
进行传递。
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
对于self.core_attention
部分,fp16的设置是在CoreAttention
的__init__
中self.fp16 = args.fp16
。
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
...
在ParallelAttention
模块本身中fp16会影响推理部分
class ParallelAttention(MegatronModule):
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
...
self.params_dtype = args.params_dtype
...
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
...
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
...
ParallelAttention
模型__init__
初始化时会设置参数类型self.params_dtype
为fp16_allocate_memory
中会用torch.empty
创建用于推理的大buffer,类型是fp16inference_params
时,forward函数中会调用_allocate_memory
当设了fp16以后,在CoreAttention
的forward计算的input就是fp16类型,在init中设置fp16 flag主要是用于计算中用到的FusedScaleMaskSoftmax
模块的输出结果类型转换。
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
...
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
...
当FusedScaleMaskSoftmax
执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax
;对于不支持的规模时,会调用forward_torch_softmax
进行模拟,输出的类型就根据self.input_in_float16
来进行cast转换。
class FusedScaleMaskSoftmax(nn.Module):
...
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
在ColumnParallelLinear
初始化时创建Parameter中的类型直接按params_dtype(即fp16)
来设。
class ColumnParallelLinear(torch.nn.Module):
def __init__(self, ...,
params_dtype=torch.float32,
...,
):
...
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
...
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
...
以gpt2模型为例,在megatron/model/gpt_model.py
文件中的post_language_model_processing
函数, 如果指定了fp16_lm_cross_entropy
,那么在计算cross entropy
时会把output
先转为float32
再进行计算loss。
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)