Skip to content

Commit

Permalink
Merge pull request #41 from brainlid/me-zephyr-chat-model
Browse files Browse the repository at this point in the history
Add initial support for Zephyr 7b Beta
  • Loading branch information
brainlid authored Feb 29, 2024
2 parents d665a8f + 1ad6a09 commit ba1efba
Show file tree
Hide file tree
Showing 15 changed files with 1,206 additions and 61 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Changelog

## v0.1.9 (2024-02-29) - The Leap Release!

This adds support for Bumblebee as a Chat model, making it easy to have conversations with Llama 2, Mistral, and Zephyr LLMs.

See the documentation in `LangChain.ChatModels.ChatBumblebee` for getting started.

NOTE: That at this time, none of the models support the `function` ability, so that is not supported yet.

This release includes an experimental change for better support of streamed responses that are broken up over multiple messages from services like ChatGPT and others.

Other library dependencies requirements were relaxed, making it easier to support different versions of libraries like `req` and `nx`.


## v0.1.8 (2024-02-16)

**Breaking change**: `RoutingChain`'s required values changed. Previously, `default_chain` was assigned an `%LLMChain{}` to return when no more specific routes matched.
Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,23 @@ For example, if a locally running service provided that feature, the following c
|> LLMChain.run()
```

### Bumblebee Chat Support

Bumblebee hosted chat models are supported. There is built-in support for Llama 2, Mistral, and Zephyr models.

Currently, function calling is NOT supported with these models.

ChatBumblebee.new!(%{
serving: @serving_name,
template_format: @template_format,
receive_timeout: @receive_timeout,
stream: true
})

The `serving` is the module name of the `Nx.Serving` that is hosting the model.

See the `LangChain.ChatModels.ChatBumblebee` documentation for more details.

## Testing

To run all the tests including the ones that perform live calls against the OpenAI API, use the following command:
Expand Down
257 changes: 257 additions & 0 deletions lib/chat_models/chat_bumblebee.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
defmodule LangChain.ChatModels.ChatBumblebee do
@moduledoc """
Represents a chat model hosted by Bumblebee and accessed through an
`Nx.Serving`.
Many types of models can be hosted through Bumblebee, so this attempts to
represent the most common features and provide a single implementation where
possible.
For streaming responses, the Bumblebee serving must be configured with
`stream: true` and should include `stream_done: true` as well.
Example:
Bumblebee.Text.generation(model_info, tokenizer, generation_config,
# ...
stream: true,
stream_done: true
)
This supports a non streaming response as well, in which case, a completed
`LangChain.Message` is returned at the completion.
The `stream_done` option sends a final message to let us know when the stream
is complete and includes some token information.
The chat model can be created like this and provided to an LLMChain:
ChatBumblebee.new!(%{
serving: @serving_name,
template_format: @template_format,
receive_timeout: @receive_timeout,
stream: true
})
The `serving` is the module name of the `Nx.Serving` that is hosting the
model.
The following are the supported values for `template_format`. These are
provided by `LangChain.Utils.ChatTemplates`.
Chat models are trained against specific content formats for the messages.
Some models have no special concept of a system message. See the
`LangChain.Utils.ChatTemplates` documentation for specific format examples.
Using the wrong format with a model may result in poor performance or
hallucinations. It will not result in an error.
"""
use Ecto.Schema
require Logger
import Ecto.Changeset
import LangChain.Utils.ApiOverride
alias __MODULE__
alias LangChain.ChatModels.ChatModel
alias LangChain.Message
alias LangChain.LangChainError
alias LangChain.Utils
alias LangChain.MessageDelta
alias LangChain.Utils.ChatTemplates

@behaviour ChatModel

@primary_key false
embedded_schema do
# Name of the Nx.Serving to use when working with the LLM.
field :serving, :any, virtual: true

# # What sampling temperature to use, between 0 and 2. Higher values like 0.8
# # will make the output more random, while lower values like 0.2 will make it
# # more focused and deterministic.
# field :temperature, :float, default: 1.0

field :template_format, Ecto.Enum, values: [:inst, :im_start, :zephyr, :llama_2]

# The bumblebee model may compile differently based on the stream true/false
# option on the serving. Therefore, streaming should be enabled on the
# serving and a stream option here can change the way data is received in
# code. - https://github.com/elixir-nx/bumblebee/issues/295

field :stream, :boolean, default: true

# Seed for randomizing behavior or giving more deterministic output. Helpful
# for testing.
field :seed, :integer, default: nil
end

@type t :: %ChatBumblebee{}

# @type call_response :: {:ok, Message.t() | [Message.t()]} | {:error, String.t()}
# @type callback_data ::
# {:ok, Message.t() | MessageDelta.t() | [Message.t() | MessageDelta.t()]}
# | {:error, String.t()}
@type callback_fn :: (Message.t() | MessageDelta.t() -> any())

@create_fields [
:serving,
# :temperature,
:seed,
:template_format,
:stream
]
@required_fields [:serving]

@doc """
Setup a ChatBumblebee client configuration.
"""
@spec new(attrs :: map()) :: {:ok, t} | {:error, Ecto.Changeset.t()}
def new(%{} = attrs \\ %{}) do
%ChatBumblebee{}
|> cast(attrs, @create_fields)
|> common_validation()
|> apply_action(:insert)
end

@doc """
Setup a ChatBumblebee client configuration and return it or raise an error if invalid.
"""
@spec new!(attrs :: map()) :: t() | no_return()
def new!(attrs \\ %{}) do
case new(attrs) do
{:ok, chain} ->
chain

{:error, changeset} ->
raise LangChainError, changeset
end
end

defp common_validation(changeset) do
changeset
|> validate_required(@required_fields)
end

@impl ChatModel
def call(model, prompt, functions \\ [], callback_fn \\ nil)

def call(%ChatBumblebee{} = model, prompt, functions, callback_fn) when is_binary(prompt) do
messages = [
Message.new_system!(),
Message.new_user!(prompt)
]

call(model, messages, functions, callback_fn)
end

def call(%ChatBumblebee{} = model, messages, functions, callback_fn)
when is_list(messages) do
if override_api_return?() do
Logger.warning("Found override API response. Will not make live API call.")

case get_api_override() do
{:ok, {:ok, data} = response} ->
# fire callback for fake responses too
Utils.fire_callback(model, data, callback_fn)
response

_other ->
raise LangChainError,
"An unexpected fake API response was set. Should be an `{:ok, value}`"
end
else
try do
# make base api request and perform high-level success/failure checks
case do_serving_request(model, messages, functions, callback_fn) do
{:error, reason} ->
{:error, reason}

parsed_data ->
{:ok, parsed_data}
end
rescue
err in LangChainError ->
{:error, err.message}
end
end
end

@doc false
@spec do_serving_request(t(), [Message.t()], [Function.t()], callback_fn()) ::
list() | struct() | {:error, String.t()}
def do_serving_request(%ChatBumblebee{} = model, messages, _functions, callback_fn) do
prompt = ChatTemplates.apply_chat_template!(messages, model.template_format)

model.serving
|> Nx.Serving.batched_run(%{text: prompt, seed: model.seed})
|> do_process_response(model, callback_fn)
end

@doc false
def do_process_response(
%{results: [%{text: content, token_summary: _token_summary}]},
%ChatBumblebee{} = model,
callback_fn
)
when is_binary(content) do
case Message.new(%{role: :assistant, status: :complete, content: content}) do
{:ok, message} ->
# execute the callback with the final message
Utils.fire_callback(model, [message], callback_fn)
# return a list of the complete message. As a list for compatibility.
[message]

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)
Logger.error("Failed to create non-streamed full message: #{inspect(reason)}")
{:error, reason}
end
end

def do_process_response(stream, %ChatBumblebee{stream: false} = model, callback_fn) do
# Request is to NOT stream. Consume the full stream and format the data as
# though it had not been streamed.
full_data =
Enum.reduce(stream, %{text: "", token_summary: nil}, fn
{:done, token_data}, %{text: text} ->
%{text: text, token_summary: token_data}

data, %{text: text} = acc ->
Map.put(acc, :text, text <> data)
end)

do_process_response(%{results: [full_data]}, model, callback_fn)
end

def do_process_response(stream, %ChatBumblebee{} = model, callback_fn) do
chunk_processor = fn
{:done, _token_data} ->
final_delta = MessageDelta.new!(%{role: :assistant, status: :complete})
Utils.fire_callback(model, [final_delta], callback_fn)
final_delta

content when is_binary(content) ->
case MessageDelta.new(%{content: content, role: :assistant, status: :incomplete}) do
{:ok, delta} ->
Utils.fire_callback(model, [delta], callback_fn)
delta

{:error, changeset} ->
reason = Utils.changeset_error_to_string(changeset)

Logger.error(
"Failed to process received model's MessageDelta data: #{inspect(reason)}"
)

raise LangChainError, reason
end
end

result =
stream
|> Stream.map(&chunk_processor.(&1))
|> Enum.to_list()

# return a list of a list to mirror the way ChatGPT returns data
[result]
end
end
2 changes: 1 addition & 1 deletion lib/chat_models/chat_open_ai.ex
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ defmodule LangChain.ChatModels.ChatOpenAI do
# NOTE: As of gpt-4 and gpt-3.5, only one function_call is issued at a time
# even when multiple requests could be issued based on the prompt.

# allow up to 2 minutes for response.
# allow up to 1 minute for response.
@receive_timeout 60_000

@primary_key false
Expand Down
17 changes: 9 additions & 8 deletions lib/message.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,8 @@ defmodule LangChain.Message do
@type t :: %Message{}
@type status :: :complete | :cancelled | :length

@create_fields [
:role,
:content,
:status,
:function_name,
:arguments,
:index
]
@update_fields [:role, :content, :status, :function_name, :arguments, :index]
@create_fields @update_fields
@required_fields [:role]

@doc """
Expand Down Expand Up @@ -89,6 +83,13 @@ defmodule LangChain.Message do
end
end

@doc false
def changeset(message, attrs) do
message
|> cast(attrs, @update_fields)
|> common_validations()
end

defp changeset_is_function?(changeset) do
get_field(changeset, :role) == :assistant and
is_binary(get_field(changeset, :function_name)) and
Expand Down
Loading

0 comments on commit ba1efba

Please sign in to comment.