From 4bd415a06815c76e62cbf0a75d43cf8e8e42cd79 Mon Sep 17 00:00:00 2001 From: James Montgomery Date: Mon, 12 Aug 2024 00:39:38 -0400 Subject: [PATCH 1/2] Update embeddings to new endpoint. --- include/ollama.hpp | 19 ++++++++++--------- singleheader/ollama.hpp | 19 ++++++++++--------- test/test.cpp | 3 +-- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/include/ollama.hpp b/include/ollama.hpp index 73f0e90..fc540ed 100644 --- a/include/ollama.hpp +++ b/include/ollama.hpp @@ -263,13 +263,14 @@ namespace ollama request(): json() {} ~request(){}; - static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m") { ollama::request request(message_type::embedding); - request["model"] = name; - request["prompt"] = prompt; + request["model"] = model; + request["input"] = input; if (options!=nullptr) request["options"] = options["options"]; + request["truncate"] = truncate; request["keep_alive"] = keep_alive_duration; return request; @@ -295,7 +296,7 @@ namespace ollama if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get(); else - if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get(); + if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get(); else if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get(); @@ -715,15 +716,15 @@ class Ollama return false; } - ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration); + ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration); ollama::response response; std::string request_string = request.dump(); if (ollama::log_requests) std::cout << request_string << std::endl; - if (auto res = cli->Post("/api/embeddings", request_string, "application/json")) + if (auto res = cli->Post("/api/embed", request_string, "application/json")) { if (ollama::log_replies) std::cout << res->body << std::endl; @@ -885,9 +886,9 @@ namespace ollama return ollama.push_model(model, allow_insecure); } - inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - return ollama.generate_embeddings(model, prompt, options, keep_alive_duration); + return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration); } inline void setReadTimeout(const int& seconds) diff --git a/singleheader/ollama.hpp b/singleheader/ollama.hpp index 88ac415..d1aff5d 100644 --- a/singleheader/ollama.hpp +++ b/singleheader/ollama.hpp @@ -35053,13 +35053,14 @@ namespace ollama request(): json() {} ~request(){}; - static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m") { ollama::request request(message_type::embedding); - request["model"] = name; - request["prompt"] = prompt; + request["model"] = model; + request["input"] = input; if (options!=nullptr) request["options"] = options["options"]; + request["truncate"] = truncate; request["keep_alive"] = keep_alive_duration; return request; @@ -35085,7 +35086,7 @@ namespace ollama if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get(); else - if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get(); + if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get(); else if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get(); @@ -35505,15 +35506,15 @@ class Ollama return false; } - ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration); + ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration); ollama::response response; std::string request_string = request.dump(); if (ollama::log_requests) std::cout << request_string << std::endl; - if (auto res = cli->Post("/api/embeddings", request_string, "application/json")) + if (auto res = cli->Post("/api/embed", request_string, "application/json")) { if (ollama::log_replies) std::cout << res->body << std::endl; @@ -35675,9 +35676,9 @@ namespace ollama return ollama.push_model(model, allow_insecure); } - inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - return ollama.generate_embeddings(model, prompt, options, keep_alive_duration); + return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration); } inline void setReadTimeout(const int& seconds) diff --git a/test/test.cpp b/test/test.cpp index 2e4e71f..09b0ef2 100644 --- a/test/test.cpp +++ b/test/test.cpp @@ -242,9 +242,8 @@ TEST_SUITE("Ollama Tests") { options["num_predict"] = 18; ollama::response response = ollama::generate_embeddings("llama3:8b", "Why is the sky blue?"); - //std::cout << response << std::endl; - CHECK(response.as_json().contains("embedding") == true); + CHECK(response.as_json().contains("embeddings") == true); } TEST_CASE("Enable Debug Logging") { From 0aaa3bdb4e310b732ab9048030fa14b378ff6371 Mon Sep 17 00:00:00 2001 From: James Montgomery Date: Mon, 12 Aug 2024 01:40:48 -0400 Subject: [PATCH 2/2] Update tests to not rely on exact comparisons. --- test/test.cpp | 55 +++++++++++++++++++++++---------------------------- 1 file changed, 25 insertions(+), 30 deletions(-) diff --git a/test/test.cpp b/test/test.cpp index 09b0ef2..10cf894 100644 --- a/test/test.cpp +++ b/test/test.cpp @@ -12,6 +12,8 @@ // Note that this is static. We will use these options for other generations. static ollama::options options; +static std::string test_model = "llama3:8b", image_test_model = "llava"; + TEST_SUITE("Ollama Tests") { TEST_CASE("Initialize Options") { @@ -52,19 +54,19 @@ TEST_SUITE("Ollama Tests") { TEST_CASE("Load Model") { - CHECK( ollama::load_model("llama3:8b") ); + CHECK( ollama::load_model(test_model) ); } TEST_CASE("Pull, Copy, and Delete Models") { // Pull a model by specifying a model name. - CHECK( ollama::pull_model("llama3:8b") == true ); + CHECK( ollama::pull_model(test_model) == true ); // Copy a model by specifying a source model and destination model name. - CHECK( ollama::copy_model("llama3:8b", "llama3_copy") ==true ); + CHECK( ollama::copy_model(test_model, test_model+"_copy") ==true ); // Delete a model by specifying a model name. - CHECK( ollama::delete_model("llama3_copy") == true ); + CHECK( ollama::delete_model(test_model+"_copy") == true ); } TEST_CASE("Model Info") { @@ -81,7 +83,7 @@ TEST_SUITE("Ollama Tests") { // List the models available locally in the ollama server std::vector models = ollama::list_models(); - bool contains_model = (std::find(models.begin(), models.end(), "llama3:8b") != models.end() ); + bool contains_model = (std::find(models.begin(), models.end(), test_model) != models.end() ); CHECK( contains_model ); } @@ -101,12 +103,9 @@ TEST_SUITE("Ollama Tests") { TEST_CASE("Basic Generation") { - ollama::response response = ollama::generate("llama3:8b", "Why is the sky blue?", options); - //std::cout << response << std::endl; - - std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; + ollama::response response = ollama::generate(test_model, "Why is the sky blue?", options); - CHECK(response.as_simple_string() == expected_response); + CHECK( response.as_json().contains("response") == true ); } @@ -124,11 +123,11 @@ TEST_SUITE("Ollama Tests") { TEST_CASE("Streaming Generation") { std::function response_callback = on_receive_response; - ollama::generate("llama3:8b", "Why is the sky blue?", response_callback, options); + ollama::generate(test_model, "Why is the sky blue?", response_callback, options); std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; - CHECK( streamed_response == expected_response ); + CHECK( streamed_response != "" ); } TEST_CASE("Non-Singleton Generation") { @@ -136,23 +135,22 @@ TEST_SUITE("Ollama Tests") { Ollama my_ollama_server("http://localhost:11434"); // You can use all of the same functions from this instanced version of the class. - ollama::response response = my_ollama_server.generate("llama3:8b", "Why is the sky blue?", options); - //std::cout << response << std::endl; + ollama::response response = my_ollama_server.generate(test_model, "Why is the sky blue?", options); std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; - CHECK(response.as_simple_string() == expected_response); + CHECK(response.as_json().contains("response") == true); } TEST_CASE("Single-Message Chat") { ollama::message message("user", "Why is the sky blue?"); - ollama::response response = ollama::chat("llama3:8b", message, options); + ollama::response response = ollama::chat(test_model, message, options); std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; - CHECK(response.as_simple_string()!=""); + CHECK(response.as_json().contains("message") == true); } TEST_CASE("Multi-Message Chat") { @@ -163,11 +161,11 @@ TEST_SUITE("Ollama Tests") { ollama::messages messages = {message1, message2, message3}; - ollama::response response = ollama::chat("llama3:8b", messages, options); + ollama::response response = ollama::chat(test_model, messages, options); std::string expected_response = ""; - CHECK(response.as_simple_string()!=""); + CHECK(response.as_json().contains("message") == true); } TEST_CASE("Chat with Streaming Response") { @@ -182,7 +180,7 @@ TEST_SUITE("Ollama Tests") { ollama::message message("user", "Why is the sky blue?"); - ollama::chat("llama3:8b", message, response_callback, options); + ollama::chat(test_model, message, response_callback, options); CHECK(streamed_response!=""); } @@ -195,12 +193,9 @@ TEST_SUITE("Ollama Tests") { ollama::image image = ollama::image::from_file("llama.jpg"); - //ollama::images images={image}; - - ollama::response response = ollama::generate("llava", "What do you see in this image?", options, image); - std::string expected_response = " The image features a large, fluffy white llama"; + ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, image); - CHECK(response.as_simple_string() == expected_response); + CHECK( response.as_json().contains("response") == true ); } TEST_CASE("Generation with Multiple Images") { @@ -214,10 +209,10 @@ TEST_SUITE("Ollama Tests") { ollama::images images={image, base64_image}; - ollama::response response = ollama::generate("llava", "What do you see in this image?", options, images); + ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, images); std::string expected_response = " The image features a large, fluffy white and gray llama"; - CHECK(response.as_simple_string() == expected_response); + CHECK(response.as_json().contains("response") == true); } TEST_CASE("Chat with Image") { @@ -230,18 +225,18 @@ TEST_SUITE("Ollama Tests") { // We can optionally include images with each message. Vision-enabled models will be able to utilize these. ollama::message message_with_image("user", "What do you see in this image?", image); - ollama::response response = ollama::chat("llava", message_with_image, options); + ollama::response response = ollama::chat(image_test_model, message_with_image, options); std::string expected_response = " The image features a large, fluffy white llama"; - CHECK(response.as_simple_string()!=""); + CHECK(response.as_json().contains("message") == true); } TEST_CASE("Embedding Generation") { options["num_predict"] = 18; - ollama::response response = ollama::generate_embeddings("llama3:8b", "Why is the sky blue?"); + ollama::response response = ollama::generate_embeddings(test_model, "Why is the sky blue?"); CHECK(response.as_json().contains("embeddings") == true); }