Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Minimal Tokenizer Implementation #513

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
14 changes: 5 additions & 9 deletions exo/api/chatgpt_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import json
from pathlib import Path
from transformers import AutoTokenizer
from typing import List, Literal, Union, Dict
from aiohttp import web
import aiohttp_cors
Expand All @@ -14,9 +13,8 @@
from exo import DEBUG, VERSION
from exo.download.download_progress import RepoProgressEvent
from exo.helpers import PrefixDict, shutdown
from exo.inference.tokenizers import resolve_tokenizer
from exo.orchestration import Node
from exo.models import build_base_shard, model_cards, get_repo, pretty_name, get_supported_models
from exo.models import build_base_shard, model_cards, pretty_name, get_supported_models
from typing import Callable, Optional

class Message:
Expand Down Expand Up @@ -228,7 +226,7 @@ async def handle_post_chat_token_encode(self, request):
data = await request.json()
shard = build_base_shard(self.default_model, self.inference_engine_classname)
messages = [parse_message(msg) for msg in data.get("messages", [])]
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
tokenizer = await self.node.inference_engine.get_tokenizer(shard)
return web.json_response({"length": len(build_prompt(tokenizer, messages)[0])})

async def handle_get_download_progress(self, request):
Expand Down Expand Up @@ -257,8 +255,7 @@ async def handle_post_chat_completions(self, request):
{"detail": f"Unsupported model: {chat_request.model} with inference engine {self.inference_engine_classname}. Supported models for this engine: {supported_models}"},
status=400,
)

tokenizer = await resolve_tokenizer(get_repo(shard.model_id, self.inference_engine_classname))
tokenizer = await self.node.inference_engine.get_tokenizer(shard)
if DEBUG >= 4: print(f"Resolved tokenizer: {tokenizer}")

prompt = build_prompt(tokenizer, chat_request.messages)
Expand Down Expand Up @@ -307,8 +304,7 @@ async def stream_result(_request_id: str, tokens: List[int], is_finished: bool):
self.prev_token_lens[_request_id] = max(prev_last_tokens_len, len(tokens))
new_tokens = tokens[prev_last_tokens_len:]
finish_reason = None
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if hasattr(tokenizer, "_tokenizer") and isinstance(tokenizer._tokenizer,
AutoTokenizer) else getattr(tokenizer, "eos_token_id", None)
eos_token_id = tokenizer.eos_token_id
if len(new_tokens) > 0 and new_tokens[-1] == eos_token_id:
new_tokens = new_tokens[:-1]
if is_finished:
Expand Down Expand Up @@ -354,7 +350,7 @@ def on_result(_request_id: str, tokens: List[int], is_finished: bool):
)

finish_reason = "length"
eos_token_id = tokenizer.special_tokens_map.get("eos_token_id") if isinstance(getattr(tokenizer, "_tokenizer", None), AutoTokenizer) else tokenizer.eos_token_id
eos_token_id = tokenizer.eos_token_id
if DEBUG >= 2: print(f"Checking if end of tokens result {tokens[-1]=} is {eos_token_id=}")
if tokens[-1] == eos_token_id:
tokens = tokens[:-1]
Expand Down
5 changes: 5 additions & 0 deletions exo/inference/mlx/sharded_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from exo.download.shard_download import ShardDownloader
import asyncio
from concurrent.futures import ThreadPoolExecutor
from exo.tokenizer.tokenizer import Tokenizer

def sample_logits(
logits: mx.array,
Expand Down Expand Up @@ -58,6 +59,10 @@ async def infer_tensor(self, request_id: str, shard: Shard, input_data: np.ndarr
await self.ensure_shard(shard)
output_data: np.ndarray = np.array(await asyncio.get_running_loop().run_in_executor(self.executor, self.model, mx.array(input_data), request_id))
return output_data

async def get_tokenizer(self, shard: Shard) -> Tokenizer:
await self.ensure_shard(shard)
return self.tokenizer

async def ensure_shard(self, shard: Shard):
if self.shard == shard:
Expand Down
8 changes: 4 additions & 4 deletions exo/inference/mlx/sharded_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import mlx.nn as nn
from transformers import AutoProcessor

from mlx_lm.tokenizer_utils import load_tokenizer, TokenizerWrapper
from exo.tokenizer.tokenizer import Tokenizer

from exo import DEBUG
from exo.inference.tokenizers import resolve_tokenizer
from exo.tokenizer.tokenizer import resolve_tokenizer
from ..shard import Shard


Expand Down Expand Up @@ -174,7 +174,7 @@ async def load_shard(
model_config={},
adapter_path: Optional[str] = None,
lazy: bool = False,
) -> Tuple[nn.Module, TokenizerWrapper]:
) -> Tuple[nn.Module, Tokenizer]:
model = load_model_shard(model_path, shard, lazy, model_config)

# TODO: figure out a generic solution
Expand All @@ -184,7 +184,7 @@ async def load_shard(
processor.encode = processor.tokenizer.encode
return model, processor
else:
tokenizer = await resolve_tokenizer(model_path)
tokenizer = resolve_tokenizer(shard.model_id, model_path)
return model, tokenizer


Expand Down
43 changes: 24 additions & 19 deletions exo/inference/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from exo.download.hf.hf_helpers import get_local_snapshot_dir
from exo.helpers import DEBUG

from exo.tokenizer.tokenizer import Tokenizer

class DummyTokenizer:
def __init__(self):
Expand Down Expand Up @@ -40,25 +41,29 @@ async def resolve_tokenizer(model_id: str):


async def _resolve_tokenizer(model_id_or_local_path: Union[str, PathLike]):
try:
if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
if not hasattr(processor, 'eos_token_id'):
processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
if not hasattr(processor, 'encode'):
processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
if not hasattr(processor, 'decode'):
processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
return processor
except Exception as e:
if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())
# try:
# if DEBUG >= 4: print(f"Trying AutoProcessor for {model_id_or_local_path}")
# processor = AutoProcessor.from_pretrained(model_id_or_local_path, use_fast=True if "Mistral-Large" in f"{model_id_or_local_path}" else False, trust_remote_code=True)
# if not hasattr(processor, 'eos_token_id'):
# processor.eos_token_id = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).eos_token_id
# if not hasattr(processor, 'encode'):
# processor.encode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).encode
# if not hasattr(processor, 'decode'):
# processor.decode = getattr(processor, 'tokenizer', getattr(processor, '_tokenizer', processor)).decode
# return processor
# except Exception as e:
# if DEBUG >= 4: print(f"Failed to load processor for {model_id_or_local_path}. Error: {e}")
# if DEBUG >= 4: print(traceback.format_exc())

# try:
# if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
# return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
# except Exception as e:
# if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
# if DEBUG >= 4: print(traceback.format_exc())

# raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
try:
if DEBUG >= 4: print(f"Trying AutoTokenizer for {model_id_or_local_path}")
return AutoTokenizer.from_pretrained(model_id_or_local_path, trust_remote_code=True)
return Tokenizer(model_id_or_local_path)
except Exception as e:
if DEBUG >= 4: print(f"Failed to load tokenizer for {model_id_or_local_path}. Falling back to tinygrad tokenizer. Error: {e}")
if DEBUG >= 4: print(traceback.format_exc())

raise ValueError(f"[TODO] Unsupported model: {model_id_or_local_path}")
raise ValueError(f"Failed to load tokenizer for {model_id_or_local_path}. Error: {e}")
2 changes: 1 addition & 1 deletion exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async def run_model_cli(node: Node, inference_engine: InferenceEngine, model_nam
if not shard:
print(f"Error: Unsupported model '{model_name}' for inference engine {inference_engine.__class__.__name__}")
return
tokenizer = await resolve_tokenizer(get_repo(shard.model_id, inference_class))
tokenizer = await node.inference_engine.get_tokenizer(shard)
request_id = str(uuid.uuid4())
callback_id = f"cli-wait-response-{request_id}"
callback = node.on_token.register(callback_id)
Expand Down
3 changes: 3 additions & 0 deletions exo/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .exo_tokenizer import ExoTokenizer

__all__ = ['ExoTokenizer']
Loading