Add files using upload-large-folder tool

This commit is contained in:
Shirong Ma
2025-09-22 12:59:34 +00:00
committed by system
parent 4d02631794
commit 846b34eb0f
50 changed files with 358875 additions and 0 deletions

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

68
config.json Normal file
View File

@@ -0,0 +1,68 @@
{
"architectures": [
"DeepseekV3ForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"attn_module_list_cfg": [],
"auto_map": {
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
"AutoModel": "modeling_deepseek.DeepseekV3Model",
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
},
"bos_token_id": 0,
"eos_token_id": 1,
"ep_size": 1,
"first_k_dense_replace": 3,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 18432,
"kv_lora_rank": 512,
"max_position_embeddings": 163840,
"model_type": "deepseek_v3",
"moe_intermediate_size": 2048,
"moe_layer_freq": 1,
"n_group": 8,
"n_routed_experts": 256,
"n_shared_experts": 1,
"norm_topk_prob": true,
"num_attention_heads": 128,
"num_experts_per_tok": 8,
"num_hidden_layers": 61,
"num_key_value_heads": 128,
"num_nextn_predict_layers": 1,
"q_lora_rank": 1536,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"quantization_config": {
"activation_scheme": "dynamic",
"fmt": "e4m3",
"quant_method": "fp8",
"weight_block_size": [
128,
128
]
},
"rms_norm_eps": 1e-06,
"rope_scaling": {
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn"
},
"rope_theta": 10000,
"routed_scaling_factor": 2.5,
"scoring_func": "sigmoid",
"tie_word_embeddings": false,
"topk_group": 4,
"topk_method": "noaux_tc",
"torch_dtype": "bfloat16",
"transformers_version": "4.44.2",
"use_cache": true,
"v_head_dim": 128,
"vocab_size": 129280
}

199
configuration_deepseek.py Normal file
View File

@@ -0,0 +1,199 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class DeepseekV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV3Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size = 2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts = 1,
n_routed_experts = 256,
ep_size = 1,
routed_scaling_factor = 2.5,
kv_lora_rank = 512,
q_lora_rank = 1536,
qk_rope_head_dim = 64,
v_head_dim = 128,
qk_nope_head_dim = 128,
topk_method = 'noaux_tc',
n_group = 8,
topk_group = 4,
num_experts_per_tok = 8,
moe_layer_freq = 1,
first_k_dense_replace = 3,
norm_topk_prob = True,
scoring_func = 'sigmoid',
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

9
generation_config.json Normal file
View File

@@ -0,0 +1,9 @@
{
"_from_model_config": true,
"bos_token_id": 0,
"eos_token_id": 1,
"do_sample": true,
"temperature": 0.6,
"top_p": 0.95,
"transformers_version": "4.46.3"
}

13
inference/README.md Normal file
View File

@@ -0,0 +1,13 @@
# DeepSeek V3.1
First convert huggingface model weight files to the format of this project.
```bash
export EXPERTS=256
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
```
Then chat with DeepSeek model at will!
```bash
export CONFIG=config_671B_v3.1.json
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive --temperature {T}
```

View File

@@ -0,0 +1,23 @@
{
"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"moe_inter_dim": 2048,
"n_layers": 61,
"n_dense_layers": 3,
"n_heads": 128,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 8,
"n_expert_groups": 8,
"n_limited_groups": 4,
"route_scale": 2.5,
"score_func": "sigmoid",
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"dtype": "fp8",
"scale_fmt": "ue8m0"
}

96
inference/convert.py Normal file
View File

@@ -0,0 +1,96 @@
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate": ("gate", None),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"norm": ("norm", None),
"lm_head": ("head", 0),
"scale": ("scale", None),
}
def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)

186
inference/generate.py Normal file
View File

@@ -0,0 +1,186 @@
import os
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(
model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("load model")
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
print("I'm DeepSeek 👋")
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)

274
inference/kernel.py Normal file
View File

@@ -0,0 +1,274 @@
import torch
import tilelang
import tilelang.language as T
from typing import Tuple, Optional
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
}
FP8 = "float8_e4m3"
BF16 = "bfloat16"
FP32 = "float32"
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(
N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False
):
M = T.symbolic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(
x_local[i, j] / s_local[i], fp8_min, fp8_max
)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(
x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert x.size(-1) % block_size == 0, (
f"Last dimension size must be divisible by block_size (block_size={block_size})"
)
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"]
M = T.symbolic("M")
group_size = 128
block_M = 32
block_N = 128
block_K = 128
@T.prim_func
def fp8_gemm_kernel_(
A: T.Tensor[(M, K), FP8],
B: T.Tensor[(N, K), FP8],
C: T.Tensor[(M, N), out_dtype],
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), FP8)
B_shared = T.alloc_shared((block_N, block_K), FP8)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
Scale_C_shared = T.alloc_shared((block_M), FP32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return fp8_gemm_kernel_
def fp8_gemm(
a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor
) -> torch.Tensor:
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert a_s.is_contiguous() and b_s.is_contiguous(), (
"Scaling factor tensors must be contiguous"
)
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
return c
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.symbolic("b")
m = T.symbolic("m")
n = T.symbolic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
logits = T.alloc_fragment((blk_n2, h), FP32)
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
return fp8_index_kernel_
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)

822
inference/model.py Normal file
View File

@@ -0,0 +1,822 @@
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from kernel import act_quant, fp8_gemm
world_size = 1
rank = 0
block_size = 128
@dataclass
class ModelArgs:
"""
Data class for defining model arguments and hyperparameters.
Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
scale_fmt (Optional[str]): Format for quantization scale.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
"""
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
scale_fmt: Optional[str] = None
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
# moe
n_routed_experts: int = 64
n_shared_experts: int = 2
n_activated_experts: int = 6
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.
class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.
Args:
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.
Args:
x (torch.Tensor): Input tensor containing token indices.
Returns:
torch.Tensor: Embedded representations.
Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None,
scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.
Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
scale_fmt (Optional[str]): The format of scaling factors.
Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.
Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
"""
assert bias is None
if weight.dtype != torch.float8_e4m3fn:
return F.linear(x, weight)
else:
x, scale = act_quant(x, block_size, scale_fmt)
return fp8_gemm(x, scale, weight, weight.scale)
class Linear(nn.Module):
"""
Custom linear layer with support for quantized weights and optional bias.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16
scale_fmt: Optional[str] = None
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
else:
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias, self.scale_fmt)
class ColumnParallelLinear(Linear):
"""
Linear layer with column parallelism, splitting output features across distributed processes.
Args:
in_features (int): Number of input features.
out_features (int): Total number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for column parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with column-parallel computation.
"""
y = linear(x, self.weight, self.bias, self.scale_fmt)
return y
class RowParallelLinear(Linear):
"""
Linear layer with row parallelism, splitting input features across distributed processes.
Args:
in_features (int): Total number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, reduce_output = True, dtype = None):
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size
self.reduce_output = reduce_output
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for row parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with row-parallel computation.
"""
y = linear(x, self.weight, None, self.scale_fmt)
if self.reduce_output and world_size > 1:
y = y.float()
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y.type_as(x)
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization (RMSNorm).
Args:
dim (int): Dimension of the input tensor.
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
"""
Forward pass for RMSNorm.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
dtype = x.dtype
if residual is None:
x = x.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
else:
x = residual = x.float() + residual.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype), residual.to(dtype)
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
"""
Precomputes frequency-based complex exponential values for rotary positional embeddings.
Args:
args (ModelArgs): Model arguments containing positional embedding parameters.
Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor
def find_correction_dim(num_rotations, dim, base, max_seq_len):
"""
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
Args:
num_rotations (float): Number of rotations to compute the correction for.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
float: The correction dimension based on the input parameters.
"""
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
"""
Computes the range of correction dimensions for rotary positional embeddings.
Args:
low_rot (float): Lower bound for the number of rotations.
high_rot (float): Upper bound for the number of rotations.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1)
def linear_ramp_factor(min, max, dim):
"""
Computes a linear ramp function used to smooth values between a minimum and maximum range.
Args:
min (float): Minimum value for the ramp function.
max (float): Maximum value for the ramp function.
dim (int): Dimensionality of the ramp tensor.
Returns:
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
clamped to the range [0, 1].
"""
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.
Args:
x (torch.Tensor): Input tensor with positional embeddings to be applied.
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
y = torch.view_as_real(x * freqs_cis).flatten(3)
return y.to(dtype)
def weight_dequant(weight, scale):
original_shape = weight.shape
assert weight.dim() == 2
weight = weight.view(original_shape[0] // block_size, block_size, original_shape[1] // block_size, block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(original_shape[0] // block_size, original_shape[1] // block_size, block_size, block_size).transpose(1, 2).contiguous().view(original_shape)
return weight
class MLA(nn.Module):
"""
Multi-Head Latent Attention (MLA) Layer.
Attributes:
dim (int): Dimensionality of the input features.
n_heads (int): Number of attention heads.
n_local_heads (int): Number of local attention heads for distributed systems.
q_lora_rank (int): Rank for low-rank query projection.
kv_lora_rank (int): Rank for low-rank key/value projection.
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
qk_head_dim (int): Total dimensionality of query/key projections.
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
if self.q_lora_rank == 0:
self.wq = ColumnParallelLinear(self.dim, self.n_heads * self.qk_head_dim)
else:
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
self.dequant_wkv_b = None
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
if self.q_lora_rank == 0:
q = self.wq(x)
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv = self.kv_norm(kv)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
self.kv_cache[:bsz, start_pos:end_pos] = kv
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
if mask is not None: # MHA prefill
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(kv)
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
scores = torch.einsum("bshd,bthd->bsht", q, k) * self.softmax_scale
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
else: # MHA decode
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
class MLP(nn.Module):
"""
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
"""
Initializes the MLP layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
self.w3 = ColumnParallelLinear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLP layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after MLP computation.
"""
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
class Gate(nn.Module):
"""
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
Attributes:
dim (int): Dimensionality of input features.
topk (int): Number of top experts activated for each input.
n_groups (int): Number of groups for routing.
topk_groups (int): Number of groups to route inputs to.
score_func (str): Scoring function ('softmax' or 'sigmoid').
route_scale (float): Scaling factor for routing weights.
weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Gate module.
Args:
args (ModelArgs): Model arguments containing gating parameters.
"""
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.n_groups = args.n_expert_groups
self.topk_groups = args.n_limited_groups
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for the gating mechanism.
Args:
x (torch.Tensor): Input tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
"""
scores = linear(x.float(), self.weight.float())
if self.score_func == "softmax":
scores = scores.softmax(dim=-1)
else:
scores = scores.sigmoid()
original_scores = scores
if self.bias is not None:
scores = scores + self.bias
if self.n_groups > 1:
scores = scores.view(x.size(0), self.n_groups, -1)
if self.bias is None:
group_scores = scores.amax(dim=-1)
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
indices = scores.topk(self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights, indices
class Expert(nn.Module):
"""
Expert layer for Mixture-of-Experts (MoE) models.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
"""
Initializes the Expert layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Expert layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert computation.
"""
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
class MoE(nn.Module):
"""
Mixture-of-Experts (MoE) module.
Attributes:
dim (int): Dimensionality of input features.
n_routed_experts (int): Total number of experts in the model.
n_local_experts (int): Number of experts handled locally in distributed systems.
n_activated_experts (int): Number of experts activated for each input.
gate (nn.Module): Gating mechanism to route inputs to experts.
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.
Args:
args (ModelArgs): Model arguments containing MoE parameters.
"""
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x, dtype=torch.float32)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
y += self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return y.type_as(x).view(shape)
class Block(nn.Module):
"""
Transformer block combining attention and feed-forward layers.
Attributes:
attn (nn.Module): Attention layer (MLA).
ffn (nn.Module): Feed-forward network (MLP or MoE).
attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network.
"""
def __init__(self, layer_id: int, args: ModelArgs):
"""
Initializes the Transformer block.
Args:
layer_id (int): Layer index in the transformer.
args (ModelArgs): Model arguments containing block parameters.
"""
super().__init__()
self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the Transformer block.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position in the sequence.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor after block computation.
"""
if residual is None:
x, residual = self.attn_norm(x), x
else:
x, residual = self.attn_norm(x, residual)
x = self.attn(x, start_pos, freqs_cis, mask)
x, residual = self.ffn_norm(x, residual)
x = self.ffn(x)
return x, residual
class Transformer(nn.Module):
"""
Transformer model with positional embeddings, multiple layers, and output projection.
Attributes:
max_seq_len (int): Maximum sequence length for the transformer.
embed (nn.Module): Embedding layer for input tokens.
layers (torch.nn.ModuleList): List of transformer blocks.
norm (nn.Module): Layer normalization applied after all blocks.
head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Transformer model.
Args:
args (ModelArgs): Model arguments containing transformer parameters.
"""
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
Linear.scale_fmt = args.scale_fmt
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
"""
Forward pass for the Transformer model.
Args:
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
seqlen = tokens.size(1)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
h, residual = self.embed(tokens), None
for layer in self.layers:
h, residual = layer(h, residual, start_pos, freqs_cis, mask)
h, _ = self.norm(h, residual)
logits = self.head(h[:, -1].float())
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.manual_seed(0)
args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())

View File

@@ -0,0 +1,4 @@
torch
transformers
safetensors
tilelang==0.1.6.post1

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6140e58b5e9c4c2bded0b7a5c13c1a5601eeda5977b4d1d85c63be4fde00c48d
size 5234241729

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:18056b701a75d11ec123573be29414222c22b1b7b64f814d98ec03df739a7030
size 4302383966

BIN
model-00003-of-000163.safetensors LFS Normal file

Binary file not shown.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:69efff813da01cca14d97f4f0a9629da7c9dfb3d0517290e4b4a063ab0c24ce2
size 4302382760

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:83dd2b06fc9a1c507b30715fc5930a09be483d004afc0a1b9080ae4f4b1de08b
size 4302384154

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b1641062ff86a782238632431b038047e9e2a909573fa04e6b20bd613dde5b70
size 4372106366

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:27a947b03266f71f9b01b25fd007ee0a5b5ca97271ee4abb60e8f2da22de9f41
size 4306080097

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b8f219fc5097f1c7ae9bd74d9128a9d631141d79541e3c3ffc040845b7a70b9f
size 4302384356

BIN
model-00009-of-000163.safetensors LFS Normal file

Binary file not shown.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:87e6176a77199fba2f9418d7549f0a1da111585909ed482d2f7a1598bce08aad
size 4302383960

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:548c7ce458921ac0707cbfe8a4483caf82b39c485cfd69ab49cca0de393d221e
size 4302384375

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4bb6adeb76bbf794e90d3f04ca5289b77b5a4b5c28b01ebe9d98a839b0d6fa17
size 1321612611

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c657d8f1cceb557dce5459e4635dc77cb0c04fdaaecdfa318729aae66e238ad0
size 4302321338

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0897fca2011cc006cfbf131b1a7edc5b7455af82b735ed775d5f5072e3a41edd
size 4302384328

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:469d5a28be557bcc8f4d328bbc9670093a2f059590f486c5d1919ded3fe2b7f0
size 4302382982

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b4ff6716edb896cefa13c32e420f9787c56ab46a4fe9fa1875a95f640e5615e8
size 4302383932

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ff57950923a935810125b6358b3c8eed9351570a8acbaeaae61cb653eb166d8f
size 4302384377

BIN
model-00018-of-000163.safetensors LFS Normal file

Binary file not shown.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:74c2b80c31d669d196476d5100c297b76d66a1a3a50f2347685a957c944eff4a
size 4302384124

BIN
model-00020-of-000163.safetensors LFS Normal file

Binary file not shown.

BIN
model-00021-of-000163.safetensors LFS Normal file

Binary file not shown.

BIN
model-00022-of-000163.safetensors LFS Normal file

Binary file not shown.

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3dcfaeb9a21afdc824b745e0cabb52d2f5fdea7129b04a7b09863af7aea992f1
size 4302383572

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8e3f77ec7623b9237737690970fbe5b30314641c39573f9f82600b7170ec4d27
size 4302384504

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:27f060310a2586c53eaf84bdea780e03c69e87f483e113a33f22a3e148c9ec0d
size 4302384961

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba9d0b302fa7a5f8132ba634fdbc05931e0770c9369eccf3eeec99bab171ee1c
size 4302383384

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:90763946ff5ba7c5ab8145f0ba0e1155f8ff6123c8762114ce226f52c495b20c
size 4302384692

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:92394afebca8edff5ec040dbd38bc7d9e38d92bb421a86cd6ef29b49d24f394f
size 4302384963

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:77036922f1abda7d779c1787c001540d2cebd57914b703d207a8411a6cac32e6
size 4302383212

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f136b1509d30045c6c0db6dbd44007d90c6836353f93b42e33a925c6e7fe047
size 4302384884

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1a93e7c380c5df0caa2ef52bffb1c1169e7c3dd1b6430d772e8a4ba54e077361
size 4302383588

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:50edb64a3101ad5c67c6faeacdef9dc263094e7d1c21b0ac35f09875798b3380
size 4302384488

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2bc08322d21ff7038d0f5e1bba6f9ce9588d75632652beab2636430995a93c81
size 4302384963

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:001cb7d92e64af6d7244abbca174dbe8670a70d35e3eacfe23e690c8fbd7ca38
size 1747446144

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f08b2003c703566a172cb8fcf49a994c8902e2f190bcd3833eb763894fc637f8
size 4302321911

91998
model.safetensors.index.json Normal file

File diff suppressed because it is too large Load Diff

1848
modeling_deepseek.py Normal file

File diff suppressed because it is too large Load Diff

263174
tokenizer.json Normal file

File diff suppressed because it is too large Load Diff

35
tokenizer_config.json Normal file
View File

@@ -0,0 +1,35 @@
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<begin▁of▁sentence>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"legacy": true,
"model_max_length": 131072,
"pad_token": {
"__type": "AddedToken",
"content": "<end▁of▁sentence>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"sp_model_kwargs": {},
"unk_token": null,
"tokenizer_class": "LlamaTokenizerFast",
"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% if not thinking is defined %}{% set thinking = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<User>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- if ns.is_last_user %}{{'<Assistant></think>'}}{%- endif %}{%- set ns.is_last_user = false -%}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}{%- else %}{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments'] + '<tool▁call▁end>'}}{%- endif %}{%- endfor %}{{'<tool▁calls▁end><end▁of▁sentence>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}{%- if ns.is_last_user %}{{'<Assistant>'}}{%- if message['prefix'] is defined and message['prefix'] and thinking %}{{'<think>'}} {%- else %}{{'</think>'}}{%- endif %}{%- endif %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{message['content'] + '<end▁of▁sentence>'}}{%- set ns.is_tool = false -%}{%- else %}{%- set content = message['content'] -%}{%- if '</think>' in content %}{%- set content = content.split('</think>', 1)[1] -%}{%- endif %}{{content + '<end▁of▁sentence>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}{%- endif %}{%- endfor -%}{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}{{'<Assistant>'}}{%- if not thinking %}{{'</think>'}}{%- else %}{{'<think>'}}{%- endif %}{% endif %}"
}