fix indexer rope

This commit is contained in:
GeeeekExplorer
2025-11-17 12:45:31 +08:00
parent 9d2f599343
commit 1938e9df3d

View File

@@ -2,7 +2,6 @@ import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
@@ -282,6 +281,7 @@ class RMSNorm(nn.Module):
super().__init__()
self.dim = dim
self.eps = eps
# rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient.
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
@@ -315,6 +315,7 @@ class LayerNorm(nn.Module):
super().__init__()
self.dim = dim
self.eps = eps
# layernorm in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
@@ -403,7 +404,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor, interleaved: bool = True) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.
@@ -415,9 +416,14 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
shape = x.shape
if not interleaved:
x = x.view(*shape[:-1], 2, -1).transpose(-1, -2).contiguous()
x = torch.view_as_complex(x.float().view(*shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
y = torch.view_as_real(x * freqs_cis).flatten(3)
if not interleaved:
y = torch.cat([y[..., 0::2], y[..., 1::2]], dim=-1)
return y.to(dtype)
@@ -441,7 +447,8 @@ class Indexer(torch.nn.Module):
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
self.wk = Linear(self.dim, self.head_dim)
self.k_norm = LayerNorm(self.head_dim)
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
# weights_proj in the checkpoint is stored in bf16, while the parameters here are stored in fp32 for convenient.
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.float32)
self.softmax_scale = self.head_dim ** -0.5
self.scale_fmt = args.scale_fmt
@@ -453,14 +460,16 @@ class Indexer(torch.nn.Module):
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
q = self.wq_b(qr)
q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
q_pe, q_nope = torch.split(q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
# rope in indexer is not interleaved
q_pe = apply_rotary_emb(q_pe, freqs_cis, False)
q = torch.cat([q_pe, q_nope], dim=-1)
k = self.wk(x)
k = self.k_norm(k)
k_pe, k_nope = torch.split(k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
# rope in indexer is not interleaved
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis, False).squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1)
q = rotate_activation(q)
k = rotate_activation(k)
@@ -468,7 +477,7 @@ class Indexer(torch.nn.Module):
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
weights = self.weights_proj(x) * self.n_heads ** -0.5
weights = self.weights_proj(x.float()) * self.n_heads ** -0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
index_score = fp8_index(q_fp8.contiguous(), weights, self.k_cache[:bsz, :end_pos].contiguous(), self.k_scale_cache[:bsz, :end_pos].contiguous())
if mask is not None:
@@ -524,6 +533,7 @@ class MLA(nn.Module):
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
self.scale_fmt = args.scale_fmt
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
@@ -558,6 +568,9 @@ class MLA(nn.Module):
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv = self.kv_norm(kv)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
# we use fp8 kv cache in actual deployment, so here we simulate the precision by casting kv to fp8 and then back to bf16.
kv_fp8, kv_scale = act_quant(kv, block_size, self.scale_fmt)
kv = (kv_fp8.view(-1, block_size).float() * kv_scale.view(-1, 1)).to(kv.dtype).view_as(kv)
self.kv_cache[:bsz, start_pos:end_pos] = kv
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
if mask is not None: # MHA prefill
@@ -566,7 +579,7 @@ class MLA(nn.Module):
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
scores = torch.einsum("bshd,bthd->bsht", q, k).mul_(self.softmax_scale)
# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
@@ -574,24 +587,24 @@ class MLA(nn.Module):
index_mask += mask
scores += index_mask.unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
else: # MHA decode
scores = scores.softmax(dim=-1)
x = torch.einsum("bsht,bthd->bshd", scores, v)
else: # MQA decode
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
scores = (torch.einsum("bshc,btc->bsht", q_nope.float(), self.kv_cache[:bsz, :end_pos].float()) +
torch.einsum("bshr,btr->bsht", q_pe.float(), self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, 1, end_pos), float("-inf"), device=x.device).scatter_(-1, topk_indices, 0)
scores += index_mask.unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
scores = scores.softmax(dim=-1)
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x