-
-
Notifications
You must be signed in to change notification settings - Fork 104
Support new OS models: Zephyr and Yi #392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 55 commits
3aba660
5699773
4213372
f440ca4
e47f762
89a5510
086dec9
23718fc
7ce670d
0d75ea8
9da7098
0cb9afd
f368412
b1f111d
44a2787
6d8cdc7
e712f41
985fd68
98842a2
b54a3d9
9bf365d
dddfaab
3af21b5
fee9ca7
e508499
3218541
c104387
2c6d899
2502c4d
03055c5
4e4a2cd
865acec
694d5da
5295400
70e3643
4321483
ef6e738
e3ff37d
056730a
196c235
1f51a4a
e18b302
7c092ca
358ba72
0c96fb6
dc926bd
5585174
6d3a4c8
57acfe4
2f1a905
9821063
98e3e6c
513c2fb
482af35
3747a2f
b2dff8f
dfe89ee
69c3c76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| from typing import Any, Dict, Iterable, List, Optional, Tuple | ||
|
|
||
| from confection import SimpleFrozenDict | ||
|
|
||
| from ...compat import Literal, transformers | ||
| from ...registry.util import registry | ||
| from .base import HuggingFace | ||
|
|
||
|
|
||
| class Yi(HuggingFace): | ||
| MODEL_NAMES = Literal[ # noqa: F722 | ||
| "Yi-34B", | ||
| "Yi-34B-chat-8bits", | ||
| "Yi-6B-chat", | ||
| "Yi-6B", | ||
| "Yi-6B-200K", | ||
| "Yi-34B-chat", | ||
| "Yi-34B-chat-4bits", | ||
| "Yi-34B-200K", | ||
| ] | ||
|
|
||
| def __init__( | ||
| self, | ||
| name: MODEL_NAMES, | ||
| config_init: Optional[Dict[str, Any]], | ||
| config_run: Optional[Dict[str, Any]], | ||
| context_length: int, | ||
| ): | ||
| self._tokenizer: Optional["transformers.AutoTokenizer"] = None | ||
| self._is_instruct = "instruct" in name | ||
| super().__init__( | ||
| name=name, | ||
| config_init=config_init, | ||
| config_run=config_run, | ||
| context_length=context_length, | ||
| ) | ||
|
|
||
| assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase) | ||
|
|
||
| # Instantiate GenerationConfig object from config dict. | ||
| self._hf_config_run = transformers.GenerationConfig.from_pretrained( | ||
| self._name, **self._config_run | ||
| ) | ||
| # To avoid deprecation warning regarding usage of `max_length`. | ||
| self._hf_config_run.max_new_tokens = self._hf_config_run.max_length | ||
|
|
||
| def init_model(self) -> Any: | ||
| self._tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
| self._name, use_fast=False | ||
| ) | ||
| init_cfg = self._config_init | ||
| device: Optional[str] = None | ||
| if "device" in init_cfg: | ||
| device = init_cfg.pop("device") | ||
|
|
||
| model = transformers.AutoModelForCausalLM.from_pretrained( | ||
| self._name, **init_cfg, resume_download=True | ||
| ).eval() | ||
| if device: | ||
| model.to(device) | ||
|
|
||
| return model | ||
|
|
||
| @property | ||
| def hf_account(self) -> str: | ||
| return "01-ai" | ||
|
|
||
| def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] | ||
| assert hasattr(self._model, "generate") | ||
| assert hasattr(self._tokenizer, "apply_chat_template") | ||
| assert self._tokenizer | ||
|
|
||
| responses: List[List[str]] = [] | ||
|
|
||
| for prompts_for_doc in prompts: | ||
| prompts_for_doc = list(prompts_for_doc) | ||
|
|
||
| tokenized_input_ids = [ | ||
| self._tokenizer.apply_chat_template( | ||
| conversation=[{"role": "user", "content": prompt}], | ||
| tokenize=True, | ||
| add_generation_prompt=True, | ||
| return_tensors="pt", | ||
| ) | ||
| for prompt in prompts_for_doc | ||
| ] | ||
| tokenized_input_ids = [ | ||
| tp.to(self._model.device) for tp in tokenized_input_ids | ||
| ] | ||
|
|
||
| responses.append( | ||
| [ | ||
| self._tokenizer.decode( | ||
| self._model.generate( | ||
| input_ids=tok_ii, generation_config=self._hf_config_run | ||
| )[:, tok_ii.shape[1] :][0], | ||
| skip_special_tokens=True, | ||
| ).strip("\n") | ||
| for tok_ii in tokenized_input_ids | ||
| ] | ||
| ) | ||
|
|
||
| return responses | ||
|
|
||
| @staticmethod | ||
| def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
| default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() | ||
| return {**default_cfg_init, **{"torch_dtype": "auto"}}, default_cfg_run | ||
|
|
||
|
|
||
| @registry.llm_models("spacy.Yi.v1") | ||
| def yi_hf( | ||
| name: Yi.MODEL_NAMES, | ||
| config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
| config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
| ) -> Yi: | ||
| """Generates Yi instance that can execute a set of prompts and return the raw responses. | ||
| name (Literal): Name of the Mistral model. Has to be one of Mistral.get_model_names(). | ||
| config_init (Optional[Dict[str, Any]]): HF config for initializing the model. | ||
| config_run (Optional[Dict[str, Any]]): HF config for running the model. | ||
| RETURNS (Yi): Yi instance that can execute a set of prompts and return the raw responses. | ||
| """ | ||
| return Yi( | ||
| name=name, | ||
| config_init=config_init, | ||
| config_run=config_run, | ||
| context_length=200000 if "200K" in name else 32000, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| from typing import Any, Dict, Iterable, List, Optional, Tuple | ||
|
|
||
| from confection import SimpleFrozenDict | ||
|
|
||
| from ...compat import Literal, transformers | ||
| from ...registry.util import registry | ||
| from .base import HuggingFace | ||
|
|
||
|
|
||
| class Zephyr(HuggingFace): | ||
| MODEL_NAMES = Literal["zephyr-7b-beta"] # noqa: F722 | ||
|
|
||
| def __init__( | ||
| self, | ||
| name: MODEL_NAMES, | ||
| config_init: Optional[Dict[str, Any]], | ||
| config_run: Optional[Dict[str, Any]], | ||
| context_length: int, | ||
| ): | ||
| super().__init__( | ||
| name=name, | ||
| config_init=config_init, | ||
| config_run=config_run, | ||
| context_length=context_length, | ||
| ) | ||
|
|
||
| # Instantiate GenerationConfig object from config dict. | ||
| self._hf_config_run = transformers.GenerationConfig.from_pretrained( | ||
| self._name, **self._config_run | ||
| ) | ||
| # To avoid deprecation warning regarding usage of `max_length`. | ||
| self._hf_config_run.max_new_tokens = self._hf_config_run.max_length | ||
|
|
||
| def init_model(self) -> Any: | ||
| return transformers.pipeline( | ||
| "text-generation", | ||
| model=self._name, | ||
| return_full_text=False, | ||
| **self._config_init | ||
| ) | ||
|
|
||
| @property | ||
| def hf_account(self) -> str: | ||
| return "HuggingFaceH4" | ||
|
|
||
| def __call__(self, prompts: Iterable[Iterable[str]]) -> Iterable[Iterable[str]]: # type: ignore[override] | ||
| responses: List[List[str]] = [] | ||
|
|
||
| for prompts_for_doc in prompts: | ||
| formatted_prompts_for_doc = [ | ||
| self._model.tokenizer.apply_chat_template( | ||
| [{"role": "user", "content": prompt}], | ||
| tokenize=False, | ||
| add_generation_prompt=False, | ||
| ) | ||
| for prompt in prompts_for_doc | ||
| ] | ||
|
|
||
| responses.append( | ||
| [ | ||
| self._model(prompt, generation_config=self._hf_config_run)[0][ | ||
| "generated_text" | ||
| ] | ||
| .replace("<|assistant|>", "") | ||
| .strip("\n") | ||
| for prompt in formatted_prompts_for_doc | ||
| ] | ||
| ) | ||
|
|
||
| return responses | ||
|
|
||
| @staticmethod | ||
| def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]: | ||
| default_cfg_init, default_cfg_run = HuggingFace.compile_default_configs() | ||
| return default_cfg_init, { | ||
| **default_cfg_run, | ||
| **{ | ||
| "max_new_tokens": 256, | ||
| "do_sample": True, | ||
| "temperature": 0.7, | ||
| "top_k": 50, | ||
| "top_p": 0.95, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| @registry.llm_models("spacy.Zephyr.v1") | ||
| def zephyr_hf( | ||
| name: Zephyr.MODEL_NAMES, | ||
| config_init: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
| config_run: Optional[Dict[str, Any]] = SimpleFrozenDict(), | ||
| ) -> Zephyr: | ||
| """Generates Zephyr instance that can execute a set of prompts and return the raw responses. | ||
| name (Literal): Name of the Zephyr model. Has to be one of Zephyr.get_model_names(). | ||
| config_init (Optional[Dict[str, Any]]): HF config for initializing the model. | ||
| config_run (Optional[Dict[str, Any]]): HF config for running the model. | ||
| RETURNS (Zephyr): Zephyr instance that can execute a set of prompts and return the raw responses. | ||
| """ | ||
| return Zephyr( | ||
| name=name, config_init=config_init, config_run=config_run, context_length=8000 | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -206,4 +206,10 @@ def reduce_shards_to_doc(task: EntityLinkerTask, shards: Iterable[Doc]) -> Doc: | |
| RETURNS (Doc): Fused doc instance. | ||
| """ | ||
| # Entities are additive, so we can just merge shards. | ||
| return Doc.from_docs(list(shards), ensure_whitespace=True) | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings( | ||
| "ignore", | ||
| category=UserWarning, | ||
| message=".*Skipping .* while merging docs.", | ||
| ) | ||
| return Doc.from_docs(list(shards), ensure_whitespace=True) | ||
|
Comment on lines
-209
to
+215
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure where this edit is coming from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just a drive-by because I noticed the warnings filter is missing here 🙃 I can move this into a separate PR, if you mind having it in here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's all just a bit confusing with the huge (mostly unrelated) git history etc - I do in general appreciate more "atomic" PRs ;-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I know 🫣
Yeah, I don't know why that's the case. The branches should all be updated. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |
|
|
||
| [components.llm] | ||
| factory = "llm" | ||
| save_io = True | ||
|
|
||
| [components.llm.task] | ||
| @llm_tasks = "spacy.NoOp.v1" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| import copy | ||
|
|
||
| import pytest | ||
| import spacy | ||
| from confection import Config # type: ignore[import] | ||
| from thinc.compat import has_torch_cuda_gpu | ||
|
|
||
| from ...compat import torch | ||
|
|
||
| _PIPE_CFG = { | ||
| "model": { | ||
| "@llm_models": "spacy.Yi.v1", | ||
| "name": "Yi-6B-chat", | ||
| }, | ||
| "task": {"@llm_tasks": "spacy.NoOp.v1"}, | ||
| } | ||
|
|
||
| _NLP_CONFIG = """ | ||
|
|
||
| [nlp] | ||
| lang = "en" | ||
| pipeline = ["llm"] | ||
| batch_size = 128 | ||
|
|
||
| [components] | ||
|
|
||
| [components.llm] | ||
| factory = "llm" | ||
|
|
||
| [components.llm.task] | ||
| @llm_tasks = "spacy.NoOp.v1" | ||
|
|
||
| [components.llm.model] | ||
| @llm_models = "spacy.Yi.v1" | ||
| name = "Yi-6B" | ||
| """ | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
| def test_init(): | ||
| """Test initialization and simple run.""" | ||
| nlp = spacy.blank("en") | ||
| cfg = copy.deepcopy(_PIPE_CFG) | ||
| nlp.add_pipe("llm", config=cfg) | ||
| nlp("This is a test.") | ||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| @pytest.mark.skip(reason="CI runner needs more GPU memory") | ||
| @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
| def test_init_from_config(): | ||
| orig_config = Config().from_str(_NLP_CONFIG) | ||
| nlp = spacy.util.load_model_from_config(orig_config, auto_fill=True) | ||
| assert nlp.pipe_names == ["llm"] | ||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| @pytest.mark.gpu | ||
| @pytest.mark.skipif(not has_torch_cuda_gpu, reason="needs GPU & CUDA") | ||
| def test_invalid_model(): | ||
| orig_config = Config().from_str(_NLP_CONFIG) | ||
| config = copy.deepcopy(orig_config) | ||
| config["components"]["llm"]["model"]["name"] = "x" | ||
| with pytest.raises(ValueError, match="unexpected value; permitted"): | ||
| spacy.util.load_model_from_config(config, auto_fill=True) | ||
| torch.cuda.empty_cache() |
Uh oh!
There was an error while loading. Please reload this page.