Skip to content

vllm.v1.attention.backends.mamba_attn

M module-attribute

M = TypeVar('M')

BaseMambaAttentionMetadataBuilder

Bases: AttentionMetadataBuilder[M], ABC

Source code in vllm/v1/attention/backends/mamba_attn.py
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
    reorder_batch_threshold: int = 1
    cudagraph_support: ClassVar[AttentionCGSupport] = (
        AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
    )

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)

        assert isinstance(kv_cache_spec, MambaSpec)
        self.compilation_config = vllm_config.compilation_config
        self.decode_cudagraph_max_bs = min(
            self.vllm_config.scheduler_config.max_num_seqs,
            self.compilation_config.max_capture_size,
        )
        self.state_indices_tensor = torch.empty(
            (self.decode_cudagraph_max_bs,),
            dtype=torch.int32,
            device=device,
        )

compilation_config instance-attribute

compilation_config = compilation_config

cudagraph_support class-attribute

decode_cudagraph_max_bs instance-attribute

decode_cudagraph_max_bs = min(
    max_num_seqs, max_capture_size
)

reorder_batch_threshold class-attribute instance-attribute

reorder_batch_threshold: int = 1

state_indices_tensor instance-attribute

state_indices_tensor = empty(
    (decode_cudagraph_max_bs,), dtype=int32, device=device
)

__init__

__init__(
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: device,
)
Source code in vllm/v1/attention/backends/mamba_attn.py
def __init__(
    self,
    kv_cache_spec: AttentionSpec,
    layer_names: list[str],
    vllm_config: VllmConfig,
    device: torch.device,
):
    super().__init__(kv_cache_spec, layer_names, vllm_config, device)

    assert isinstance(kv_cache_spec, MambaSpec)
    self.compilation_config = vllm_config.compilation_config
    self.decode_cudagraph_max_bs = min(
        self.vllm_config.scheduler_config.max_num_seqs,
        self.compilation_config.max_capture_size,
    )
    self.state_indices_tensor = torch.empty(
        (self.decode_cudagraph_max_bs,),
        dtype=torch.int32,
        device=device,
    )