class DeepGemmFP4Experts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation for FP4 weights.
Uses m_grouped_fp8_fp4_gemm_nt_contiguous with FP8 activations and
MXFP4 (FP4 E2M1 packed as uint8) weights. Requires SM100+ (Blackwell).
"""
# FP8 activation block size (hardcoded since mxfp4_w4a8 quant config
# does not set a block_shape on the activation descriptor).
_ACT_BLOCK_K = 128
# FP4 weight block size
_WEIGHT_BLOCK_K = 32
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.weight_quant_dtype == "mxfp4"
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.gemm1_clamp_limit = quant_config.gemm1_clamp_limit
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
from vllm.platforms import current_platform
return (
is_deep_gemm_supported()
and current_platform.is_device_capability_family(100)
)
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return not (
moe_parallel_config.use_fi_nvl_two_sided_kernels
or moe_parallel_config.use_fi_nvl_one_sided_kernels
)
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
block_m = get_mk_alignment_for_contiguous_layout()[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_tokens_meta
)
assert M_sum % block_m == 0
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
block_k = self._ACT_BLOCK_K
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
assert activation == MoEActivation.SILU
return fused_silu_mul_fp8_quant_packed(
input=input,
output_q=output,
group_size=block_k,
clamp_limit=self.gemm1_clamp_limit,
)
if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
)
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
assert a2_scale is None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, _ = w1.size()
# K comes from activations (full hidden dim), not from w1 which is
# packed FP4 (E, N, K//2).
K = a1q.size(1)
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm,
)
assert a1q.size(0) == M_sum
# FC1: FP8 activations x FP4 weights
# DeepGEMM 2.4.2 requires FP4-packed weights as int8 (kPackedFP4).
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a1q, a1q_scale),
(w1.view(torch.int8), self.w1_scale),
mm1_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
# SwiGLU activation + FP8 requant
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
# FC2: FP8 activations x FP4 weights
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_fp4_gemm_nt_contiguous(
(a2q, a2q_scale),
(w2.view(torch.int8), self.w2_scale),
mm2_out,
expert_ids,
recipe_a=(1, self._ACT_BLOCK_K),
recipe_b=(1, self._WEIGHT_BLOCK_K),
)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)