1.DIFFERENTIAL TRANSFORMER

transformer问题: 过度关注不相关的上下文 substantiate 证实 negligible 微不足道的 drowns out 淹没

改进:
1.用差分注意力替换了传统的softmax注意力
2.采用pre-RMSNorm和SwiGLU作为LLaMA的改进

论文中提到了其他技术
Group Normalization: 参考文章全面解读Group Normalization-(吴育昕-何恺明)
RMSNorm: 参考文章Llama改进之——均方根层归一化RMSNorm
SwiGLU: 参考文章SwiGLU

代码: 这里调用了flash_attn包

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def init_method(tensor, **kwargs):
nn.init.kaiming_uniform_(tensor, a=math.sqrt(5))

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
bs, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)

def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)


class MultiheadFlashDiff1(nn.Module):
"""
(Recommended)
DiffAttn implemented with FlashAttention, for packages that support different qk/v dimensions
e.g., our customized-flash-attention (https://aka.ms/flash-diff) and xformers (https://github.com/facebookresearch/xformers)
"""
def __init__(
self,
args,
embed_dim,
depth,
num_heads,
):
super().__init__()
self.args = args
self.embed_dim = embed_dim
# num_heads set to half of Transformer's #heads
self.num_heads = num_heads // args.model_parallel_size
self.num_kv_heads = args.decoder_kv_attention_heads // args.model_parallel_size if args.decoder_kv_attention_heads is not None else num_heads // args.model_parallel_size
self.n_rep = self.num_heads // self.num_kv_heads

self.head_dim = embed_dim // num_heads // 2 # 这里设置了一半的注意力
self.scaling = self.head_dim ** -0.5

self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim // self.n_rep, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

self.lambda_init = lambda_init_fn(depth) # 对于不同的层数不同的初始化数值
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=False) # GroupNorm setup

def forward(
self,
x,
rel_pos,
attn_mask=None,
):
bsz, tgt_len, embed_dim = x.size()
src_len = tgt_len

q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)

q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim)
# 旋转位置编码
q = apply_rotary_emb(q, *rel_pos, interleaved=True)
k = apply_rotary_emb(k, *rel_pos, interleaved=True)

offset = src_len - tgt_len
q = q.reshape(bsz, tgt_len, self.num_heads, 2, self.head_dim)
k = k.reshape(bsz, src_len, self.num_kv_heads, 2, self.head_dim)
q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
attn1 = flash_attn_func(q1, k1, v, causal=True)
attn2 = flash_attn_func(q2, k2, v, causal=True)

lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
attn = attn1 - lambda_full * attn2

attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)

attn = self.out_proj(attn)
return attn

2.HYPER-CONNECTIONS

提出了hyper-connections,用来解决残差连接中梯度消失和表示崩溃的跷跷板效应

目前这篇论文还没有看懂?

3.Reverse Modeling in Large Language Models