3737
3838ConfigType : TypeAlias = "dict[str, Any]"
3939
40- # Environment setup adapted from HF transformers
4140@_operator_call
4241def _jinja_env () -> ImmutableSandboxedEnvironment :
42+ # Environment setup adapted from HF transformers
4343 def raise_exception (message : str ) -> NoReturn :
4444 raise jinja2 .exceptions .TemplateError (message )
4545
@@ -56,15 +56,17 @@ def strftime_now(fmt: str) -> str:
5656 return env
5757
5858
59- class MessageType (TypedDict ):
59+ class Message (TypedDict ):
60+ """A message in a chat with a GPT4All model."""
61+
6062 role : str
6163 content : str
6264
6365
64- class ChatSession (NamedTuple ):
66+ class _ChatSession (NamedTuple ):
6567 template : jinja2 .Template
6668 template_source : str
67- history : list [MessageType ]
69+ history : list [Message ]
6870
6971
7072class Embed4All :
@@ -195,7 +197,8 @@ class GPT4All:
195197 """
196198
197199 RE_LEGACY_SYSPROMPT = re .compile (
198- r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>" ,
200+ r"(?:^|\s)(?:### *System\b|S(?:ystem|YSTEM):)|"
201+ r"<\|(?:im_(?:start|end)|(?:start|end)_header_id|eot_id|SYSTEM_TOKEN)\|>|<<SYS>>" ,
199202 re .MULTILINE ,
200203 )
201204
@@ -244,7 +247,7 @@ def __init__(
244247 """
245248
246249 self .model_type = model_type
247- self ._chat_session : ChatSession | None = None
250+ self ._chat_session : _ChatSession | None = None
248251
249252 device_init = None
250253 if sys .platform == "darwin" :
@@ -303,11 +306,12 @@ def device(self) -> str | None:
303306 return self .model .device
304307
305308 @property
306- def current_chat_session (self ) -> list [MessageType ] | None :
309+ def current_chat_session (self ) -> list [Message ] | None :
310+ """The message history of the current chat session."""
307311 return None if self ._chat_session is None else self ._chat_session .history
308312
309313 @current_chat_session .setter
310- def current_chat_session (self , history : list [MessageType ]) -> None :
314+ def current_chat_session (self , history : list [Message ]) -> None :
311315 if self ._chat_session is None :
312316 raise ValueError ("current_chat_session may only be set when there is an active chat session" )
313317 self ._chat_session .history [:] = history
@@ -585,13 +589,13 @@ def _callback_wrapper(token_id: int, response: str) -> bool:
585589 last_msg_rendered = prompt
586590 if self ._chat_session is not None :
587591 session = self ._chat_session
588- def render (messages : list [MessageType ]) -> str :
592+ def render (messages : list [Message ]) -> str :
589593 return session .template .render (
590594 messages = messages ,
591595 add_generation_prompt = True ,
592596 ** self .model .special_tokens_map ,
593597 )
594- session .history .append (MessageType (role = "user" , content = prompt ))
598+ session .history .append (Message (role = "user" , content = prompt ))
595599 prompt = render (session .history )
596600 if len (session .history ) > 1 :
597601 last_msg_rendered = render (session .history [- 1 :])
@@ -606,20 +610,14 @@ def render(messages: list[MessageType]) -> str:
606610 def stream () -> Iterator [str ]:
607611 yield from self .model .prompt_model_streaming (prompt , _callback_wrapper , ** generate_kwargs )
608612 if self ._chat_session is not None :
609- self ._chat_session .history .append (MessageType (role = "assistant" , content = full_response ))
613+ self ._chat_session .history .append (Message (role = "assistant" , content = full_response ))
610614 return stream ()
611615
612616 self .model .prompt_model (prompt , _callback_wrapper , ** generate_kwargs )
613617 if self ._chat_session is not None :
614- self ._chat_session .history .append (MessageType (role = "assistant" , content = full_response ))
618+ self ._chat_session .history .append (Message (role = "assistant" , content = full_response ))
615619 return full_response
616620
617- @classmethod
618- def is_legacy_chat_template (cls , tmpl : str ) -> bool :
619- """A fairly reliable heuristic for detecting templates that don't look like Jinja templates."""
620- return bool (re .search (r"%[12]\b" , tmpl ) or not cls .RE_JINJA_LIKE .search (tmpl )
621- or not re .search (r"\bcontent\b" , tmpl ))
622-
623621 @contextmanager
624622 def chat_session (
625623 self ,
@@ -632,10 +630,14 @@ def chat_session(
632630 Context manager to hold an inference optimized chat session with a GPT4All model.
633631
634632 Args:
635- system_message: An initial instruction for the model, None to use the model default, or False to disable. Defaults to None.
633+ system_message: An initial instruction for the model, None to use the model default, or False to disable.
634+ Defaults to None.
636635 chat_template: Jinja template for the conversation, or None to use the model default. Defaults to None.
637- """
636+ warn_legacy: Whether to warn about legacy system prompts or prompt templates. Defaults to True.
638637
638+ Raises:
639+ ValueError: If no valid chat template was found.
640+ """
639641 if system_message is None :
640642 system_message = self .config .get ("systemMessage" , False )
641643 elif system_message is not False and warn_legacy and (m := self .RE_LEGACY_SYSPROMPT .search (system_message )):
@@ -662,7 +664,7 @@ def chat_session(
662664 msg += " If this is a built-in model, consider setting allow_download to True."
663665 raise ValueError (msg ) from None
664666 raise
665- elif warn_legacy and self .is_legacy_chat_template (chat_template ):
667+ elif warn_legacy and self ._is_legacy_chat_template (chat_template ):
666668 print (
667669 "Warning: chat_session() was passed a chat template that is not in Jinja format. Old-style prompt "
668670 "templates are no longer supported.\n To disable this warning, pass warn_legacy=False." ,
@@ -671,8 +673,8 @@ def chat_session(
671673
672674 history = []
673675 if system_message is not False :
674- history .append (MessageType (role = "system" , content = system_message ))
675- self ._chat_session = ChatSession (
676+ history .append (Message (role = "system" , content = system_message ))
677+ self ._chat_session = _ChatSession (
676678 template = _jinja_env .from_string (chat_template ),
677679 template_source = chat_template ,
678680 history = history ,
@@ -692,6 +694,12 @@ def list_gpus() -> list[str]:
692694 """
693695 return LLModel .list_gpus ()
694696
697+ @classmethod
698+ def _is_legacy_chat_template (cls , tmpl : str ) -> bool :
699+ # check if tmpl does not look like a Jinja template
700+ return bool (re .search (r"%[12]\b" , tmpl ) or not cls .RE_JINJA_LIKE .search (tmpl )
701+ or not re .search (r"\bcontent\b" , tmpl ))
702+
695703
696704def append_extension_if_missing (model_name ):
697705 if not model_name .endswith ((".bin" , ".gguf" )):
0 commit comments