Skip to content

Commit 87d9e35

Browse files
authored
feat: implement column parallel for lm head to improve performance. (#1145)
1 parent 5b3a9a4 commit 87d9e35

File tree

11 files changed

+229
-69
lines changed

11 files changed

+229
-69
lines changed

xllm/core/layers/common/linear.cpp

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,14 @@ torch::Tensor fp8_linear_forward(
181181
} // namespace
182182

183183
ColumnParallelLinearImpl::ColumnParallelLinearImpl(const ModelContext& context)
184-
: ColumnParallelLinearImpl(context.get_model_args().hidden_size(),
185-
context.get_model_args().vocab_size(),
186-
/*bias=*/false,
187-
/*gather_output=*/true,
188-
context.get_quant_args(),
189-
context.get_parallel_args().tp_group_,
190-
context.get_tensor_options()) {}
184+
: ColumnParallelLinearImpl(
185+
context.get_model_args().hidden_size(),
186+
context.get_model_args().vocab_size(),
187+
/*bias=*/false,
188+
/*gather_output=*/true,
189+
QuantArgs{}, // do not use quantization for lm_head
190+
context.get_parallel_args().tp_group_,
191+
context.get_tensor_options()) {}
191192

192193
// Linear layer with column parallelism.
193194
ColumnParallelLinearImpl::ColumnParallelLinearImpl(
@@ -667,17 +668,6 @@ std::optional<torch::Tensor> QKVParallelLinearImpl::get_input_scale() const {
667668
return std::nullopt;
668669
}
669670

670-
// Linear layer with row parallelism.
671-
RowParallelLinearImpl::RowParallelLinearImpl(const ModelContext& context)
672-
: RowParallelLinearImpl(context.get_model_args().hidden_size(),
673-
context.get_model_args().vocab_size(),
674-
/*bias=*/false,
675-
/*input_is_parallelized=*/false,
676-
/*enable_result_reduction=*/true,
677-
context.get_quant_args(),
678-
context.get_parallel_args().tp_group_,
679-
context.get_tensor_options()) {}
680-
681671
// Linear layer with row parallelism.
682672
RowParallelLinearImpl::RowParallelLinearImpl(
683673
int64_t in_features,

xllm/core/layers/common/linear.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,6 @@ TORCH_MODULE(QKVParallelLinear);
198198
// - -
199199
class RowParallelLinearImpl : public torch::nn::Module {
200200
public:
201-
RowParallelLinearImpl(const ModelContext& context);
202-
203201
RowParallelLinearImpl(
204202
int64_t in_features,
205203
int64_t out_features,

xllm/core/layers/common/lm_head.h

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,13 @@ limitations under the License.
2020
namespace xllm {
2121
namespace layer {
2222

23-
class LmHead : public torch::nn::ModuleHolder<RowParallelLinearImpl> {
23+
class LmHead : public torch::nn::ModuleHolder<ColumnParallelLinearImpl> {
2424
public:
25-
using torch::nn::ModuleHolder<RowParallelLinearImpl>::ModuleHolder;
26-
using Impl __attribute__((__unused__)) = RowParallelLinearImpl;
25+
using torch::nn::ModuleHolder<ColumnParallelLinearImpl>::ModuleHolder;
26+
using Impl __attribute__((__unused__)) = ColumnParallelLinearImpl;
2727

2828
LmHead(const ModelContext& context)
29-
: ModuleHolder(std::make_shared<RowParallelLinearImpl>(
30-
// NOTE: Quantization should NOT be used for the final language
31-
// modeling head (lm_head). The output logits must remain in high
32-
// precision (typically bfloat16/float16) for numerical stability
33-
// and correct evaluation of loss and predictions. Always use
34-
// unquantized weights here.
35-
context.get_model_args().hidden_size(),
36-
context.get_model_args().vocab_size(),
37-
/*bias=*/false,
38-
/*input_is_parallelized=*/false,
39-
/*enable_result_reduction=*/true,
40-
QuantArgs{}, // do not use quantization for lm_head!
41-
context.get_parallel_args().tp_group_,
42-
context.get_tensor_options())) {}
29+
: ModuleHolder(std::make_shared<ColumnParallelLinearImpl>(context)) {}
4330
};
4431

4532
} // namespace layer

xllm/core/layers/npu/loader/base_loader.cpp

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,5 +176,120 @@ torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) {
176176
LOG(FATAL) << "Unsupported dtype string: " << dtype_str;
177177
}
178178

179+
at::Tensor BaseLoader::pad_vocab_tensor(const at::Tensor& tensor,
180+
int64_t padded_vocab_size) const {
181+
if (tensor.size(0) >= padded_vocab_size) {
182+
return tensor;
183+
}
184+
at::Tensor padded_tensor =
185+
torch::zeros({padded_vocab_size, tensor.size(1)}, tensor.options());
186+
padded_tensor.slice(0, 0, tensor.size(0)) = tensor;
187+
return padded_tensor;
188+
}
189+
190+
at::Tensor BaseLoader::shard_padded_tensor(const at::Tensor& padded_tensor,
191+
int dim,
192+
int rank,
193+
int world_size) const {
194+
if (world_size <= 1) {
195+
return padded_tensor;
196+
}
197+
auto chunks = padded_tensor.chunk(world_size, dim);
198+
return chunks[rank];
199+
}
200+
201+
void BaseLoader::set_weight_with_padding(const StateDict& state_dict,
202+
const std::string& tensor_name,
203+
int weight_position,
204+
int dim,
205+
int64_t padded_vocab_size,
206+
bool to_host) {
207+
auto device = to_host ? at::kCPU : device_;
208+
for (const auto& [name, tensor] : state_dict) {
209+
if (absl::EndsWith(name, tensor_name)) {
210+
at::Tensor mutable_tensor = tensor;
211+
if (padded_vocab_size > tensor.size(0)) {
212+
mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
213+
}
214+
correct_tensor_dtype(mutable_tensor, tensor_name);
215+
if (to_host) {
216+
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
217+
} else {
218+
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
219+
}
220+
}
221+
}
222+
}
223+
224+
void BaseLoader::set_weight_with_padding(const StateDict& state_dict,
225+
const std::string& tensor_name,
226+
int weight_position,
227+
int dim,
228+
int rank,
229+
int world_size,
230+
int64_t padded_vocab_size,
231+
bool to_host) {
232+
auto device = to_host ? at::kCPU : device_;
233+
if (world_size <= 1) {
234+
set_weight_with_padding(state_dict,
235+
tensor_name,
236+
weight_position,
237+
dim,
238+
padded_vocab_size,
239+
to_host);
240+
return;
241+
}
242+
for (const auto& [name, tensor] : state_dict) {
243+
if (absl::EndsWith(name, tensor_name)) {
244+
at::Tensor mutable_tensor = tensor;
245+
if (padded_vocab_size > tensor.size(0)) {
246+
// Memory-optimized path for vocabulary dimension sharding
247+
if (dim == 0) {
248+
int64_t shard_size = padded_vocab_size / world_size;
249+
int64_t start_idx = rank * shard_size;
250+
int64_t end_idx = (rank + 1) * shard_size;
251+
if (start_idx >= tensor.size(0)) {
252+
mutable_tensor =
253+
torch::zeros({shard_size, tensor.size(1)}, tensor.options());
254+
} else {
255+
auto valid_part =
256+
tensor.slice(0, start_idx, std::min(end_idx, tensor.size(0)));
257+
if (valid_part.size(0) < shard_size) {
258+
mutable_tensor =
259+
torch::zeros({shard_size, tensor.size(1)}, tensor.options());
260+
mutable_tensor.slice(0, 0, valid_part.size(0)).copy_(valid_part);
261+
} else {
262+
mutable_tensor = valid_part.clone();
263+
}
264+
}
265+
} else {
266+
// Non-vocabulary dimension: use original approach
267+
mutable_tensor = pad_vocab_tensor(tensor, padded_vocab_size);
268+
mutable_tensor =
269+
shard_padded_tensor(mutable_tensor, dim, rank, world_size);
270+
}
271+
} else {
272+
mutable_tensor =
273+
state_dict.get_sharded_tensor(tensor_name, dim, rank, world_size);
274+
}
275+
correct_tensor_dtype(mutable_tensor, tensor_name);
276+
if (to_host) {
277+
at_host_weight_tensors_[weight_position] = mutable_tensor.to(device);
278+
} else {
279+
at_weight_tensors_[weight_position] = mutable_tensor.to(device);
280+
}
281+
}
282+
}
283+
}
284+
285+
int64_t BaseLoader::get_padded_vocab_size(const ModelContext& context) const {
286+
int64_t vocab_size = context.get_model_args().vocab_size();
287+
int32_t local_tp_size = dp_local_tp_size_;
288+
if (vocab_size > 0 && local_tp_size > 1 && vocab_size % local_tp_size != 0) {
289+
return ((vocab_size + local_tp_size - 1) / local_tp_size) * local_tp_size;
290+
}
291+
return vocab_size;
292+
}
293+
179294
} // namespace layer
180-
} // namespace xllm
295+
} // namespace xllm

xllm/core/layers/npu/loader/base_loader.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,33 @@ class BaseLoader {
111111
int rank,
112112
int world_size,
113113
bool to_host = false);
114+
115+
void set_weight_with_padding(const StateDict& state_dict,
116+
const std::string& tensor_name,
117+
int weight_position,
118+
int dim,
119+
int64_t padded_vocab_size,
120+
bool to_host = false);
121+
122+
void set_weight_with_padding(const StateDict& state_dict,
123+
const std::string& tensor_name,
124+
int weight_position,
125+
int dim,
126+
int rank,
127+
int world_size,
128+
int64_t padded_vocab_size,
129+
bool to_host = false);
130+
131+
at::Tensor pad_vocab_tensor(const at::Tensor& tensor,
132+
int64_t padded_vocab_size) const;
133+
134+
at::Tensor shard_padded_tensor(const at::Tensor& padded_tensor,
135+
int dim,
136+
int rank,
137+
int world_size) const;
138+
139+
int64_t get_padded_vocab_size(const ModelContext& context) const;
114140
};
115141

116142
} // namespace layer
117-
} // namespace xllm
143+
} // namespace xllm

xllm/core/layers/npu/loader/lm_head_loader.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,32 @@ LmHeadLoader::LmHeadLoader(uint64_t weight_count, const ModelContext& context)
2222
: BaseLoader(weight_count, context) {
2323
auto options = context.get_tensor_options();
2424
at_weight_tensors_[0] = torch::zeros({1}).to(options);
25+
vocab_size_ = context.get_model_args().vocab_size();
26+
padded_vocab_size_ = get_padded_vocab_size(context);
2527
}
2628

2729
void LmHeadLoader::load_state_dict(const StateDict& state_dict) {
28-
if (cp_size_ > 1) {
29-
set_weight(
30-
state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_);
31-
} else if (dp_size_ > 1) {
32-
set_weight(
33-
state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_);
30+
if (cp_size_ > 1 || dp_size_ > 1) {
31+
set_weight_with_padding(state_dict,
32+
"weight",
33+
0,
34+
0,
35+
dp_local_tp_rank_,
36+
dp_local_tp_size_,
37+
padded_vocab_size_,
38+
false);
39+
} else if (parallel_args_.world_size() > 1) {
40+
set_weight_with_padding(state_dict,
41+
"weight",
42+
0,
43+
0,
44+
parallel_args_.rank(),
45+
parallel_args_.world_size(),
46+
padded_vocab_size_,
47+
false);
3448
} else {
35-
set_weight(state_dict, "weight", 0, 1);
49+
set_weight_with_padding(
50+
state_dict, "weight", 0, 0, padded_vocab_size_, false);
3651
}
3752
}
3853

xllm/core/layers/npu/loader/lm_head_loader.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class LmHeadLoader : public BaseLoader {
2525

2626
void load_state_dict(const StateDict& state_dict) override;
2727
void verify_loaded_weights(const std::string& weight_str) const override;
28+
29+
private:
30+
int64_t vocab_size_ = -1;
31+
int64_t padded_vocab_size_ = -1;
2832
};
2933
} // namespace layer
3034
} // namespace xllm

xllm/core/layers/npu/loader/lm_head_manual_loader.cpp

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,32 @@ LmHeadManualLoader::LmHeadManualLoader(uint64_t weight_count,
2323
: BaseManualLoader(weight_count, context) {
2424
auto options = context.get_tensor_options();
2525
at_weight_tensors_[0] = torch::zeros({1}).to(options);
26+
vocab_size_ = context.get_model_args().vocab_size();
27+
padded_vocab_size_ = get_padded_vocab_size(context);
2628
}
2729

2830
void LmHeadManualLoader::load_state_dict(const StateDict& state_dict) {
29-
if (cp_size_ > 1) {
30-
set_weight(
31-
state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_, true);
32-
} else if (dp_size_ > 1) {
33-
set_weight(
34-
state_dict, "weight", 0, 1, dp_local_tp_rank_, dp_local_tp_size_, true);
31+
if (cp_size_ > 1 || dp_size_ > 1) {
32+
set_weight_with_padding(state_dict,
33+
"weight",
34+
0,
35+
0,
36+
dp_local_tp_rank_,
37+
dp_local_tp_size_,
38+
padded_vocab_size_,
39+
true);
40+
} else if (parallel_args_.world_size() > 1) {
41+
set_weight_with_padding(state_dict,
42+
"weight",
43+
0,
44+
0,
45+
parallel_args_.rank(),
46+
parallel_args_.world_size(),
47+
padded_vocab_size_,
48+
true);
3549
} else {
36-
set_weight(state_dict, "weight", 0, 1, true);
50+
set_weight_with_padding(
51+
state_dict, "weight", 0, 0, padded_vocab_size_, true);
3752
}
3853
}
3954

xllm/core/layers/npu/loader/lm_head_manual_loader.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class LmHeadManualLoader : public BaseManualLoader {
2828

2929
protected:
3030
void merge_host_at_weights() override;
31+
32+
private:
33+
int64_t vocab_size_ = -1;
34+
int64_t padded_vocab_size_ = -1;
3135
};
3236
} // namespace layer
33-
} // namespace xllm
37+
} // namespace xllm

0 commit comments

Comments
 (0)