@@ -176,5 +176,120 @@ torch::Dtype BaseLoader::string2dtype(const std::string& dtype_str) {
176176 LOG (FATAL) << " Unsupported dtype string: " << dtype_str;
177177}
178178
179+ at::Tensor BaseLoader::pad_vocab_tensor (const at::Tensor& tensor,
180+ int64_t padded_vocab_size) const {
181+ if (tensor.size (0 ) >= padded_vocab_size) {
182+ return tensor;
183+ }
184+ at::Tensor padded_tensor =
185+ torch::zeros ({padded_vocab_size, tensor.size (1 )}, tensor.options ());
186+ padded_tensor.slice (0 , 0 , tensor.size (0 )) = tensor;
187+ return padded_tensor;
188+ }
189+
190+ at::Tensor BaseLoader::shard_padded_tensor (const at::Tensor& padded_tensor,
191+ int dim,
192+ int rank,
193+ int world_size) const {
194+ if (world_size <= 1 ) {
195+ return padded_tensor;
196+ }
197+ auto chunks = padded_tensor.chunk (world_size, dim);
198+ return chunks[rank];
199+ }
200+
201+ void BaseLoader::set_weight_with_padding (const StateDict& state_dict,
202+ const std::string& tensor_name,
203+ int weight_position,
204+ int dim,
205+ int64_t padded_vocab_size,
206+ bool to_host) {
207+ auto device = to_host ? at::kCPU : device_;
208+ for (const auto & [name, tensor] : state_dict) {
209+ if (absl::EndsWith (name, tensor_name)) {
210+ at::Tensor mutable_tensor = tensor;
211+ if (padded_vocab_size > tensor.size (0 )) {
212+ mutable_tensor = pad_vocab_tensor (tensor, padded_vocab_size);
213+ }
214+ correct_tensor_dtype (mutable_tensor, tensor_name);
215+ if (to_host) {
216+ at_host_weight_tensors_[weight_position] = mutable_tensor.to (device);
217+ } else {
218+ at_weight_tensors_[weight_position] = mutable_tensor.to (device);
219+ }
220+ }
221+ }
222+ }
223+
224+ void BaseLoader::set_weight_with_padding (const StateDict& state_dict,
225+ const std::string& tensor_name,
226+ int weight_position,
227+ int dim,
228+ int rank,
229+ int world_size,
230+ int64_t padded_vocab_size,
231+ bool to_host) {
232+ auto device = to_host ? at::kCPU : device_;
233+ if (world_size <= 1 ) {
234+ set_weight_with_padding (state_dict,
235+ tensor_name,
236+ weight_position,
237+ dim,
238+ padded_vocab_size,
239+ to_host);
240+ return ;
241+ }
242+ for (const auto & [name, tensor] : state_dict) {
243+ if (absl::EndsWith (name, tensor_name)) {
244+ at::Tensor mutable_tensor = tensor;
245+ if (padded_vocab_size > tensor.size (0 )) {
246+ // Memory-optimized path for vocabulary dimension sharding
247+ if (dim == 0 ) {
248+ int64_t shard_size = padded_vocab_size / world_size;
249+ int64_t start_idx = rank * shard_size;
250+ int64_t end_idx = (rank + 1 ) * shard_size;
251+ if (start_idx >= tensor.size (0 )) {
252+ mutable_tensor =
253+ torch::zeros ({shard_size, tensor.size (1 )}, tensor.options ());
254+ } else {
255+ auto valid_part =
256+ tensor.slice (0 , start_idx, std::min (end_idx, tensor.size (0 )));
257+ if (valid_part.size (0 ) < shard_size) {
258+ mutable_tensor =
259+ torch::zeros ({shard_size, tensor.size (1 )}, tensor.options ());
260+ mutable_tensor.slice (0 , 0 , valid_part.size (0 )).copy_ (valid_part);
261+ } else {
262+ mutable_tensor = valid_part.clone ();
263+ }
264+ }
265+ } else {
266+ // Non-vocabulary dimension: use original approach
267+ mutable_tensor = pad_vocab_tensor (tensor, padded_vocab_size);
268+ mutable_tensor =
269+ shard_padded_tensor (mutable_tensor, dim, rank, world_size);
270+ }
271+ } else {
272+ mutable_tensor =
273+ state_dict.get_sharded_tensor (tensor_name, dim, rank, world_size);
274+ }
275+ correct_tensor_dtype (mutable_tensor, tensor_name);
276+ if (to_host) {
277+ at_host_weight_tensors_[weight_position] = mutable_tensor.to (device);
278+ } else {
279+ at_weight_tensors_[weight_position] = mutable_tensor.to (device);
280+ }
281+ }
282+ }
283+ }
284+
285+ int64_t BaseLoader::get_padded_vocab_size (const ModelContext& context) const {
286+ int64_t vocab_size = context.get_model_args ().vocab_size ();
287+ int32_t local_tp_size = dp_local_tp_size_;
288+ if (vocab_size > 0 && local_tp_size > 1 && vocab_size % local_tp_size != 0 ) {
289+ return ((vocab_size + local_tp_size - 1 ) / local_tp_size) * local_tp_size;
290+ }
291+ return vocab_size;
292+ }
293+
179294} // namespace layer
180- } // namespace xllm
295+ } // namespace xllm
0 commit comments