class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
assert not self.base_layer.quant_method.is_monolithic, (
"Monolithic kernels are not supported for Fused MoE LoRA."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer)
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
self.base_layer.ensure_moe_quant_config_init()
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
moe_kernel = self.base_layer.quant_method.moe_kernel
# Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream.
moe_kernel.shared_experts = None
else:
prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
moe_kernel = FusedMoEKernel(
prepare_finalize,
self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer
),
)
assert moe_kernel.supports_lora(), (
f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
"For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
"(auto selects Triton automatically when LoRA is enabled). "
"For quantized MoE, mix LoRAExpertsMixin into the experts class "
"and consume self._lora_context in apply()."
)
self._fused_experts = moe_kernel.fused_experts
self.base_layer._replace_quant_method(
FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel)
)
def _build_lora_context(self):
return MoELoRAContext(
w13_lora_a_stacked=self.w13_lora_a_stacked,
w13_lora_b_stacked=self.w13_lora_b_stacked,
w2_lora_a_stacked=self.w2_lora_a_stacked,
w2_lora_b_stacked=self.w2_lora_b_stacked,
adapter_enabled=self.adapter_enabled,
max_loras=self.max_loras,
top_k=self.base_layer.top_k,
w13_num_slices=self._w13_slices,
fully_sharded=self.fully_sharded,
tp_rank=self.tp_rank,
tp_size=self.tp_size,
local_num_experts=self.base_layer.local_num_experts,
punica_wrapper=self.punica_wrapper,
use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
)
def _create_lora_a_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
):
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
# TODO Optimize this section
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts):
# For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
# For non-gated MoE: up_proj (w1), down_proj (w2)
self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(
self.w2_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w2_lora_b_stacked[0][lora_id][experts_id]
)
# Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
if self._w13_slices == 2:
self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
shard_size = self.w13_lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_b[:, start_idx:end_idx, :]
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
shard_size = self.w2_lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices):
self.w13_lora_a_stacked[pos][index] = 0
self.w13_lora_b_stacked[pos][index] = 0
self.w2_lora_a_stacked[0][index] = 0
self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
#
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
num_experts = self.w13_lora_a_stacked[0].shape[1]
w1_lora_a, w2_lora_a, w3_lora_a = lora_a
w1_lora_b, w2_lora_b, w3_lora_b = lora_b
assert (
num_experts
== w1_lora_a.shape[0]
== w2_lora_a.shape[0]
== w3_lora_a.shape[0]
)
slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][
index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
].copy_(slliced_w1_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
].copy_(slliced_w1_lora_b, non_blocking=True)
# Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
if self._w13_slices == 2:
slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
self.w13_lora_a_stacked[1][
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
].copy_(slliced_w3_lora_a, non_blocking=True)
self.w13_lora_b_stacked[1][
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
].copy_(slliced_w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
def set_mapping(self, punica_wrapper):
super().set_mapping(punica_wrapper)
self._fused_experts.set_lora_context(self._build_lora_context())
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
@property
def quant_method(self):
return self.base_layer.quant_method
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2