From e24336b5a772ab3aa6ad83527b880f9e5050ea2a Mon Sep 17 00:00:00 2001 From: Megha Agarwal <16129366+megha95@users.noreply.github.com> Date: Wed, 27 Mar 2024 13:01:46 -0700 Subject: [PATCH] [Model] Add support for DBRX (#3660) --- README.md | 1 + docs/source/models/supported_models.rst | 4 + requirements.txt | 1 + vllm/config.py | 5 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/dbrx.py | 421 ++++++++++++++++++++ vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/dbrx.py | 277 +++++++++++++ 9 files changed, 713 insertions(+) create mode 100644 vllm/model_executor/models/dbrx.py create mode 100644 vllm/transformers_utils/configs/dbrx.py diff --git a/README.md b/README.md index 9d3f7422..5fad3370 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) +- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.) - DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) - Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index bc7a21e7..7f80831b 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -27,6 +27,10 @@ Alongside each architecture, we include some popular models that use it. - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - ✅︎ + * - :code:`DbrxForCausalLM` + - DBRX + - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. + - * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. diff --git a/requirements.txt b/requirements.txt index 6d75067b..a85f5d2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ prometheus_client >= 0.18.0 pynvml == 11.5.0 triton >= 2.1.0 outlines == 0.0.34 +tiktoken == 0.6.0 # Required for DBRX tokenizer diff --git a/vllm/config.py b/vllm/config.py index 3ef9497e..baa37cda 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -277,6 +277,11 @@ class ModelConfig: # Currently, tensor parallelism is not supported in this case. return 1 + # For DBRX and MPT + if self.hf_config.model_type in ["dbrx", "mpt"]: + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + attributes = [ # For Falcon: "n_head_kv", diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index d5613168..edec642e 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -17,6 +17,7 @@ _MODELS = { "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py new file mode 100644 index 00000000..14c0fece --- /dev/null +++ b/vllm/model_executor/models/dbrx.py @@ -0,0 +1,421 @@ +# coding=utf-8 +from typing import List, Optional + +import torch +import torch.nn as nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.linear import (LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.dbrx import DbrxConfig + + +class DbrxRouter(nn.Module): + """A Router implementation for DBRX that returns logits for each expert + per token. + """ + + def __init__( + self, + config: DbrxConfig, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.ffn_config.moe_num_experts + self.d_model = config.d_model + self.layer = ReplicatedLinear( + self.d_model, + self.num_total_experts, + bias=False, + params_dtype=params_dtype, + linear_method=None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.layer(hidden_states) + return router_logits + + +class DbrxExperts(nn.Module): + """A tensor-parallel MoE implementation for DBRX. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: DbrxConfig, + linear_method: Optional[LinearMethodBase] = None, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.ffn_config.moe_num_experts + self.top_k = config.ffn_config.moe_top_k + self.d_model = config.d_model + self.intermediate_size = (config.ffn_config.ffn_hidden_size // + self.tp_size) + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = DbrxRouter(config, self.params_dtype) + self.ws = nn.Parameter( + torch.empty( + self.num_total_experts, + 2 * self.intermediate_size, + self.d_model, + device="cuda", + dtype=self.params_dtype, + )) + self.w2s = nn.Parameter( + torch.empty( + self.num_total_experts, + self.d_model, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype, + )) + + set_weight_attrs( + self.ws, + { + "weight_loader": self.weight_loader, + }, + ) + set_weight_attrs( + self.w2s, + { + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + # DBRX uses GLU for each experts. + # GLU has 3 linear layers: w1, v1 and w2. + if weight_name.endswith("w1"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :] + if weight_name.endswith("v1"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, + shard_size:2 * shard_size, :] = loaded_weight[:, + shard, :] + if weight_name.endswith("w2"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ).transpose(1, 2) + param_data[:] = loaded_weight[:, :, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.d_model) + # router_logits: (num_tokens, n_experts) + router_logits = self.router(hidden_states) + final_hidden_states = fused_moe( + hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True, + ) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class DbrxAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.total_num_kv_heads = config.attn_config.kv_n_heads + self.clip_qkv = config.attn_config.clip_qkv + self.rope_theta = config.attn_config.rope_theta + self.max_position = config.max_seq_len + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=False, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + self.tp_size = tp_world_size + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + ) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + hidden_states, _ = self.out_proj(attn_output) + return hidden_states + + +class DbrxFusedNormAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.d_model = config.d_model + self.attn = DbrxAttention(config, linear_method) + self.norm_1 = nn.LayerNorm(self.d_model) + self.norm_2 = nn.LayerNorm(self.d_model) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + x + residual = hidden_states + hidden_states = self.norm_2(hidden_states) + return hidden_states, residual + + +class DbrxBlock(nn.Module): + + def __init__( + self, + config: DbrxConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method) + self.ffn = DbrxExperts(config, linear_method) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states, residual = self.norm_attn_norm( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = self.ffn(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DbrxModel(nn.Module): + + def __init__( + self, + config: DbrxConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.blocks = nn.ModuleList( + [DbrxBlock(config, linear_method) for _ in range(config.n_layers)]) + self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, + nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.blocks)): + block = self.blocks[i] + hidden_states = block( + position_ids, + hidden_states, + kv_caches[i], + attn_metadata, + ) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class DbrxForCausalLM(nn.Module): + + def __init__( + self, + config: DbrxConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.unpadded_vocab_size = config.vocab_size + self.transformer = DbrxModel(config, linear_method) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + expert_params_mapping = [( + "ws" if weight_name in ["w1", "v1"] else "w2s", + f"experts.mlp.{weight_name}", + ) for weight_name in ["w1", "v1", "w2"]] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + for param_name, weight_name in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, weight_name) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c34ee10b..8a6ba6c5 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -6,6 +6,7 @@ from vllm.transformers_utils.configs import * _CONFIG_REGISTRY = { "chatglm": ChatGLMConfig, + "dbrx": DbrxConfig, "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 22220852..0e486928 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,4 +1,5 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.configs.dbrx import DbrxConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -8,6 +9,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig __all__ = [ "ChatGLMConfig", + "DbrxConfig", "MPTConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py new file mode 100644 index 00000000..3a19af71 --- /dev/null +++ b/vllm/transformers_utils/configs/dbrx.py @@ -0,0 +1,277 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py +"""Dbrx configuration.""" + +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*, defaults to None): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (float): The base frequency for rope. + """ + + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["attn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of + the activation function along with any additional keyword arguments. + ffn_hidden_size (int, optional): The hidden size of the feedforward network. + moe_num_experts (int, optional): The number of experts in the mixture of experts layer. + moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. + moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. + uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. + This should only be used for benchmarking purposes. + """ + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["ffn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxConfig(PretrainedConfig): + """Configuration class for Dbrx. + + [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 6144): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 32768): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models." + ) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + )