From c409ca0114672ae79a948d79a1517bbbc2294f79 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 29 May 2023 14:08:21 -0400 Subject: [PATCH] Enable vulkan for RWKV and some misc updates (#264) This PR enables vulkan for rwkv, removes the vulkan from local detection as it can cause cross platform issues. Introduce max_gen_len parameter for chat --- build.py | 1 + cpp/conversation.cc | 2 +- cpp/conversation.h | 2 +- cpp/llm_chat.cc | 41 +++++++++++++++++++++++--------------- mlc_llm/utils.py | 39 ++++++++++++++++++------------------ tests/cpp/conv_unittest.cc | 2 +- 6 files changed, 49 insertions(+), 38 deletions(-) diff --git a/build.py b/build.py index 7c54fc3365..8bae0c8255 100644 --- a/build.py +++ b/build.py @@ -291,6 +291,7 @@ def dump_default_mlc_chat_config(args): config["repetition_penalty"] = 1.0 config["top_p"] = 0.95 config["mean_gen_len"] = 128 + config["max_gen_len"] = 512 config["shift_fill_factor"] = 0.3 config["tokenizer_files"] = utils.get_tokenizer_files(params_path) diff --git a/cpp/conversation.cc b/cpp/conversation.cc index 28e752001b..1966f57e24 100644 --- a/cpp/conversation.cc +++ b/cpp/conversation.cc @@ -148,7 +148,7 @@ picojson::value Conversation::SerializeToJSON() const { return picojson::value(config); } -std::string Conversation::SerializeToJSONStr() const { return SerializeToJSON().serialize(true); } +std::string Conversation::GetConfigJSON() const { return SerializeToJSON().serialize(true); } } // namespace llm } // namespace mlc diff --git a/cpp/conversation.h b/cpp/conversation.h index 0772bf0d96..b8934a698b 100644 --- a/cpp/conversation.h +++ b/cpp/conversation.h @@ -135,7 +135,7 @@ class Conversation { * \brief Serialize the Conversation to JSON String. * \return A string storing the serialized conversation in JSON format. */ - std::string SerializeToJSONStr() const; + std::string GetConfigJSON() const; /*! * \brief Get the entire prompt array diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 46b7fc4117..9954fa5285 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -179,6 +179,12 @@ class LLMChat { } else { CHECK(partial_update) << "Key \"mean_gen_len\" not found."; } + // NOTE: for backward compact + // max gen len is optonal + if (config.count("max_gen_len")) { + CHECK(config["max_gen_len"].is()); + this->max_gen_len_ = config["max_gen_len"].get(); + } if (config.count("shift_fill_factor")) { CHECK(config["shift_fill_factor"].is()); this->shift_fill_factor_ = config["shift_fill_factor"].get(); @@ -219,18 +225,8 @@ class LLMChat { LoadJSONOverride(config_json, partial_update); } - picojson::value SerializeToJSON() const { - picojson::object config; - config["temperature"] = picojson::value(this->temperature_); - config["repetition_penalty"] = picojson::value(this->repetition_penalty_); - config["top_p"] = picojson::value(this->top_p_); - config["mean_gen_len"] = picojson::value(this->mean_gen_len_); - config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_); - config["conv_config"] = this->conversation_.SerializeToJSON(); - return picojson::value(config); - } - std::string SerializeToJSONStr() const { return SerializeToJSON().serialize(true); } + std::string GetConfigJSON() const { return SerializeConfigToJSONValue().serialize(true); } /*! * \brief Reload model, tokenizers and configurations from the specified model path. @@ -595,6 +591,17 @@ class LLMChat { } private: + picojson::value SerializeConfigToJSONValue() const { + picojson::object config; + config["temperature"] = picojson::value(this->temperature_); + config["repetition_penalty"] = picojson::value(this->repetition_penalty_); + config["top_p"] = picojson::value(this->top_p_); + config["mean_gen_len"] = picojson::value(this->mean_gen_len_); + config["max_gen_len"] = picojson::value(this->max_gen_len_); + config["shift_fill_factor"] = picojson::value(this->shift_fill_factor_); + config["conv_config"] = this->conversation_.SerializeToJSON(); + return picojson::value(config); + } /*! * \brief Sample output token from logits on device */ @@ -662,8 +669,10 @@ class LLMChat { } } } - // TODO(mlc-team): add another per convo seq len trigger - if (total_seq_len_ >= max_window_size_) { + + if (static_cast(output_ids_.size()) >= max_gen_len_) { + stop_triggered_ = true; + } else if (total_seq_len_ >= max_window_size_) { stop_triggered_ = true; } if (stop_triggered_) { @@ -783,7 +792,7 @@ class LLMChat { // total sequence len, int64_t total_seq_len_{0}; // max window size, mean generation length - int64_t max_window_size_{768}, mean_gen_len_{128}; + int64_t max_window_size_{768}, mean_gen_len_{128}, max_gen_len_{512}; // shift window fill factor double shift_fill_factor_{0.3}; // temperature @@ -927,9 +936,9 @@ class LLMChatModule : public ModuleNode { } else if (name == "reset_runtime_stats") { return PackedFunc( [this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->ResetRuntimeStats(); }); - } else if (name == "serialize_config") { + } else if (name == "get_config_json") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - *rv = GetChat()->SerializeToJSONStr(); + *rv = GetChat()->GetConfigJSON(); }); } else { return PackedFunc(nullptr); diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 2d4fe48ddc..cd7ac40249 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -40,8 +40,8 @@ class Quantization: "q0f16": Quantization( name="q0f16", mode="no", sym=False, storage_nbit=-1, model_dtype="float16" ), - "q8f16": Quantization( - name="q8f16", mode="uint8", sym=False, storage_nbit=-1, model_dtype="float16" + "q8f16_0": Quantization( + name="q8f16_0", mode="uint8", sym=False, storage_nbit=-1, model_dtype="float16" ), } @@ -306,6 +306,7 @@ def _detect_local_vulkan(): "thread_warp_size": dev.warp_size, "supports_float16": 1, "supports_int16": 1, + "supports_int8": 1, "supports_16bit_buffer": 1, } ) @@ -424,23 +425,23 @@ def compile_metal(src, target): args.target = target args.target_kind = "iphone" elif args.target == "vulkan": - target = _detect_local_vulkan() - if target is None: - print("Cannot detect local Vulkan GPU target! Falling back...") - target = tvm.target.Target( - tvm.target.Target( - { - "kind": "vulkan", - "max_threads_per_block": 256, - "max_shared_memory_per_block": 32768, - "thread_warp_size": 1, - "supports_float16": 1, - "supports_int16": 1, - "supports_16bit_buffer": 1, - } - ), - host="llvm", - ) + target = tvm.target.Target( + tvm.target.Target( + { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_8bit_buffer": 1, + "supports_16bit_buffer": 1, + "supports_storage_buffer_storage_class": 1 + } + ), + host="llvm", + ) args.target = target args.target_kind = args.target.kind.default_keys[0] elif args.target == "webgpu": diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc index 4ab6141b6f..214736320d 100644 --- a/tests/cpp/conv_unittest.cc +++ b/tests/cpp/conv_unittest.cc @@ -3,7 +3,7 @@ void _TestConversationJSONRoundTrip(std::string templ_name) { mlc::llm::Conversation conv = mlc::llm::Conversation::FromTemplate(templ_name); - std::string conv_json = conv.SerializeToJSONStr(); + std::string conv_json = conv.GetConfigJSON(); mlc::llm::Conversation conv_new; conv_new.LoadJSONOverride(conv_json, false); ASSERT_EQ(conv, conv_new);