Skip to content

vllm.lora.layers.fused_moe

FusedMoE3DWithLoRA

Bases: FusedMoEWithLoRA

Source code in vllm/lora/layers/fused_moe.py
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert isinstance(model_config, PretrainedConfig)
        self._base_model = model_config.architectures[0]
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        # HACK: Currently, only GPT-OSS is in interleaved order
        if self._base_model == "GptOssForCausalLM":
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    @property
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]

    @property
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size

    @property
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
        return self.base_layer.hidden_size

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # source_layer is FusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

w13_input_size property

w13_input_size

Full size

w13_output_size property

w13_output_size

Full size

w2_input_size property

w2_input_size

Full size

w2_output_size property

w2_output_size

Full size

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    # source_layer is FusedMoE
    return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    assert isinstance(model_config, PretrainedConfig)
    self._base_model = model_config.architectures[0]
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)
    assert len(lora_a) == len(lora_b) == 2

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    w13_lora_a, w2_lora_a = lora_a
    w13_lora_b, w2_lora_b = lora_b

    sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
    sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
    ].copy_(sliced_w13_lora_a, non_blocking=True)
    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
    ].copy_(sliced_w13_lora_b, non_blocking=True)
    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

FusedMoEWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/fused_moe.py
class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )
        assert not self.base_layer.quant_method.is_monolithic, (
            "Monolithic kernels are not supported for Fused MoE LoRA."
        )
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.device = _get_lora_device(base_layer)
        # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
        # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
        self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1

        self.base_layer.ensure_moe_quant_config_init()
        if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
            moe_kernel = self.base_layer.quant_method.moe_kernel
            # Don't let the kernel own shared experts so the runner can
            # overlap them with routed experts via a separate CUDA stream.
            moe_kernel.shared_experts = None
        else:
            prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
            moe_kernel = FusedMoEKernel(
                prepare_finalize,
                self.base_layer.quant_method.select_gemm_impl(
                    prepare_finalize, self.base_layer
                ),
            )
        assert moe_kernel.supports_lora(), (
            f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
            "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
            "(auto selects Triton automatically when LoRA is enabled). "
            "For quantized MoE, mix LoRAExpertsMixin into the experts class "
            "and consume self._lora_context in apply()."
        )
        self._fused_experts = moe_kernel.fused_experts
        self.base_layer._replace_quant_method(
            FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel)
        )

    def _build_lora_context(self):
        return MoELoRAContext(
            w13_lora_a_stacked=self.w13_lora_a_stacked,
            w13_lora_b_stacked=self.w13_lora_b_stacked,
            w2_lora_a_stacked=self.w2_lora_a_stacked,
            w2_lora_b_stacked=self.w2_lora_b_stacked,
            adapter_enabled=self.adapter_enabled,
            max_loras=self.max_loras,
            top_k=self.base_layer.top_k,
            w13_num_slices=self._w13_slices,
            fully_sharded=self.fully_sharded,
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            local_num_experts=self.base_layer.local_num_experts,
            punica_wrapper=self.punica_wrapper,
            use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
        )

    def _create_lora_a_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.base_layer.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank,
                    self.base_layer.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        # TODO Optimize this section
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.base_layer.local_num_experts):
                # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
                # For non-gated MoE: up_proj (w1), down_proj (w2)
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
                )

                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

                # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
                if self._w13_slices == 2:
                    self.lora_a_stacked.append(
                        self.w13_lora_a_stacked[1][lora_id][experts_id]
                    )
                    self.lora_b_stacked.append(
                        self.w13_lora_b_stacked[1][lora_id][experts_id]
                    )

    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0
        # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
        shard_size = self.w13_lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w13_lora_b[:, start_idx:end_idx, :]

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
        shard_size = self.w2_lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_b[:, start_idx:end_idx, :]

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        for pos in range(self._w13_slices):
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
        self.adapter_enabled[index] = 0

    #

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        num_experts = self.w13_lora_a_stacked[0].shape[1]

        w1_lora_a, w2_lora_a, w3_lora_a = lora_a
        w1_lora_b, w2_lora_b, w3_lora_b = lora_b
        assert (
            num_experts
            == w1_lora_a.shape[0]
            == w2_lora_a.shape[0]
            == w3_lora_a.shape[0]
        )

        slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
        slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
        ].copy_(slliced_w1_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
        ].copy_(slliced_w1_lora_b, non_blocking=True)

        # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
        if self._w13_slices == 2:
            slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
            slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

            self.w13_lora_a_stacked[1][
                index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
            ].copy_(slliced_w3_lora_a, non_blocking=True)

            self.w13_lora_b_stacked[1][
                index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
            ].copy_(slliced_w3_lora_b, non_blocking=True)

        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    def set_mapping(self, punica_wrapper):
        super().set_mapping(punica_wrapper)
        self._fused_experts.set_lora_context(self._build_lora_context())

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    @property
    def quant_method(self):
        return self.base_layer.quant_method

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

        # source_layer is FusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2

_slice_w13_a

_slice_w13_a(w13_lora_a: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w13_lora_a

    # w13_lora_a shape (num_experts,rank,input_size)
    current_lora_rank = w13_lora_a.shape[1]
    assert current_lora_rank % self.tp_size == 0
    # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
    shard_size = self.w13_lora_a_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size
    return w13_lora_a[:, start_idx:end_idx, :]

_slice_w2_a

_slice_w2_a(w2_lora_a: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1:
        return w2_lora_a
    # w2_lora_a shape (num_experts,rank,input_size)
    shard_size = self.base_layer.intermediate_size_per_partition
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_a[:, :, start_idx:end_idx]

_slice_w2_b

_slice_w2_b(w2_lora_b: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w2_lora_b
    # Based on S-LoRA, we slice W2 B along the hidden_size dim.
    # w2_lora_b shape (num_experts,output_size,rank)
    shard_size = self.w2_lora_b_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_b[:, start_idx:end_idx, :]

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""

    # source_layer is FusedMoE
    return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)
    # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
    # to create a dummy LoRA weights.
    # TODO Optimize this section
    self.lora_a_stacked = []
    self.lora_b_stacked = []
    for lora_id in range(max_loras):
        for experts_id in range(self.base_layer.local_num_experts):
            # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
            # For non-gated MoE: up_proj (w1), down_proj (w2)
            self.lora_a_stacked.append(
                self.w13_lora_a_stacked[0][lora_id][experts_id]
            )
            self.lora_a_stacked.append(
                self.w2_lora_a_stacked[0][lora_id][experts_id]
            )

            self.lora_b_stacked.append(
                self.w13_lora_b_stacked[0][lora_id][experts_id]
            )
            self.lora_b_stacked.append(
                self.w2_lora_b_stacked[0][lora_id][experts_id]
            )

            # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
            if self._w13_slices == 2:
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[1][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[1][lora_id][experts_id]
                )

reset_lora

reset_lora(index: int)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/fused_moe.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    for pos in range(self._w13_slices):
        self.w13_lora_a_stacked[pos][index] = 0
        self.w13_lora_b_stacked[pos][index] = 0

    self.w2_lora_a_stacked[0][index] = 0
    self.w2_lora_b_stacked[0][index] = 0
    self.adapter_enabled[index] = 0

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    num_experts = self.w13_lora_a_stacked[0].shape[1]

    w1_lora_a, w2_lora_a, w3_lora_a = lora_a
    w1_lora_b, w2_lora_b, w3_lora_b = lora_b
    assert (
        num_experts
        == w1_lora_a.shape[0]
        == w2_lora_a.shape[0]
        == w3_lora_a.shape[0]
    )

    slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
    slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
    ].copy_(slliced_w1_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
    ].copy_(slliced_w1_lora_b, non_blocking=True)

    # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
    if self._w13_slices == 2:
        slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
        slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

        self.w13_lora_a_stacked[1][
            index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
        ].copy_(slliced_w3_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[1][
            index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
        ].copy_(slliced_w3_lora_b, non_blocking=True)

    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)