Skip to content

vllm.v1.attention.backends.mla.flashmla_sparse

MIN_HEADS_FOR_BF16_PREFILL module-attribute

MIN_HEADS_FOR_BF16_PREFILL = 32

NOTE: FlashMLA Sparse uses an fp8 cache with the following format

For DeepSeek V3.2, in the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as: - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.

For DeepSeek V4, in the "FP8 with scale" format, each token's KV cache is 584 Bytes, structured as: - First 448 bytes: The "quantized NoPE" part, containing 448 float8_e4m3 values. - Next 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. - Last 8 bytes: Scale factors, containing 7 ue8m0 values + 1B pad. The first ue8m0 is the scale for the first 64 float8_e4m3 values, the second for the next 64, and so on.

FlashMLASparseImpl

Bases: SparseMLAAttentionImpl[FlashMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
    @staticmethod
    def _compute_fp8_decode_padded_heads(num_heads: int) -> int:
        # FP8 decode kernel only supports h_q = 64 or 128
        # Compute padded head count for decode
        return 64 if num_heads <= 64 else 128

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int,
        alibi_slopes: list[float] | None,
        sliding_window: int | None,
        kv_cache_dtype: str,
        logits_soft_cap: float | None,
        attn_type: str,
        kv_sharing_target_layer_name: str | None,
        # MLA Specific Arguments
        topk_indices_buffer: torch.Tensor | None = None,
        indexer: "Indexer | None" = None,
        **mla_args,
    ) -> None:
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = float(scale)
        self.num_kv_heads = num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype
        self.kv_lora_rank: int = mla_args["kv_lora_rank"]
        self.softmax_scale = scale
        assert indexer is not None
        self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
        # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell
        self.prefill_padding = (
            128 if current_platform.is_device_capability_family(100) else 64
        )
        self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)

        vllm_config = get_current_vllm_config()
        max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
        q_concat_shape = (max_tokens, num_heads, head_size)
        if is_quantized_kv_cache(kv_cache_dtype):
            assert kv_cache_dtype == "fp8_ds_mla", (
                "FlashMLA Sparse Attention backend fp8 only supports "
                "fp8_ds_mla kv-cache dtype"
            )

        if kv_cache_dtype == "fp8_ds_mla":
            # Reserve workspace during initialization
            assert vllm_config is not None and vllm_config.model_config is not None
            prefill_workspace_size = get_prefill_workspace_size(
                vllm_config.model_config.max_model_len
            )
            self.prefill_workspace_shape = (prefill_workspace_size, head_size)
            self.q_concat_buffer, self.prefill_bf16_workspace = (
                current_workspace_manager().get_simultaneous(
                    (q_concat_shape, torch.bfloat16),
                    (self.prefill_workspace_shape, torch.bfloat16),
                )
            )
        else:
            (self.q_concat_buffer,) = current_workspace_manager().get_simultaneous(
                (q_concat_shape, torch.bfloat16),
            )

    def _forward_bf16_kv(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        return self._bf16_flash_mla_kernel(
            q,
            kv_c_and_k_pe_cache,
            topk_indices,
        )

    def _forward_fp8_kv_separate_prefill_decode(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        fp8_metadata = attn_metadata.fp8_extra_metadata
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)
        num_decodes = fp8_metadata.num_decodes

        prefill_request_ids = None
        prefill_workspace_starts = None
        has_prefill_workspace = False
        if fp8_metadata.prefill is not None:
            prefill_request_ids = fp8_metadata.prefill.request_ids
            prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
            has_prefill_workspace = True

        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        # For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
        # For BF16 cache: always use global cache slots (no workspace)
        # prefill_workspace_starts has been adjusted in-place per chunk so
        # prefill indices automatically come out chunk-local
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
            HAS_PREFILL_WORKSPACE=has_prefill_workspace,
            prefill_workspace_request_ids=prefill_request_ids,
            prefill_workspace_starts=prefill_workspace_starts,
        )

        fp8_metadata = attn_metadata.fp8_extra_metadata
        assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeparatePrefillDecode)

        def _fp8_decode(
            q: torch.Tensor,
            topk_indices: torch.Tensor,
        ) -> torch.Tensor:
            # Reshape q: (num_decode_tokens, num_heads, head_dim)
            #         -> (num_decodes, seq_len, num_heads, head_dim)
            q = reshape_query_for_spec_decode(q, num_decodes)
            seq_len = q.shape[1]
            # Reshape topk_indices: (num_decode_tokens, topk)
            #                    -> (num_decodes, seq_len, topk)
            topk_indices = topk_indices.view(num_decodes, seq_len, -1)
            assert fp8_metadata.decode is not None
            attn_out, _ = self._fp8_flash_mla_kernel(
                q=q,
                kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
                topk_indices=topk_indices,
                kernel_metadata=fp8_metadata.decode.kernel_metadata,
            )
            # Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
            #              -> (num_decode_tokens, num_heads, head_dim_v)
            return reshape_attn_output_for_spec_decode(attn_out)

        num_decode_tokens = fp8_metadata.num_decode_tokens
        num_prefill_tokens = fp8_metadata.num_prefill_tokens

        # Pure decode: direct call without allocation
        if num_decode_tokens > 0 and num_prefill_tokens == 0:
            assert fp8_metadata.decode is not None
            attn_out = _fp8_decode(q, topk_indices)
        else:
            # Mixed or pure prefill: allocate output tensor
            attn_out = q.new_empty(
                (attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
                dtype=q.dtype,
                device=q.device,
            )

            if num_decode_tokens > 0:
                attn_out[:num_decode_tokens] = _fp8_decode(
                    q[:num_decode_tokens],
                    topk_indices[:num_decode_tokens],
                )

            assert fp8_metadata.prefill is not None
            for chunk in fp8_metadata.prefill.chunks:
                chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
                ops.cp_gather_and_upconvert_fp8_kv_cache(
                    kv_c_and_k_pe_cache,
                    chunk_workspace,
                    chunk.block_table,
                    chunk.seq_lens,
                    chunk.workspace_starts,
                    len(chunk.block_table),
                )

                chunk_q = q[chunk.tokens_slice]
                chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]

                attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
                    chunk_q,
                    chunk_workspace,
                    chunk_topk_indices_workspace,
                )

        return attn_out

    def _forward_fp8_kv_mixed_batch(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
    ) -> torch.Tensor:
        """Mixed batch FP8 forward path that treats all tokens as one batch.

        This is equivalent to main branch's approach and avoids the BF16
        prefill kernel which has head padding overhead when num_heads is small.
        Used when use_mixed_batch is True.
        """
        # Convert per-request indices to global slots (decode) or workspace
        # offsets (prefill).
        topk_indices = triton_convert_req_index_to_global_index(
            attn_metadata.req_id_per_token,
            attn_metadata.block_table,
            topk_indices,
            BLOCK_SIZE=attn_metadata.block_size,
            NUM_TOPK_TOKENS=topk_indices.shape[1],
        )

        assert attn_metadata.fp8_extra_metadata is not None
        assert isinstance(
            attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
        )
        fp8_metadata = attn_metadata.fp8_extra_metadata

        _attn_out, _ = self._fp8_flash_mla_kernel(
            q=q.unsqueeze(0),  # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
            kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
            topk_indices=topk_indices.unsqueeze(0),  # (T, topk) -> (1, T, topk)
            kernel_metadata=fp8_metadata,
        )

        # Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
        return _attn_out.squeeze(0)

    def _fp8_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
        kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # q shape: (batch, seq_len, num_heads, head_dim)
        actual_num_heads = q.size(2)
        padded_num_heads = self.fp8_decode_padded_heads

        # Pad query if needed (kernel only supports h_q = 64 or 128)
        if actual_num_heads < padded_num_heads:
            logger.warning_once(
                f"Padding num_heads from {actual_num_heads} to "
                f"{padded_num_heads} for FP8 sparse decode kernel"
            )
            q_padded = q.new_zeros((q.size(0), q.size(1), padded_num_heads, q.size(3)))
            q_padded[:, :, :actual_num_heads, :] = q
            q = q_padded

        out, lse = flash_mla_with_kvcache(
            q=q,
            k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
            block_table=kernel_metadata.dummy_block_table,
            head_dim_v=512,
            cache_seqlens=kernel_metadata.cache_lens,
            tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
            is_fp8_kvcache=True,
            indices=topk_indices,
            softmax_scale=self.softmax_scale,
        )

        # Slice output back to actual head count if we padded
        if actual_num_heads < padded_num_heads:
            out = out[:, :, :actual_num_heads, :]

        return out, lse

    def _bf16_flash_mla_kernel(
        self,
        q: torch.Tensor,
        kv_c_and_k_pe_cache: torch.Tensor,
        topk_indices: torch.Tensor,
    ) -> torch.Tensor:
        num_tokens = q.shape[0]
        kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
            -1, 1, kv_c_and_k_pe_cache.shape[-1]
        )

        # NOTE(Chen): kernel requires num_local_head to be a multiple of
        # 64 on hopper and 128 on blackwell
        if self.num_heads % self.prefill_padding != 0:
            assert self.prefill_padding % self.num_heads == 0
            logger.warning_once(
                f"Padding num_heads from {self.num_heads} to "
                f"{self.prefill_padding} for BF16 sparse prefill kernel"
            )
            q_padded = q.new_empty((q.shape[0], self.prefill_padding, q.shape[2]))
            q_padded[:, : self.num_heads, :] = q
            q = q_padded

        topk_indices = topk_indices.view(num_tokens, 1, -1)
        output = flash_mla_sparse_fwd(
            q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
        )[0]

        output = output[:, : self.num_heads, :]
        return output

    def forward_mqa(
        self,
        q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
        kv_c_and_k_pe_cache: torch.Tensor,
        attn_metadata: FlashMLASparseMetadata,
        layer: AttentionLayer,
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
        # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
        # MQA 576/512 approach for both prefill and decode

        # Concatenate q if it's a tuple (ql_nope, q_pe)
        if isinstance(q, tuple):
            ql_nope, q_pe = q
            q = self.q_concat_buffer[: ql_nope.shape[0]]
            ops.concat_mla_q(ql_nope, q_pe, q)

        num_actual_toks = q.shape[0]

        # Get topk indices
        assert self.topk_indices_buffer is not None
        topk_indices = self.topk_indices_buffer[:num_actual_toks]

        use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"

        if not use_fp8_cache:
            attn_out = self._forward_bf16_kv(
                q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
            )
        elif attn_metadata.fp8_use_mixed_batch:
            attn_out = self._forward_fp8_kv_mixed_batch(
                q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
            )
        else:
            attn_out = self._forward_fp8_kv_separate_prefill_decode(
                q, kv_c_and_k_pe_cache, topk_indices, attn_metadata
            )

        return attn_out, None

_forward_fp8_kv_mixed_batch

_forward_fp8_kv_mixed_batch(
    q: Tensor,
    kv_c_and_k_pe_cache: Tensor,
    topk_indices: Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> Tensor

Mixed batch FP8 forward path that treats all tokens as one batch.

This is equivalent to main branch's approach and avoids the BF16 prefill kernel which has head padding overhead when num_heads is small. Used when use_mixed_batch is True.

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _forward_fp8_kv_mixed_batch(
    self,
    q: torch.Tensor,
    kv_c_and_k_pe_cache: torch.Tensor,
    topk_indices: torch.Tensor,
    attn_metadata: FlashMLASparseMetadata,
) -> torch.Tensor:
    """Mixed batch FP8 forward path that treats all tokens as one batch.

    This is equivalent to main branch's approach and avoids the BF16
    prefill kernel which has head padding overhead when num_heads is small.
    Used when use_mixed_batch is True.
    """
    # Convert per-request indices to global slots (decode) or workspace
    # offsets (prefill).
    topk_indices = triton_convert_req_index_to_global_index(
        attn_metadata.req_id_per_token,
        attn_metadata.block_table,
        topk_indices,
        BLOCK_SIZE=attn_metadata.block_size,
        NUM_TOPK_TOKENS=topk_indices.shape[1],
    )

    assert attn_metadata.fp8_extra_metadata is not None
    assert isinstance(
        attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
    )
    fp8_metadata = attn_metadata.fp8_extra_metadata

    _attn_out, _ = self._fp8_flash_mla_kernel(
        q=q.unsqueeze(0),  # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
        kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
        topk_indices=topk_indices.unsqueeze(0),  # (T, topk) -> (1, T, topk)
        kernel_metadata=fp8_metadata,
    )

    # Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
    return _attn_out.squeeze(0)

FlashMLASparseMetadata dataclass

Bases: AttentionMetadata

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FlashMLASparseMetadata(AttentionMetadata):
    num_reqs: int
    max_query_len: int
    max_seq_len: int

    num_actual_tokens: int  # Number of tokens excluding padding.
    query_start_loc: torch.Tensor
    slot_mapping: torch.Tensor

    block_table: torch.Tensor
    req_id_per_token: torch.Tensor
    block_size: int = 64
    topk_tokens: int = 2048

    @dataclass
    class FP8KernelMetadata:
        scheduler_metadata: FlashMLASchedMeta
        dummy_block_table: torch.Tensor
        cache_lens: torch.Tensor

    @dataclass
    class FP8SeparatePrefillDecode:
        @dataclass
        class Decode:
            seq_lens: torch.Tensor
            kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
            decode_query_len: int  # needed for reshape in spec decode

        @dataclass
        class Prefill:
            # Sequence lengths (context + query) for prefill requests
            # Shape: [num_prefill_reqs]
            seq_lens: torch.Tensor

            # Request ID for each token: -1 for decode tokens, request index
            # (0, 1, 2, ...) for prefill tokens.
            # Shape: [num_actual_tokens]
            request_ids: torch.Tensor

            # Workspace start offsets for all prefill requests
            # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
            # 0-indexed within each chunk. Used to map prefill tokens to workspace
            # offsets in convert_logical_index_to_physical_index
            workspace_starts: torch.Tensor

            @dataclass
            class Chunk:
                """Metadata for a chunk of prefill requests.

                Prefill requests may be chunked to fit within the fixed workspace size.
                """

                seq_lens: torch.Tensor
                tokens_slice: slice
                block_table: torch.Tensor
                req_start_idx: int
                workspace_starts: torch.Tensor
                chunk_tot_seqlen: int

            chunks: list[Chunk]

        num_prefills: int = 0
        num_decodes: int = 0
        num_prefill_tokens: int = 0
        num_decode_tokens: int = 0

        decode: Decode | None = None
        prefill: Prefill | None = None

    fp8_extra_metadata: FP8SeparatePrefillDecode | FP8KernelMetadata | None = None
    fp8_use_mixed_batch: bool = False

    # Pre-computed C128A metadata (DeepseekV4 only, compress_ratio == 128).
    # Decode: global slot ids + valid-entry counts (fused from positions).
    c128a_global_decode_topk_indices: torch.Tensor | None = None
    c128a_decode_topk_lens: torch.Tensor | None = None
    # Prefill: local topk indices (used by combine_topk_swa_indices).
    c128a_prefill_topk_indices: torch.Tensor | None = None

FP8SeparatePrefillDecode dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class FP8SeparatePrefillDecode:
    @dataclass
    class Decode:
        seq_lens: torch.Tensor
        kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
        decode_query_len: int  # needed for reshape in spec decode

    @dataclass
    class Prefill:
        # Sequence lengths (context + query) for prefill requests
        # Shape: [num_prefill_reqs]
        seq_lens: torch.Tensor

        # Request ID for each token: -1 for decode tokens, request index
        # (0, 1, 2, ...) for prefill tokens.
        # Shape: [num_actual_tokens]
        request_ids: torch.Tensor

        # Workspace start offsets for all prefill requests
        # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
        # 0-indexed within each chunk. Used to map prefill tokens to workspace
        # offsets in convert_logical_index_to_physical_index
        workspace_starts: torch.Tensor

        @dataclass
        class Chunk:
            """Metadata for a chunk of prefill requests.

            Prefill requests may be chunked to fit within the fixed workspace size.
            """

            seq_lens: torch.Tensor
            tokens_slice: slice
            block_table: torch.Tensor
            req_start_idx: int
            workspace_starts: torch.Tensor
            chunk_tot_seqlen: int

        chunks: list[Chunk]

    num_prefills: int = 0
    num_decodes: int = 0
    num_prefill_tokens: int = 0
    num_decode_tokens: int = 0

    decode: Decode | None = None
    prefill: Prefill | None = None

Prefill dataclass

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class Prefill:
    # Sequence lengths (context + query) for prefill requests
    # Shape: [num_prefill_reqs]
    seq_lens: torch.Tensor

    # Request ID for each token: -1 for decode tokens, request index
    # (0, 1, 2, ...) for prefill tokens.
    # Shape: [num_actual_tokens]
    request_ids: torch.Tensor

    # Workspace start offsets for all prefill requests
    # Shape: [num_prefill_reqs], adjusted in-place per chunk to be
    # 0-indexed within each chunk. Used to map prefill tokens to workspace
    # offsets in convert_logical_index_to_physical_index
    workspace_starts: torch.Tensor

    @dataclass
    class Chunk:
        """Metadata for a chunk of prefill requests.

        Prefill requests may be chunked to fit within the fixed workspace size.
        """

        seq_lens: torch.Tensor
        tokens_slice: slice
        block_table: torch.Tensor
        req_start_idx: int
        workspace_starts: torch.Tensor
        chunk_tot_seqlen: int

    chunks: list[Chunk]
Chunk dataclass

Metadata for a chunk of prefill requests.

Prefill requests may be chunked to fit within the fixed workspace size.

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
@dataclass
class Chunk:
    """Metadata for a chunk of prefill requests.

    Prefill requests may be chunked to fit within the fixed workspace size.
    """

    seq_lens: torch.Tensor
    tokens_slice: slice
    block_table: torch.Tensor
    req_start_idx: int
    workspace_starts: torch.Tensor
    chunk_tot_seqlen: int

FlashMLASparseMetadataBuilder

Bases: AttentionMetadataBuilder[FlashMLASparseMetadata]

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

    def __init__(
        self,
        kv_cache_spec: AttentionSpec,
        layer_names: list[str],
        vllm_config: VllmConfig,
        device: torch.device,
    ) -> None:
        self.vllm_config = vllm_config
        self.layer_names = layer_names
        cache_config = vllm_config.cache_config
        self.kv_cache_spec = kv_cache_spec
        self.model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config
        self.device = device

        # Classify single-token queries (plus num_speculative_tokens via
        # supports_spec_as_decode=True) as decodes; longer queries go to
        # prefill.
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)

        sm_count = num_compute_units(device.index)

        self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
        self.mla_dims = get_mla_dims(self.model_config)
        # FP8 decode kernel only supports h_q = 64 or 128, so we need to pad
        self.fp8_decode_padded_heads = (
            FlashMLASparseImpl._compute_fp8_decode_padded_heads(self.num_heads)
        )

        self.topk_tokens = vllm_config.model_config.hf_config.index_topk
        self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
        max_num_seqs = vllm_config.scheduler_config.max_num_seqs
        # Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
        self.topk_tokens_tensor = torch.full(
            (max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
        )
        # Shape: [max_num_seqs], all elements = max_model_len
        self.max_model_len_tensor = torch.full(
            (max_num_seqs,),
            self.model_config.max_model_len,
            device=device,
            dtype=torch.int32,
        )
        # this is ignored by `flash_mla_with_kvcache` if indices not None
        self.dummy_block_table = torch.empty(
            (max_num_seqs, 1), dtype=torch.int32, device=self.device
        )

        # Equation taken from FlashMLA/csrc/api/sparse_decode.h
        # For sparse FP8 decode, the formula depends on architecture:
        # - SM90 (Hopper): num_sm_parts = num_sms / s_q / (h_q/64)
        # - SM100 (Blackwell head64/head64x2): num_sm_parts = num_sms / s_q
        # - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2
        # For max buffer size, use s_q = 1 (the case that produces largest output)
        # Use padded head count since that's what will be passed to the kernel
        h_q = self.fp8_decode_padded_heads
        if current_platform.is_device_capability_family(100):
            # SM100 head64 or head64x2 uses full SM count
            max_num_sm_parts = sm_count
        else:
            # SM90 uses h_q/64 divisor
            max_num_sm_parts = sm_count // max(1, h_q // 64)
        self.tile_scheduler_metadata_buffer = torch.empty(
            # TileSchedulerMetaDataSize = 8
            # see: FlashMLA/csrc/params.h
            (max_num_sm_parts, 8),
            dtype=torch.int32,
            device=device,
        )
        # Sized for per-request batching (num_decodes + 1)
        self.num_splits_buffer = torch.empty(
            (max_num_seqs + 1,),
            dtype=torch.int32,
            device=device,
        )
        self.req_id_per_token_buffer = torch.empty(
            (vllm_config.scheduler_config.max_num_batched_tokens,),
            dtype=torch.int32,
            device=device,
        )

        # DeepseekV4: has compress_ratios in hf_config.
        hf_config = vllm_config.model_config.hf_config
        self.is_deepseek_v4 = (
            hasattr(hf_config, "compress_ratios") and len(hf_config.compress_ratios) > 0
        )
        self.compress_ratio = 1
        if self.is_deepseek_v4:
            assert hasattr(self.kv_cache_spec, "compress_ratio")
            self.compress_ratio = self.kv_cache_spec.compress_ratio
            # Pre-allocate compressed slot mapping buffer for CUDA graph
            # address stability when compress_ratio > 1.
            if self.compress_ratio > 1:
                max_num_batched_tokens = (
                    vllm_config.scheduler_config.max_num_batched_tokens
                )
                self.compressed_slot_mapping_buffer = torch.empty(
                    max_num_batched_tokens,
                    dtype=torch.int64,
                    device=self.device,
                )

            # Pre-allocate C128A topk buffers for CUDA graph address stability.
            if self.compress_ratio == 128:
                max_num_batched_tokens = (
                    vllm_config.scheduler_config.max_num_batched_tokens
                )
                # Pad to B_TOPK alignment (128 covers both h_q=64 B_TOPK=64 and
                # h_q=128 B_TOPK=128). FlashMLA decode asserts extra_topk % B_TOPK
                # == 0; unaligned widths (e.g. 17 = ceil(2136/128)) crash the
                # sm100 head64 kernel. Padded slots stay -1 and decode_lens caps
                # them via topk_length, so the pad is a no-op at kernel level.
                # Mirrors _SPARSE_PREFILL_TOPK_ALIGNMENT in cache_utils.py.
                _C128A_TOPK_ALIGNMENT = 128
                c128a_max_compressed = cdiv(
                    self.model_config.max_model_len, self.compress_ratio
                )
                c128a_max_compressed = (
                    cdiv(c128a_max_compressed, _C128A_TOPK_ALIGNMENT)
                    * _C128A_TOPK_ALIGNMENT
                )
                # Stored so _build_c128a_metadata passes it as the kernel's
                # max_compressed_tokens, matching the buffer stride. Otherwise
                # the kernel's default 8192 iterates past row width and spills
                # writes into adjacent rows (present in both decode and prefill
                # branches of _build_c128a_topk_metadata_kernel).
                self.c128a_max_compressed = c128a_max_compressed
                self.c128a_global_decode_buffer = torch.empty(
                    (max_num_batched_tokens, c128a_max_compressed),
                    dtype=torch.int32,
                    device=self.device,
                )
                self.c128a_decode_lens_buffer = torch.empty(
                    max_num_batched_tokens,
                    dtype=torch.int32,
                    device=self.device,
                )
                self.c128a_prefill_buffer = torch.empty(
                    (max_num_batched_tokens, c128a_max_compressed),
                    dtype=torch.int32,
                    device=self.device,
                )

    def _build_fp8_mixed_decode_prefill(
        self,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> "FlashMLASparseMetadata.FP8KernelMetadata":
        """Build FP8 metadata treating all tokens as one mixed batch.

        This matches main branch's approach and avoids the BF16 prefill kernel
        which has head padding overhead when num_heads is small (high TP case).
        """
        num_tokens = common_attn_metadata.num_actual_tokens

        # Use padded head count since that's what the kernel will see
        padded_heads = self.fp8_decode_padded_heads

        # Build metadata for all tokens as a single batch
        scheduler_metadata, _ = get_mla_metadata(
            cache_seqlens=self.topk_tokens_tensor[:1],  # Single batch
            num_q_tokens_per_head_k=num_tokens * padded_heads,
            topk=self.topk_tokens,
            num_heads_q=padded_heads,
            num_heads_k=1,
            is_fp8_kvcache=True,
        )

        fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
            scheduler_metadata=scheduler_metadata,
            cache_lens=self.max_model_len_tensor[:1],
            dummy_block_table=self.dummy_block_table[:1],
        )

        return fp8_metadata

    def _build_fp8_separate_prefill_decode(
        self,
        common_attn_metadata: CommonAttentionMetadata,
    ) -> "FlashMLASparseMetadata.FP8SeparatePrefillDecode":
        num_tokens = common_attn_metadata.num_actual_tokens

        (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
            split_decodes_and_prefills(
                common_attn_metadata,
                decode_threshold=self.reorder_batch_threshold or 1,
                require_uniform=True,
            )
        )

        FP8Meta = FlashMLASparseMetadata.FP8SeparatePrefillDecode
        fp8_metadata = FP8Meta(
            num_decodes=num_decodes,
            num_prefills=num_prefills,
            num_decode_tokens=num_decode_tokens,
            num_prefill_tokens=num_prefill_tokens,
        )

        # Extract prefill sequence lengths (context + query, not just query)
        # Decode requests come first in the batch, prefill requests follow
        prefill_seq_lens = None
        prefill_request_id = None
        prefill_workspace_starts = None
        prefill_chunks = None

        # For pure decode batches, prefill_request_id will be None
        # For mixed batches, it will have -1 for decode and request_id for prefill
        if num_prefills > 0:
            # Upper bound is exact for prefill rows (the `[num_decodes:]`
            # slice below), so no D2H sync is needed.
            seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound
            assert seq_lens_cpu is not None
            seq_lens = common_attn_metadata.seq_lens
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu

            prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
            prefill_seq_lens = seq_lens[num_decodes:]

            # Build prefill_request_id: -1 for decode, request index for
            # prefill. This enables a single
            # convert_logical_index_to_physical_index call for all tokens
            prefill_request_id = torch.full(
                (num_tokens,), -1, dtype=torch.int32, device=self.device
            )
            # Map prefill tokens to their request IDs (0, 1, 2, ...)
            for req_idx in range(num_prefills):
                # Get query token range for this prefill request
                global_req_idx = num_decodes + req_idx
                req_query_start = query_start_loc_cpu[global_req_idx]
                req_query_end = query_start_loc_cpu[global_req_idx + 1]
                prefill_request_id[req_query_start:req_query_end] = req_idx

            # will be adjusted by chunk loop
            prefill_workspace_starts_cpu = torch.zeros(
                num_prefills, dtype=torch.int32, pin_memory=True
            )
            prefill_workspace_starts_cpu[1:] = torch.cumsum(
                prefill_seq_lens_cpu[:-1], dim=0
            )
            # populated by non-blocking copy after prefill_workspace_starts_cpu is
            # updated by each chunk
            prefill_workspace_starts = torch.empty(
                num_prefills, dtype=torch.int32, device=self.device
            )

            # Chunk prefill requests to fit within workspace size
            max_prefill_buffer_size = get_prefill_workspace_size(
                self.vllm_config.model_config.max_model_len
            )
            chunk_bounds = split_prefill_chunks(
                prefill_seq_lens_cpu, max_prefill_buffer_size
            )

            prefill_chunks = []
            for chunk_start, chunk_end in chunk_bounds:
                # Adjust workspace_starts in-place per chunk to be
                # 0-indexed within each chunk
                # Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
                #   Initial: workspace_starts=[0,10,25,45]
                #   After:   workspace_starts=[0,10,0,20]
                #           (chunk 0 starts at 0, chunk 1 starts at 0)
                offset = prefill_workspace_starts_cpu[chunk_start].item()
                prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset

                chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
                chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
                token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
                token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
                tokens_slice = slice(token_start, token_end)

                # Create chunk view of gpu tensor
                chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
                chunk_block_table = common_attn_metadata.block_table_tensor[
                    num_decodes + chunk_start : num_decodes + chunk_end
                ]

                prefill_chunks.append(
                    FP8Meta.Prefill.Chunk(
                        seq_lens=chunk_seq_lens,
                        tokens_slice=tokens_slice,
                        block_table=chunk_block_table,
                        req_start_idx=chunk_start,
                        workspace_starts=chunk_workspace_starts,
                        chunk_tot_seqlen=chunk_tot_seqlen,
                    )
                )

            prefill_workspace_starts.copy_(
                prefill_workspace_starts_cpu, non_blocking=True
            )

            fp8_metadata.prefill = FP8Meta.Prefill(
                seq_lens=prefill_seq_lens,
                request_ids=prefill_request_id,
                workspace_starts=prefill_workspace_starts,
                chunks=prefill_chunks,
            )

        if num_decodes > 0:
            # Compute decode_query_len for spec decode (uniform due to require_uniform)
            query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
            decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()

            # Use padded head count since that's what the kernel will see
            scheduler_metadata, _ = get_mla_metadata()

            kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
                scheduler_metadata=scheduler_metadata,
                dummy_block_table=self.dummy_block_table[:num_decodes],
                cache_lens=self.max_model_len_tensor[:num_decodes],
            )
            fp8_metadata.decode = FP8Meta.Decode(
                seq_lens=common_attn_metadata.seq_lens[:num_decodes],
                kernel_metadata=kernel_meta,
                decode_query_len=decode_query_len,
            )

        return fp8_metadata

    def build(
        self,
        common_prefix_len: int,
        common_attn_metadata: CommonAttentionMetadata,
        fast_build: bool = False,
    ) -> FlashMLASparseMetadata:
        cm = common_attn_metadata
        num_tokens = cm.num_actual_tokens
        starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
        seg_lengths = np.diff(starts)
        req_id_per_token = np.repeat(
            np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
        )
        # Zero-fill for cudagraphs
        self.req_id_per_token_buffer.fill_(0)
        self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
            torch.from_numpy(req_id_per_token), non_blocking=True
        )
        req_id_per_token = self.req_id_per_token_buffer[:num_tokens]

        slot_mapping = cm.slot_mapping
        if self.compress_ratio > 1:
            slot_mapping = get_compressed_slot_mapping(
                common_attn_metadata.num_actual_tokens,
                common_attn_metadata.query_start_loc,
                common_attn_metadata.seq_lens,
                common_attn_metadata.block_table_tensor.clamp(min=0),
                int(self.kv_cache_spec.storage_block_size),
                self.compress_ratio,
                out=self.compressed_slot_mapping_buffer,
            )

        fp8_extra_metadata: (
            FlashMLASparseMetadata.FP8SeparatePrefillDecode
            | FlashMLASparseMetadata.FP8KernelMetadata
            | None
        ) = None
        fp8_use_mixed_batch = (
            self.num_heads < MIN_HEADS_FOR_BF16_PREFILL and not self.is_deepseek_v4
        )
        # DeepseekV4 has its own attention impl (DeepseekV4MLAAttention) that does not
        # consume fp8_extra_metadata. Skipping the build here avoids a
        # forced D2H sync on seq_lens that would otherwise fire on every
        # prefill-bearing step, lifting GPU utilization on long-prefill
        # workloads (e.g. LongBench) from ~83% to ~100%.
        if self.use_fp8_kv_cache and not self.is_deepseek_v4:
            if fp8_use_mixed_batch:
                fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
            else:
                fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)

        # Pre-compute C128A topk indices for DeepseekV4.
        c128a_fields = {}
        if self.is_deepseek_v4 and self.compress_ratio == 128:
            c128a_fields = self._build_c128a_metadata(cm, req_id_per_token)

        metadata = FlashMLASparseMetadata(
            num_reqs=cm.num_reqs,
            max_query_len=cm.max_query_len,
            max_seq_len=cm.max_seq_len,
            num_actual_tokens=cm.num_actual_tokens,
            query_start_loc=cm.query_start_loc,
            slot_mapping=slot_mapping,
            block_table=cm.block_table_tensor,
            req_id_per_token=req_id_per_token,
            block_size=self.kv_cache_spec.block_size,
            topk_tokens=self.topk_tokens,
            fp8_extra_metadata=fp8_extra_metadata,
            fp8_use_mixed_batch=fp8_use_mixed_batch,
            **c128a_fields,
        )

        return metadata

    def _build_c128a_metadata(
        self,
        cm: CommonAttentionMetadata,
        req_id_per_token: torch.Tensor,
    ) -> dict[str, torch.Tensor | None]:
        """Pre-compute C128A topk indices for DeepseekV4 (compress_ratio >= 128)."""
        # Must match SWA's decode split (no `require_uniform=True`) so
        # `c128a_global_decode_topk_indices.shape[0]` lines up with q in
        # `_forward_decode`. The per-token C128A kernel handles non-uniform
        # query lengths.
        (num_decodes, _, num_decode_tokens, num_prefill_tokens) = (
            split_decodes_and_prefills(
                cm,
                decode_threshold=self.reorder_batch_threshold or 1,
            )
        )

        num_total = num_decode_tokens + num_prefill_tokens
        if num_total == 0:
            return {}

        assert cm.positions is not None, (
            "positions is required for C128A metadata build"
        )
        block_size = self.kv_cache_spec.block_size // self.compress_ratio
        global_decode, decode_lens, prefill_local = build_c128a_topk_metadata(
            cm.positions[:num_total],
            self.compress_ratio,
            num_decode_tokens,
            req_id_per_token,
            cm.block_table_tensor[:num_decodes],
            block_size,
            cm.slot_mapping,
            self.c128a_global_decode_buffer,
            self.c128a_decode_lens_buffer,
            self.c128a_prefill_buffer,
            max_compressed_tokens=self.c128a_max_compressed,
        )

        result: dict[str, torch.Tensor | None] = {}
        if num_decode_tokens > 0:
            result["c128a_global_decode_topk_indices"] = global_decode.view(
                num_decode_tokens, 1, -1
            )
            result["c128a_decode_topk_lens"] = decode_lens
        if num_prefill_tokens > 0:
            result["c128a_prefill_topk_indices"] = prefill_local
        return result

_build_c128a_metadata

_build_c128a_metadata(
    cm: CommonAttentionMetadata, req_id_per_token: Tensor
) -> dict[str, Tensor | None]

Pre-compute C128A topk indices for DeepseekV4 (compress_ratio >= 128).

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _build_c128a_metadata(
    self,
    cm: CommonAttentionMetadata,
    req_id_per_token: torch.Tensor,
) -> dict[str, torch.Tensor | None]:
    """Pre-compute C128A topk indices for DeepseekV4 (compress_ratio >= 128)."""
    # Must match SWA's decode split (no `require_uniform=True`) so
    # `c128a_global_decode_topk_indices.shape[0]` lines up with q in
    # `_forward_decode`. The per-token C128A kernel handles non-uniform
    # query lengths.
    (num_decodes, _, num_decode_tokens, num_prefill_tokens) = (
        split_decodes_and_prefills(
            cm,
            decode_threshold=self.reorder_batch_threshold or 1,
        )
    )

    num_total = num_decode_tokens + num_prefill_tokens
    if num_total == 0:
        return {}

    assert cm.positions is not None, (
        "positions is required for C128A metadata build"
    )
    block_size = self.kv_cache_spec.block_size // self.compress_ratio
    global_decode, decode_lens, prefill_local = build_c128a_topk_metadata(
        cm.positions[:num_total],
        self.compress_ratio,
        num_decode_tokens,
        req_id_per_token,
        cm.block_table_tensor[:num_decodes],
        block_size,
        cm.slot_mapping,
        self.c128a_global_decode_buffer,
        self.c128a_decode_lens_buffer,
        self.c128a_prefill_buffer,
        max_compressed_tokens=self.c128a_max_compressed,
    )

    result: dict[str, torch.Tensor | None] = {}
    if num_decode_tokens > 0:
        result["c128a_global_decode_topk_indices"] = global_decode.view(
            num_decode_tokens, 1, -1
        )
        result["c128a_decode_topk_lens"] = decode_lens
    if num_prefill_tokens > 0:
        result["c128a_prefill_topk_indices"] = prefill_local
    return result

_build_fp8_mixed_decode_prefill

_build_fp8_mixed_decode_prefill(
    common_attn_metadata: CommonAttentionMetadata,
) -> FP8KernelMetadata

Build FP8 metadata treating all tokens as one mixed batch.

This matches main branch's approach and avoids the BF16 prefill kernel which has head padding overhead when num_heads is small (high TP case).

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def _build_fp8_mixed_decode_prefill(
    self,
    common_attn_metadata: CommonAttentionMetadata,
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
    """Build FP8 metadata treating all tokens as one mixed batch.

    This matches main branch's approach and avoids the BF16 prefill kernel
    which has head padding overhead when num_heads is small (high TP case).
    """
    num_tokens = common_attn_metadata.num_actual_tokens

    # Use padded head count since that's what the kernel will see
    padded_heads = self.fp8_decode_padded_heads

    # Build metadata for all tokens as a single batch
    scheduler_metadata, _ = get_mla_metadata(
        cache_seqlens=self.topk_tokens_tensor[:1],  # Single batch
        num_q_tokens_per_head_k=num_tokens * padded_heads,
        topk=self.topk_tokens,
        num_heads_q=padded_heads,
        num_heads_k=1,
        is_fp8_kvcache=True,
    )

    fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
        scheduler_metadata=scheduler_metadata,
        cache_lens=self.max_model_len_tensor[:1],
        dummy_block_table=self.dummy_block_table[:1],
    )

    return fp8_metadata

build_c128a_topk_metadata

build_c128a_topk_metadata(
    positions: Tensor,
    compress_ratio: int,
    num_decode_tokens: int,
    token_to_req_indices: Tensor,
    block_table: Tensor,
    block_size: int,
    slot_mapping: Tensor,
    global_decode_buffer: Tensor,
    decode_lens_buffer: Tensor,
    prefill_buffer: Tensor,
    max_compressed_tokens: int = 8192,
) -> tuple[Tensor, Tensor, Tensor]

Single kernel for all C128A tokens (decode + prefill).

Decode tokens: position → block_table lookup → global slot ids + topk_lens. Prefill tokens: position → local indices [0, ..., n-1, -1, ...].

Writes into pre-allocated buffers for CUDA graph address stability. Returns slices of the buffers.

Source code in vllm/v1/attention/backends/mla/flashmla_sparse.py
def build_c128a_topk_metadata(
    positions: torch.Tensor,
    compress_ratio: int,
    num_decode_tokens: int,
    token_to_req_indices: torch.Tensor,
    block_table: torch.Tensor,
    block_size: int,
    slot_mapping: torch.Tensor,
    global_decode_buffer: torch.Tensor,
    decode_lens_buffer: torch.Tensor,
    prefill_buffer: torch.Tensor,
    max_compressed_tokens: int = 8192,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Single kernel for all C128A tokens (decode + prefill).

    Decode tokens: position → block_table lookup → global slot ids + topk_lens.
    Prefill tokens: position → local indices [0, ..., n-1, -1, ...].

    Writes into pre-allocated buffers for CUDA graph address stability.
    Returns slices of the buffers.
    """
    num_tokens = positions.shape[0]
    num_prefill_tokens = num_tokens - num_decode_tokens

    global_decode = global_decode_buffer[:num_decode_tokens]
    decode_lens = decode_lens_buffer[:num_decode_tokens]
    prefill_local = prefill_buffer[:num_prefill_tokens]

    if num_tokens == 0:
        return global_decode, decode_lens, prefill_local

    _build_c128a_topk_metadata_kernel[(num_tokens,)](
        global_decode_buffer,
        global_decode_buffer.stride(0),
        decode_lens_buffer,
        prefill_buffer,
        prefill_buffer.stride(0),
        positions,
        compress_ratio,
        max_compressed_tokens,
        num_decode_tokens,
        token_to_req_indices,
        block_table,
        block_table.stride(0),
        block_size,
        slot_mapping,
        BLOCK_SIZE=1024,
    )
    return global_decode, decode_lens, prefill_local