1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
| def init_method(tensor, **kwargs): nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" bs, n_kv_heads, slen, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, None, :, :] .expand(bs, n_kv_heads, n_rep, slen, head_dim) .reshape(bs, n_kv_heads * n_rep, slen, head_dim) )
def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth)
class MultiheadFlashDiff1(nn.Module): """ (Recommended) DiffAttn implemented with FlashAttention, for packages that support different qk/v dimensions e.g., our customized-flash-attention (https://aka.ms/flash-diff) and xformers (https://github.com/facebookresearch/xformers) """ def __init__( self, args, embed_dim, depth, num_heads, ): super().__init__() self.args = args self.embed_dim = embed_dim self.num_heads = num_heads // args.model_parallel_size self.num_kv_heads = args.decoder_kv_attention_heads // args.model_parallel_size if args.decoder_kv_attention_heads is not None else num_heads // args.model_parallel_size self.n_rep = self.num_heads // self.num_kv_heads self.head_dim = embed_dim // num_heads // 2 self.scaling = self.head_dim ** -0.5 self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.lambda_init = lambda_init_fn(depth) self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1)) self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=False) def forward( self, x, rel_pos, attn_mask=None, ): bsz, tgt_len, embed_dim = x.size() src_len = tgt_len
q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x)
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim) k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim) v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim) q = apply_rotary_emb(q, *rel_pos, interleaved=True) k = apply_rotary_emb(k, *rel_pos, interleaved=True)
offset = src_len - tgt_len q = q.reshape(bsz, tgt_len, self.num_heads, 2, self.head_dim) k = k.reshape(bsz, src_len, self.num_kv_heads, 2, self.head_dim) q1, q2 = q[:, :, :, 0], q[:, :, :, 1] k1, k2 = k[:, :, :, 0], k[:, :, :, 1] attn1 = flash_attn_func(q1, k1, v, causal=True) attn2 = flash_attn_func(q2, k2, v, causal=True) lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q) lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q) lambda_full = lambda_1 - lambda_2 + self.lambda_init attn = attn1 - lambda_full * attn2
attn = self.subln(attn) attn = attn * (1 - self.lambda_init) attn = attn.reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim) attn = self.out_proj(attn) return attn
|