Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored May 2, 2023
1 parent 1a003c8 commit 7f349a4
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions cpp/cli_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ std::string DetectDeviceName(std::string device_name) {
return device_name;
}

DLDevice GetDevice(const std::string& device_name) {
if (device_name == "cuda") return DLDevice{kDLCUDA, 0};
if (device_name == "metal") return DLDevice{kDLMetal, 0};
if (device_name == "vulkan") return DLDevice{kDLVulkan, 0};
if (device_name == "opencl") return DLDevice{kDLOpenCL, 0};
DLDevice GetDevice(const std::string& device_name, int device_id) {
if (device_name == "cuda") return DLDevice{kDLCUDA, device_id};
if (device_name == "metal") return DLDevice{kDLMetal, device_id};
if (device_name == "vulkan") return DLDevice{kDLVulkan, device_id};
if (device_name == "opencl") return DLDevice{kDLOpenCL, device_id};
LOG(FATAL) << "Do not recognize device name " << device_name;
return DLDevice{kDLCPU, 0};
}
Expand Down Expand Up @@ -230,6 +230,7 @@ int main(int argc, char* argv[]) {
argparse::ArgumentParser args("mlc_chat");

args.add_argument("--device-name").default_value("auto");
args.add_argument("--device_id").default_value(0);
args.add_argument("--artifact-path").default_value("dist");
args.add_argument("--model").default_value("vicuna-v1-7b");
args.add_argument("--dtype").default_value("auto");
Expand All @@ -245,7 +246,8 @@ int main(int argc, char* argv[]) {
}

std::string device_name = DetectDeviceName(args.get<std::string>("--device-name"));
DLDevice device = GetDevice(device_name);
int device_id = args.get<int>("--device_id");
DLDevice device = GetDevice(device_name, device_id);
std::string artifact_path = args.get<std::string>("--artifact-path");
std::string model = args.get<std::string>("--model");
std::string dtype = args.get<std::string>("--dtype");
Expand Down

0 comments on commit 7f349a4

Please sign in to comment.