mirror of
https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp
synced 2025-12-18 12:59:20 +08:00
fix indexer rope
This commit is contained in:
@@ -2,7 +2,6 @@ import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Optional, Literal
|
||||
|
||||
from einops import rearrange
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
@@ -282,6 +281,7 @@ class RMSNorm(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
|
||||
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
|
||||
@@ -315,6 +315,7 @@ class LayerNorm(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
# layernorm in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
|
||||
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
|
||||
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
|
||||
|
||||
@@ -403,7 +404,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
Applies rotary positional embeddings to the input tensor.
|
||||
|
||||
@@ -415,9 +416,14 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
torch.Tensor: Tensor with rotary embeddings applied.
|
||||
"""
|
||||
dtype = x.dtype
|
||||
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
|
||||
shape = x.shape
|
||||
if not interleaved:
|
||||
x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
|
||||
x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
|
||||
y = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
if not interleaved:
|
||||
y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
|
||||
return y.to(dtype)
|
||||
|
||||
|
||||
@@ -441,7 +447,8 @@ class Indexer(torch.nn.Module):
|
||||
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
|
||||
self.wk = Linear(self.dim, self.head_dim)
|
||||
self.k_norm = LayerNorm(self.head_dim)
|
||||
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
|
||||
# weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
|
||||
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
|
||||
self.softmax_scale = self.head_dim ** -0.5
|
||||
self.scale_fmt = args.scale_fmt
|
||||
|
||||
@@ -453,14 +460,16 @@ class Indexer(torch.nn.Module):
|
||||
bsz, seqlen, _ = x.size()
|
||||
end_pos = start_pos + seqlen
|
||||
q = self.wq_b(qr)
|
||||
q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
|
||||
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
||||
# rope in indexer is not interleaved
|
||||
q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
k = self.wk(x)
|
||||
k = self.k_norm(k)
|
||||
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
|
||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
|
||||
# rope in indexer is not interleaved
|
||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
|
||||
k = torch.cat([k_pe, k_nope], dim=-1)
|
||||
q = rotate_activation(q)
|
||||
k = rotate_activation(k)
|
||||
@@ -468,7 +477,7 @@ class Indexer(torch.nn.Module):
|
||||
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
|
||||
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
|
||||
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
|
||||
weights = self.weights_proj(x) * self.n_heads ** -0.5
|
||||
weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
|
||||
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
||||
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
|
||||
if mask is not None:
|
||||
@@ -524,6 +533,7 @@ class MLA(nn.Module):
|
||||
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
||||
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
||||
self.softmax_scale = self.qk_head_dim ** -0.5
|
||||
self.scale_fmt = args.scale_fmt
|
||||
if args.max_seq_len > args.original_seq_len:
|
||||
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
||||
@@ -558,6 +568,9 @@ class MLA(nn.Module):
|
||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv = self.kv_norm(kv)
|
||||
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
|
||||
# we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
|
||||
kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
|
||||
kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
|
||||
self.kv_cache[:bsz, start_pos:end_pos] = kv
|
||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
||||
if mask is not None: # MHA prefill
|
||||
@@ -566,7 +579,7 @@ class MLA(nn.Module):
|
||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
||||
scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
|
||||
scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
|
||||
|
||||
# indexer
|
||||
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
||||
@@ -574,24 +587,24 @@ class MLA(nn.Module):
|
||||
index_mask += mask
|
||||
scores += index_mask.unsqueeze(2)
|
||||
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
|
||||
else: # MHA decode
|
||||
scores = scores.softmax(dim=-1)
|
||||
x = torch.einsum("bsht,bthd->bshd", scores, v)
|
||||
else: # MQA decode
|
||||
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
|
||||
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
|
||||
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
|
||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
||||
scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float()) +
|
||||
torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
|
||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
||||
|
||||
# indexer
|
||||
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
|
||||
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
|
||||
scores += index_mask.unsqueeze(2)
|
||||
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32)
|
||||
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
|
||||
scores = scores.softmax(dim=-1)
|
||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
||||
x = self.wo(x.flatten(2))
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user