Skip to content

vllm.lora.layers

Modules:

Name Description
base
base_linear
column_parallel_linear
fused_moe
logits_processor
replicated_linear
row_parallel_linear
utils

BaseLayerWithLoRA

Bases: Module

Source code in vllm/lora/layers/base.py
class BaseLayerWithLoRA(nn.Module):
    @overload
    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]: ...
    @overload
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: ...
    def slice_lora_a(
        self, lora_a: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora a if splitting for tensor parallelism."""
        ...

    @overload
    def slice_lora_b(
        self, lora_b: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]: ...
    @overload
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: ...
    def slice_lora_b(
        self, lora_b: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora b if splitting with tensor parallelism."""
        ...

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
        punica_wrapper,
    ):
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/base.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    raise NotImplementedError

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/base.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""
    ...

reset_lora

reset_lora(index: int)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/base.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    ...

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/base.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    ...

slice_lora_a

slice_lora_a(
    lora_a: list[Tensor | None],
) -> list[Tensor | None]
slice_lora_a(lora_a: Tensor) -> Tensor
slice_lora_a(
    lora_a: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora a if splitting for tensor parallelism.

Source code in vllm/lora/layers/base.py
def slice_lora_a(
    self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora a if splitting for tensor parallelism."""
    ...

slice_lora_b

slice_lora_b(
    lora_b: list[Tensor | None],
) -> list[Tensor | None]
slice_lora_b(lora_b: Tensor) -> Tensor
slice_lora_b(
    lora_b: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/base.py
def slice_lora_b(
    self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora b if splitting with tensor parallelism."""
    ...

ColumnParallelLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

LoRA on top of ColumnParallelLinear layer. LoRA B is sliced for tensor parallelism. There are two types for the base_layer: 1. ColumnParallelLinear, e.g.dense_h_to_4h in FalconForCausalLM. 2. MergedColumnParallelLinear, e.g.gate_up_proj in Phi3ForCausalLM.

Source code in vllm/lora/layers/column_parallel_linear.py
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
    """

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__(base_layer)
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = isinstance(base_layer, MergedColumnParallelLinear)
        self.output_size = self.base_layer.output_size_per_partition
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            shard_size = self.output_size // 2
            offset = lora_b.shape[0] // 2

            left_weight = lora_b[
                self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
            ]
            right_weight = lora_b[
                offset + self.tp_rank * shard_size : offset
                + (self.tp_rank + 1) * shard_size,
                :,
            ]
            lora_b = torch.cat([left_weight, right_weight], dim=0)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            shard_size = self.output_size
            start_idx = self.tp_rank * shard_size
            end_idx = (self.tp_rank + 1) * shard_size
            lora_b = lora_b[start_idx:end_idx, :]
        return lora_b

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

        # Matrix multiply.
        output_parallel = self.apply(input_, bias)
        if self.base_layer.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel

        if not self.base_layer.return_bias:
            return output

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
            return True
        if isinstance(source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)):
            if len(packed_modules_list) != 1:
                return False
            # Exclude layers with 3+ output sizes - those are handled by
            # MergedColumnParallelLinearVariableSliceWithLoRA since this
            # class's slice_lora_b assumes exactly 2 slices.
            return not (
                hasattr(source_layer, "output_sizes")
                and len(source_layer.output_sizes) >= 3
            )
        return False

forward

forward(
    input_: Tensor,
) -> Tensor | tuple[Tensor, Tensor | None]

Forward of ColumnParallelLinear

Parameters:

Name Type Description Default
input_ Tensor

Tensor whose last dimension is input_size.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor | None]
  • output
Tensor | tuple[Tensor, Tensor | None]
  • bias
Source code in vllm/lora/layers/column_parallel_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of ColumnParallelLinear

    Args:
        input_: Tensor whose last dimension is `input_size`.

    Returns:
        - output
        - bias
    """
    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

    # Matrix multiply.
    output_parallel = self.apply(input_, bias)
    if self.base_layer.gather_output and self.tp_size > 1:
        # All-gather across the partitions.
        output = tensor_model_parallel_all_gather(output_parallel)
    else:
        output = output_parallel

    if not self.base_layer.return_bias:
        return output

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
    return output, output_bias

ColumnParallelLinearWithShardedLoRA

Bases: ColumnParallelLinearWithLoRA

Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
    """
    Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
    # their `lora_a` and `lora_b` have different sharding patterns. After
    # completing the `lora_a` GEMM , a gather operation is performed.
    # Therefore, the sharding of `lora_a` only needs to correspond with the
    # gather operation.
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

FusedMoE3DWithLoRA

Bases: FusedMoEWithLoRA

Source code in vllm/lora/layers/fused_moe.py
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert isinstance(model_config, PretrainedConfig)
        self._base_model = model_config.architectures[0]
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        # HACK: Currently, only GPT-OSS is in interleaved order
        if self._base_model == "GptOssForCausalLM":
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    @property
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]

    @property
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size

    @property
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
        return self.base_layer.hidden_size

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # source_layer is FusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

w13_input_size property

w13_input_size

Full size

w13_output_size property

w13_output_size

Full size

w2_input_size property

w2_input_size

Full size

w2_output_size property

w2_output_size

Full size

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    # source_layer is FusedMoE
    return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    assert isinstance(model_config, PretrainedConfig)
    self._base_model = model_config.architectures[0]
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)
    assert len(lora_a) == len(lora_b) == 2

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    w13_lora_a, w2_lora_a = lora_a
    w13_lora_b, w2_lora_b = lora_b

    sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
    sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
    ].copy_(sliced_w13_lora_a, non_blocking=True)
    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
    ].copy_(sliced_w13_lora_b, non_blocking=True)
    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

FusedMoEWithLoRA

Bases: BaseLayerWithLoRA

Source code in vllm/lora/layers/fused_moe.py
class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: FusedMoE) -> None:
        super().__init__()
        self.base_layer = base_layer

        assert not self.base_layer.use_ep, (
            "EP support for Fused MoE LoRA is not implemented yet."
        )
        assert not self.base_layer.quant_method.is_monolithic, (
            "Monolithic kernels are not supported for Fused MoE LoRA."
        )
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.device = _get_lora_device(base_layer)
        # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
        # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
        self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1

        self.base_layer.ensure_moe_quant_config_init()
        if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
            moe_kernel = self.base_layer.quant_method.moe_kernel
            # Don't let the kernel own shared experts so the runner can
            # overlap them with routed experts via a separate CUDA stream.
            moe_kernel.shared_experts = None
        else:
            prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
            moe_kernel = FusedMoEKernel(
                prepare_finalize,
                self.base_layer.quant_method.select_gemm_impl(
                    prepare_finalize, self.base_layer
                ),
            )
        assert moe_kernel.supports_lora(), (
            f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
            "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
            "(auto selects Triton automatically when LoRA is enabled). "
            "For quantized MoE, mix LoRAExpertsMixin into the experts class "
            "and consume self._lora_context in apply()."
        )
        self._fused_experts = moe_kernel.fused_experts
        self.base_layer._replace_quant_method(
            FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel)
        )

    def _build_lora_context(self):
        return MoELoRAContext(
            w13_lora_a_stacked=self.w13_lora_a_stacked,
            w13_lora_b_stacked=self.w13_lora_b_stacked,
            w2_lora_a_stacked=self.w2_lora_a_stacked,
            w2_lora_b_stacked=self.w2_lora_b_stacked,
            adapter_enabled=self.adapter_enabled,
            max_loras=self.max_loras,
            top_k=self.base_layer.top_k,
            w13_num_slices=self._w13_slices,
            fully_sharded=self.fully_sharded,
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            local_num_experts=self.base_layer.local_num_experts,
            punica_wrapper=self.punica_wrapper,
            use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
        )

    def _create_lora_a_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.base_layer.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    lora_config.max_lora_rank,
                    self.base_layer.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.base_layer.local_num_experts,
                    self.base_layer.hidden_size
                    if not self.fully_sharded
                    else divide(self.base_layer.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        # TODO Optimize this section
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.base_layer.local_num_experts):
                # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
                # For non-gated MoE: up_proj (w1), down_proj (w2)
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
                )

                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

                # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
                if self._w13_slices == 2:
                    self.lora_a_stacked.append(
                        self.w13_lora_a_stacked[1][lora_id][experts_id]
                    )
                    self.lora_b_stacked.append(
                        self.w13_lora_b_stacked[1][lora_id][experts_id]
                    )

    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0
        # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
        shard_size = self.w13_lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w13_lora_b[:, start_idx:end_idx, :]

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.base_layer.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
        shard_size = self.w2_lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_b[:, start_idx:end_idx, :]

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        for pos in range(self._w13_slices):
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
        self.adapter_enabled[index] = 0

    #

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        num_experts = self.w13_lora_a_stacked[0].shape[1]

        w1_lora_a, w2_lora_a, w3_lora_a = lora_a
        w1_lora_b, w2_lora_b, w3_lora_b = lora_b
        assert (
            num_experts
            == w1_lora_a.shape[0]
            == w2_lora_a.shape[0]
            == w3_lora_a.shape[0]
        )

        slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
        slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
        ].copy_(slliced_w1_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
        ].copy_(slliced_w1_lora_b, non_blocking=True)

        # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
        if self._w13_slices == 2:
            slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
            slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

            self.w13_lora_a_stacked[1][
                index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
            ].copy_(slliced_w3_lora_a, non_blocking=True)

            self.w13_lora_b_stacked[1][
                index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
            ].copy_(slliced_w3_lora_b, non_blocking=True)

        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    def set_mapping(self, punica_wrapper):
        super().set_mapping(punica_wrapper)
        self._fused_experts.set_lora_context(self._build_lora_context())

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    @property
    def quant_method(self):
        return self.base_layer.quant_method

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

        # source_layer is FusedMoE
        return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2

_slice_w13_a

_slice_w13_a(w13_lora_a: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w13_lora_a

    # w13_lora_a shape (num_experts,rank,input_size)
    current_lora_rank = w13_lora_a.shape[1]
    assert current_lora_rank % self.tp_size == 0
    # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
    shard_size = self.w13_lora_a_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size
    return w13_lora_a[:, start_idx:end_idx, :]

_slice_w2_a

_slice_w2_a(w2_lora_a: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1:
        return w2_lora_a
    # w2_lora_a shape (num_experts,rank,input_size)
    shard_size = self.base_layer.intermediate_size_per_partition
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_a[:, :, start_idx:end_idx]

_slice_w2_b

_slice_w2_b(w2_lora_b: Tensor) -> Tensor

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w2_lora_b
    # Based on S-LoRA, we slice W2 B along the hidden_size dim.
    # w2_lora_b shape (num_experts,output_size,rank)
    shard_size = self.w2_lora_b_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_b[:, start_idx:end_idx, :]

can_replace_layer classmethod

can_replace_layer(
    source_layer: Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""

    # source_layer is FusedMoE
    return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)
    # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
    # to create a dummy LoRA weights.
    # TODO Optimize this section
    self.lora_a_stacked = []
    self.lora_b_stacked = []
    for lora_id in range(max_loras):
        for experts_id in range(self.base_layer.local_num_experts):
            # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
            # For non-gated MoE: up_proj (w1), down_proj (w2)
            self.lora_a_stacked.append(
                self.w13_lora_a_stacked[0][lora_id][experts_id]
            )
            self.lora_a_stacked.append(
                self.w2_lora_a_stacked[0][lora_id][experts_id]
            )

            self.lora_b_stacked.append(
                self.w13_lora_b_stacked[0][lora_id][experts_id]
            )
            self.lora_b_stacked.append(
                self.w2_lora_b_stacked[0][lora_id][experts_id]
            )

            # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
            if self._w13_slices == 2:
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[1][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[1][lora_id][experts_id]
                )

reset_lora

reset_lora(index: int)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/fused_moe.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    for pos in range(self._w13_slices):
        self.w13_lora_a_stacked[pos][index] = 0
        self.w13_lora_b_stacked[pos][index] = 0

    self.w2_lora_a_stacked[0][index] = 0
    self.w2_lora_b_stacked[0][index] = 0
    self.adapter_enabled[index] = 0

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    num_experts = self.w13_lora_a_stacked[0].shape[1]

    w1_lora_a, w2_lora_a, w3_lora_a = lora_a
    w1_lora_b, w2_lora_b, w3_lora_b = lora_b
    assert (
        num_experts
        == w1_lora_a.shape[0]
        == w2_lora_a.shape[0]
        == w3_lora_a.shape[0]
    )

    slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
    slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
    ].copy_(slliced_w1_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
    ].copy_(slliced_w1_lora_b, non_blocking=True)

    # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
    if self._w13_slices == 2:
        slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
        slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

        self.w13_lora_a_stacked[1][
            index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
        ].copy_(slliced_w3_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[1][
            index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
        ].copy_(slliced_w3_lora_b, non_blocking=True)

    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

LogitsProcessorWithLoRA

Bases: BaseLayerWithLoRA

LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary.

Parameters:

Name Type Description Default
base_layer LogitsProcessor

LogitsProcessor layer

required
hidden_size int

hidden size of the model

required
dtype dtype

data type of the model

required
device device

device of the model

required
sharded_to_full_mapping list[int] | None

index mapping from sharded vocab to full vocab received from base_layer.get_sharded_to_full_mapping(). If None, no reindexing will be done.

required
Source code in vllm/lora/layers/logits_processor.py
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """

    def __init__(
        self,
        base_layer: LogitsProcessor,
        hidden_size: int,
        dtype: torch.dtype,
        device: torch.device,
        sharded_to_full_mapping: list[int] | None,
    ) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping

    @property
    def logits_as_input(self):
        return self.base_layer.logits_as_input

    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

    @property
    def scale(self):
        return self.base_layer.scale

    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

    @property
    def use_all_gather(self):
        return self.base_layer.use_all_gather

    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        # TODO: Verify if this condition can be further relaxed
        if self.base_layer.vocab_size > 258048:
            raise ValueError("When using LoRA, vocab size must be <= 258048")
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping, device=self.device, dtype=torch.long
            )
        else:
            self.sharded_to_full_mapping_gpu = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        assert isinstance(lora_a, torch.Tensor)
        assert isinstance(lora_b, torch.Tensor)
        self.reset_lora(index)
        self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
            lora_a, non_blocking=True
        )
        self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: VocabParallelEmbedding,
        embedding_bias: torch.Tensor | None = None,
    ) -> torch.Tensor | None:
        # Get the logits for the next tokens.
        if hasattr(lm_head, "base_layer"):
            actual_lm_head = lm_head.base_layer
        else:
            actual_lm_head = lm_head
        logits = actual_lm_head.quant_method.apply(actual_lm_head, hidden_states)
        if embedding_bias is not None:
            logits += embedding_bias

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

        if logits is None:
            return None

        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
            logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
        )

        if not current_platform.can_update_inplace():
            logits = lora_output

        # Remove paddings in vocab (if any).
        logits = logits[:, : self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Special handling for the LogitsProcessor.
        return False

MergedColumnParallelLinearVariableSliceWithLoRA

Bases: MergedColumnParallelLinearWithLoRA

MergedColumnParallelLinear with variable number of slices (3+).

This handles cases where the checkpoint has a single weight for the whole module (not split into slices), but the layer itself has multiple slices.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearVariableSliceWithLoRA(
    MergedColumnParallelLinearWithLoRA
):
    """MergedColumnParallelLinear with variable number of slices (3+).

    This handles cases where the checkpoint has a single weight for the whole
    module (not split into slices), but the layer itself has multiple slices.
    """

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Support MergedColumnParallelLinear with 3 or more slices
        # (2 slices are handled by MergedColumnParallelLinearWithLoRA)
        if not isinstance(
            source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)
        ):
            return False

        # If packed_modules_list has 3+ items, use this class
        if len(packed_modules_list) >= 3:
            return True

        # If packed_modules_list has exactly 2 items, let
        # MergedColumnParallelLinearWithLoRA handle it
        if len(packed_modules_list) == 2:
            return False

        # If packed_modules_list is empty or has 1 item,
        # check the layer's output_sizes.
        # This handles cases where the checkpoint has a single weight
        # but the layer has multiple slices (3+)
        return (
            hasattr(source_layer, "output_sizes")
            and len(source_layer.output_sizes) >= 3
        )

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Override to handle single tensor weights
        that need to be split into slices."""
        self.reset_lora(index)

        # Handle case where checkpoint has single tensor weights
        # lora_a shape: (rank, input_size) - same for all slices, duplicate it
        if isinstance(lora_a, torch.Tensor):
            lora_a = [lora_a] * self.n_slices

        # lora_b shape: (total_output_size, rank) -
        # split along dim 0 based on output_sizes
        if isinstance(lora_b, torch.Tensor):
            output_sizes = self.base_layer.output_sizes
            lora_b_list = []
            start_idx = 0
            for output_size in output_sizes:
                end_idx = start_idx + output_size
                lora_b_list.append(lora_b[start_idx:end_idx, :])
                start_idx = end_idx
            lora_b = lora_b_list

        # Now call parent's set_lora which expects lists
        super().set_lora(index, lora_a, lora_b)

set_lora

set_lora(
    index: int,
    lora_a: Tensor | list[Tensor],
    lora_b: Tensor | list[Tensor],
)

Override to handle single tensor weights that need to be split into slices.

Source code in vllm/lora/layers/column_parallel_linear.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Override to handle single tensor weights
    that need to be split into slices."""
    self.reset_lora(index)

    # Handle case where checkpoint has single tensor weights
    # lora_a shape: (rank, input_size) - same for all slices, duplicate it
    if isinstance(lora_a, torch.Tensor):
        lora_a = [lora_a] * self.n_slices

    # lora_b shape: (total_output_size, rank) -
    # split along dim 0 based on output_sizes
    if isinstance(lora_b, torch.Tensor):
        output_sizes = self.base_layer.output_sizes
        lora_b_list = []
        start_idx = 0
        for output_size in output_sizes:
            end_idx = start_idx + output_size
            lora_b_list.append(lora_b[start_idx:end_idx, :])
            start_idx = end_idx
        lora_b = lora_b_list

    # Now call parent's set_lora which expects lists
    super().set_lora(index, lora_a, lora_b)

MergedColumnParallelLinearWithLoRA

Bases: ColumnParallelLinearWithLoRA

ColumnParallelLinear layer that is composed of 2 sublayers (slices) packed together (e.g. gate_proj + up_proj -> gate_up_proj).

This means we have 2 LoRAs, each applied to one half of the layer.

Both slices must have the same size.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (e.g. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(
        self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
    ) -> None:
        super().__init__(base_layer)
        # There are two LoRA layers
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in output_sizes
        )
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank,) * self.n_slices

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """
        The main reason for overriding this function is to enhance  code
        maintainability.
        """
        self.lora_config = lora_config

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank
            if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size)
        )

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_output_size_per_partition,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self.n_slices)
        )
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                output_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for output_size in self.output_slices
        )

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        return lora_a

    def slice_lora_b(
        self, lora_b: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        sliced_lora_b = [None] * self.n_slices
        for i, (shard_id, shard_size) in enumerate(
            zip(self.output_ids, self.output_slices)
        ):
            if (lora_b_i := lora_b[i]) is not None:
                sliced_lora_b[i] = lora_b_i[
                    shard_size * shard_id : shard_size * (shard_id + 1), :
                ]
        return sliced_lora_b

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        self.reset_lora(index)

        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
                ].copy_(lora_a_i, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
                ].copy_(lora_b_i, non_blocking=True)

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
        # Effectively unsharded subclasses can safely reuse their custom
        # forward() implementation before applying the LoRA delta.
        if (
            self.tp_size == 1
            and type(self.base_layer) is not merged_cls
            and type(self.base_layer).forward is not merged_cls.forward
        ):
            return self._apply_base_forward(x)
        return _mcp_apply(x, bias, self)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
        decorate: bool = True,
    ) -> bool:
        merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
        if not isinstance(source_layer, merged_cls) or len(packed_modules_list) != 2:
            return False

        tp_size = getattr(source_layer, "tp_size", 1)
        if type(source_layer) is merged_cls:
            if not decorate:
                return True
            return not lora_config.fully_sharded_loras or tp_size == 1

        # Only support effectively unsharded subclasses here. Sharded
        # subclasses may have custom communication semantics that the generic
        # merged-column LoRA path does not know how to preserve.
        return tp_size == 1

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

The main reason for overriding this function is to enhance code maintainability.

Source code in vllm/lora/layers/column_parallel_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """
    The main reason for overriding this function is to enhance  code
    maintainability.
    """
    self.lora_config = lora_config

    lora_a_output_size_per_partition = (
        lora_config.max_lora_rank
        if not lora_config.fully_sharded_loras
        else divide(lora_config.max_lora_rank, self.tp_size)
    )

    self.lora_a_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            lora_a_output_size_per_partition,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        for _ in range(self.n_slices)
    )
    self.lora_b_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            output_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        for output_size in self.output_slices
    )

MergedColumnParallelLinearWithShardedLoRA

Bases: MergedColumnParallelLinearWithLoRA

Differs from MergedColumnParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
    """
    Differs from MergedColumnParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        # NOTE: lora_a contains 2 subloras, and each sublora could be None.
        output_shard_size = self.lora_a_stacked[0].shape[2]
        output_start_idx = self.tp_rank * output_shard_size
        lora_a = [
            lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
            if lora_a[0] is not None
            else None,
            lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
            if lora_a[1] is not None
            else None,
        ]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

MergedQKVParallelLinearWithLoRA

Bases: MergedColumnParallelLinearWithLoRA

MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj).

This means we have 3 LoRAs, each applied to one slice of the layer.

Q slice may have different shape than K and V slices (which both have the same shape).

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)

        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas

        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3

create_lora_weights

create_lora_weights(
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None

The main reason for overloading this function is to handle inconsistent weight dimensions in qkv lora.

Source code in vllm/lora/layers/column_parallel_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """
    The main reason for overloading this function is to handle inconsistent
    weight dimensions in qkv lora.
    """
    super().create_lora_weights(max_loras, lora_config, model_config)

MergedQKVParallelLinearWithShardedLoRA

Bases: MergedQKVParallelLinearWithLoRA

Differs from MergedQKVParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
    """
    Differs from MergedQKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        # NOTE: lora_a contains 3 subloras, and each sublora could be None.
        shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
        start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
        lora_a = [
            lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
            if lora_a[0] is not None
            else None,
            lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
            if lora_a[1] is not None
            else None,
            lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
            if lora_a[2] is not None
            else None,
        ]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

QKVParallelLinearWithLoRA

Bases: ColumnParallelLinearWithLoRA

ColumnParallelLinear layer that is specifically designed for qkv_proj. Certain models, such as chatglm3 and baichuan-7b, only contains a single LoRA within their qkv_proj layer.

During inference with Tensor Parallel, the weights of lora_b must be accurately partitioned according to the respective ranks.

Q slice may have different shape than K and V slices (which both have the same shape).

Source code in vllm/lora/layers/column_parallel_linear.py
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """
    ColumnParallelLinear layer that is specifically designed for
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
    only contains a single LoRA within their qkv_proj layer.

    During inference with Tensor Parallel, the weights of lora_b
    must be accurately partitioned according to the respective ranks.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (
            self.base_layer.total_num_heads * self.base_layer.head_size
        )
        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.kv_proj_total_size = (
            self.base_layer.total_num_kv_heads * self.base_layer.head_size
        )
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[
            self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
            * (self.q_shard_id + 1),
            :,
        ]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[
            k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[
            v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
        return lora_b

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1

QKVParallelLinearWithShardedLoRA

Bases: QKVParallelLinearWithLoRA

Differs from QKVParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
    """
    Differs from QKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

ReplicatedLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

Source code in vllm/lora/layers/replicated_linear.py
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(
            base_layer,
        )
        # To ensure interface compatibility, set to 1 always.
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None

        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        # ReplicatedLinear subclasses such as GateLinear override forward() to
        # dispatch custom kernels and/or adjust the output dtype. Apply LoRA on
        # top of the actual base-layer output instead of bypassing that path.
        return self._apply_base_forward(x)

    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return isinstance(source_layer, maybe_get_oot_by_class(ReplicatedLinear))

    def slice_lora_a(
        self, lora_a: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora a if splitting for tensor parallelism."""
        return lora_a

    def slice_lora_b(
        self, lora_b: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora b if splitting with tensor parallelism."""
        return lora_b

forward

forward(
    input_: Tensor,
) -> Tensor | tuple[Tensor, Tensor | None]

Forward of ReplicatedLinearWithLoRA

Parameters:

Name Type Description Default
input_ Tensor

Tensor whose last dimension is input_size.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor | None]
  • output
Tensor | tuple[Tensor, Tensor | None]
  • bias
Source code in vllm/lora/layers/replicated_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of ReplicatedLinearWithLoRA

    Args:
        input_: Tensor whose last dimension is `input_size`.

    Returns:
        - output
        - bias
    """
    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

    # Matrix multiply.
    output = self.apply(input_, bias)

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None

    if not self.base_layer.return_bias:
        return output

    return output, output_bias

slice_lora_a

slice_lora_a(
    lora_a: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora a if splitting for tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
def slice_lora_a(
    self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora a if splitting for tensor parallelism."""
    return lora_a

slice_lora_b

slice_lora_b(
    lora_b: Tensor | list[Tensor | None],
) -> Tensor | list[Tensor | None]

Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
def slice_lora_b(
    self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora b if splitting with tensor parallelism."""
    return lora_b

RowParallelLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

Source code in vllm/lora/layers/row_parallel_linear.py
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__(base_layer)

        # reset input_size
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        # There is only one LoRA layer.
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.input_size
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_a = lora_a[:, start_idx:end_idx]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            split_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size
            )
            input_parallel = split_input[self.tp_rank].contiguous()

        # Matrix multiply.
        bias_ = (
            None
            if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
            else self.base_layer.bias
        )
        output_parallel = self.apply(input_parallel, bias_)
        if self.base_layer.reduce_results and self.tp_size > 1:
            output = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output = output_parallel

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)

forward

forward(
    input_: Tensor,
) -> Tensor | tuple[Tensor, Tensor | None]

Forward of RowParallelLinear

Parameters:

Name Type Description Default
input_ Tensor

tensor whose last dimension is input_size. If input_is_parallel is set, then the last dimension is input_size // tp_size.

required

Returns:

Type Description
Tensor | tuple[Tensor, Tensor | None]
  • output
Tensor | tuple[Tensor, Tensor | None]
  • bias
Source code in vllm/lora/layers/row_parallel_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of RowParallelLinear

    Args:
        input_: tensor whose last dimension is `input_size`. If
                `input_is_parallel` is set, then the last dimension
                is `input_size // tp_size`.

    Returns:
        - output
        - bias
    """
    # set up backprop all-reduce.
    if self.base_layer.input_is_parallel:
        input_parallel = input_
    else:
        # TODO: simplify code below
        split_input = split_tensor_along_last_dim(
            input_, num_partitions=self.tp_size
        )
        input_parallel = split_input[self.tp_rank].contiguous()

    # Matrix multiply.
    bias_ = (
        None
        if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
        else self.base_layer.bias
    )
    output_parallel = self.apply(input_parallel, bias_)
    if self.base_layer.reduce_results and self.tp_size > 1:
        output = tensor_model_parallel_all_reduce(output_parallel)
    else:
        output = output_parallel

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
    if not self.base_layer.return_bias:
        return output

    return output, output_bias

RowParallelLinearWithShardedLoRA

Bases: RowParallelLinearWithLoRA

Differs from RowParallelLinearWithLoRA by slicing the LoRA B's also.

Based on S-LoRA, slicing happens along the output dim. This yields a combined partial sum from the row parallel base layer and column partitioned output from the LoRA.

Source code in vllm/lora/layers/row_parallel_linear.py
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
    """
    Differs from RowParallelLinearWithLoRA by slicing the
    LoRA B's also.

    Based on S-LoRA, slicing happens along the output dim.
    This yields a combined partial sum from the row parallel base
    layer and column partitioned output from the LoRA.
    """

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = lora_b[start_idx:end_idx, :]
        return lora_b

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
        buffer = torch.zeros(
            (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
            dtype=torch.float32,
            device=x.device,
        )

        shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
            buffer, x, self.lora_a_stacked, 1.0
        )
        if not current_platform.can_update_inplace():
            buffer = shrunk_buffer
        if self.tp_size > 1:
            buffer = tensor_model_parallel_all_reduce(buffer)

        # following S-LoRA, allows the fusing of all_gather and all_reduce
        # by adding the column partitioned lora output to a slice of output
        # tensor, which is a partial sum due to row parallel. All that
        # remains is a standard all_reduce. User should be aware though that
        # the output is not the same as a normal row_parallel, it should be
        # reduced before being used
        # NOTE offset are based on the rank.
        shard_size = self.lora_b_stacked[0].shape[2]
        offset_start = self.tp_rank * shard_size
        lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
            output,
            buffer,
            self.lora_b_stacked,
            self.output_slices,
            offset_start=offset_start,
            add_input=True,
        )

        if not current_platform.can_update_inplace():
            output = lora_output

        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )