Skip to content

Commit

Permalink
Enable vulkan for RWKV and some misc updates (mlc-ai#264)
Browse files Browse the repository at this point in the history
This PR enables vulkan for rwkv, removes the vulkan
from local detection as it can cause cross platform issues.
Introduce max_gen_len parameter for chat
  • Loading branch information
tqchen authored May 29, 2023
1 parent 60e2176 commit c409ca0
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 38 deletions.
1 change: 1 addition & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def dump_default_mlc_chat_config(args):
config["repetition_penalty"] = 1.0
config["top_p"] = 0.95
config["mean_gen_len"] = 128
config["max_gen_len"] = 512
config["shift_fill_factor"] = 0.3
config["tokenizer_files"] = utils.get_tokenizer_files(params_path)

Expand Down
2 changes: 1 addition & 1 deletion cpp/conversation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ picojson::value Conversation::SerializeToJSON() const {
return picojson::value(config);
}

std::string Conversation::SerializeToJSONStr() const { return SerializeToJSON().serialize(true); }
std::string Conversation::GetConfigJSON() const { return SerializeToJSON().serialize(true); }

} // namespace llm
} // namespace mlc
2 changes: 1 addition & 1 deletion cpp/conversation.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class Conversation {
* \brief Serialize the Conversation to JSON String.
* \return A string storing the serialized conversation in JSON format.
*/
std::string SerializeToJSONStr() const;
std::string GetConfigJSON() const;

/*!
* \brief Get the entire prompt array
Expand Down
41 changes: 25 additions & 16 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ class LLMChat {
} else {
CHECK(partial_update) << "Key \"mean_gen_len\" not found.";
}
// NOTE: for backward compact
// max gen len is optonal
if (config.count("max_gen_len")) {
CHECK(config["max_gen_len"].is<int64_t>());
this->max_gen_len_ = config["max_gen_len"].get<int64_t>();
}
if (config.count("shift_fill_factor")) {
CHECK(config["shift_fill_factor"].is<double>());
this->shift_fill_factor_ = config["shift_fill_factor"].get<double>();
Expand Down Expand Up @@ -219,18 +225,8 @@ class LLMChat {
LoadJSONOverride(config_json, partial_update);
}

picojson::value SerializeToJSON() const {
picojson::object config;
config["temperature"] = picojson::value(this->temperature_);
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
config["top_p"] = picojson::value(this->top_p_);
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_);
config["conv_config"] = this->conversation_.SerializeToJSON();
return picojson::value(config);
}

std::string SerializeToJSONStr() const { return SerializeToJSON().serialize(true); }
std::string GetConfigJSON() const { return SerializeConfigToJSONValue().serialize(true); }

/*!
* \brief Reload model, tokenizers and configurations from the specified model path.
Expand Down Expand Up @@ -595,6 +591,17 @@ class LLMChat {
}

private:
picojson::value SerializeConfigToJSONValue() const {
picojson::object config;
config["temperature"] = picojson::value(this->temperature_);
config["repetition_penalty"] = picojson::value(this->repetition_penalty_);
config["top_p"] = picojson::value(this->top_p_);
config["mean_gen_len"] = picojson::value(this->mean_gen_len_);
config["max_gen_len"] = picojson::value(this->max_gen_len_);
config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_);
config["conv_config"] = this->conversation_.SerializeToJSON();
return picojson::value(config);
}
/*!
* \brief Sample output token from logits on device
*/
Expand Down Expand Up @@ -662,8 +669,10 @@ class LLMChat {
}
}
}
// TODO(mlc-team): add another per convo seq len trigger
if (total_seq_len_ >= max_window_size_) {

if (static_cast<int64_t>(output_ids_.size()) >= max_gen_len_) {
stop_triggered_ = true;
} else if (total_seq_len_ >= max_window_size_) {
stop_triggered_ = true;
}
if (stop_triggered_) {
Expand Down Expand Up @@ -783,7 +792,7 @@ class LLMChat {
// total sequence len,
int64_t total_seq_len_{0};
// max window size, mean generation length
int64_t max_window_size_{768}, mean_gen_len_{128};
int64_t max_window_size_{768}, mean_gen_len_{128}, max_gen_len_{512};
// shift window fill factor
double shift_fill_factor_{0.3};
// temperature
Expand Down Expand Up @@ -927,9 +936,9 @@ class LLMChatModule : public ModuleNode {
} else if (name == "reset_runtime_stats") {
return PackedFunc(
[this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->ResetRuntimeStats(); });
} else if (name == "serialize_config") {
} else if (name == "get_config_json") {
return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
*rv = GetChat()->SerializeToJSONStr();
*rv = GetChat()->GetConfigJSON();
});
} else {
return PackedFunc(nullptr);
Expand Down
39 changes: 20 additions & 19 deletions mlc_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class Quantization:
"q0f16": Quantization(
name="q0f16", mode="no", sym=False, storage_nbit=-1, model_dtype="float16"
),
"q8f16": Quantization(
name="q8f16", mode="uint8", sym=False, storage_nbit=-1, model_dtype="float16"
"q8f16_0": Quantization(
name="q8f16_0", mode="uint8", sym=False, storage_nbit=-1, model_dtype="float16"
),
}

Expand Down Expand Up @@ -306,6 +306,7 @@ def _detect_local_vulkan():
"thread_warp_size": dev.warp_size,
"supports_float16": 1,
"supports_int16": 1,
"supports_int8": 1,
"supports_16bit_buffer": 1,
}
)
Expand Down Expand Up @@ -424,23 +425,23 @@ def compile_metal(src, target):
args.target = target
args.target_kind = "iphone"
elif args.target == "vulkan":
target = _detect_local_vulkan()
if target is None:
print("Cannot detect local Vulkan GPU target! Falling back...")
target = tvm.target.Target(
tvm.target.Target(
{
"kind": "vulkan",
"max_threads_per_block": 256,
"max_shared_memory_per_block": 32768,
"thread_warp_size": 1,
"supports_float16": 1,
"supports_int16": 1,
"supports_16bit_buffer": 1,
}
),
host="llvm",
)
target = tvm.target.Target(
tvm.target.Target(
{
"kind": "vulkan",
"max_threads_per_block": 256,
"max_shared_memory_per_block": 32768,
"thread_warp_size": 1,
"supports_float16": 1,
"supports_int16": 1,
"supports_int8": 1,
"supports_8bit_buffer": 1,
"supports_16bit_buffer": 1,
"supports_storage_buffer_storage_class": 1
}
),
host="llvm",
)
args.target = target
args.target_kind = args.target.kind.default_keys[0]
elif args.target == "webgpu":
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/conv_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

void _TestConversationJSONRoundTrip(std::string templ_name) {
mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name);
std::string conv_json = conv.SerializeToJSONStr();
std::string conv_json = conv.GetConfigJSON();
mlc::llm::Conversation conv_new;
conv_new.LoadJSONOverride(conv_json, false);
ASSERT_EQ(conv, conv_new);
Expand Down

0 comments on commit c409ca0

Please sign in to comment.