-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3abd7a0
commit 18b6a68
Showing
1 changed file
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import torch | ||
import torch.nn as nn | ||
from model import Transformer | ||
from config import LlamaConfig | ||
from sentencepiece import SentencePieceProcessor | ||
|
||
import time | ||
import json | ||
from tqdm import tqdm | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
|
||
class LLaMA: | ||
|
||
def __init__(self, model: Transformer, tokenizer: SentencePieceProcessor, model_args: LlamaConfig): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
self.args = model_args | ||
|
||
@staticmethod | ||
def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str): | ||
prev_time = time.time() | ||
if load_model: | ||
checkpoints = sorted(Path(checkpoints_dir).glob("*.pth")) | ||
assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}" | ||
ckpt_path = checkpoints[0] | ||
print(f'Loading checkpoint "{ckpt_path}"') | ||
checkpoint = torch.load(ckpt_path, map_location="cpu") | ||
print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s") | ||
prev_time = time.time() | ||
with open(Path(checkpoints_dir) / "params.json", "r") as f: | ||
params = json.loads(f.read()) | ||
|
||
model_args: LlamaConfig = LlamaConfig( | ||
max_seq_len=max_seq_len, | ||
max_batch_size=max_batch_size, | ||
device=device, | ||
**params | ||
) | ||
|
||
tokenizer = SentencePieceProcessor() | ||
tokenizer.load(tokenizer_path) | ||
model_args.vocab_size = tokenizer.vocab_size() | ||
|
||
if device == "cuda": | ||
torch.set_default_tensor_type(torch.cuda.HalfTensor) | ||
else: | ||
torch.set_default_tensor_type(torch.BFloat16Tensor) | ||
|
||
model = Transformer(model_args).to(device) | ||
|
||
if load_model: | ||
del checkpoint['rope.freqs'] | ||
model.load_state_dict(checkpoint, strict=True) | ||
print(f"Loaded state dict in {time.time() - prev_time:.2f}s") | ||
|
||
return LLaMA(model, tokenizer, model_args) | ||
|
||
def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None): | ||
if max_gen_len is None: | ||
max_gen_len = self.args.max_seq_len - 1 | ||
# Convert each prompt into tokens | ||
prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts] | ||
# Make sure the batch size is not too large | ||
batch_size = len(prompt_tokens) | ||
assert batch_size <= self.args.max_batch_size, f"batch size must be less than or equal to {self.args.max_batch_size}" | ||
max_prompt_len = max(len(prompt) for prompt in prompt_tokens) | ||
# Make sure the prompt length is not larger than the maximum sequence length | ||
assert max_prompt_len <= self.args.max_seq_len, f"prompt length must be less than or equal to {self.args.max_seq_len}" | ||
total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len) | ||
|
||
# Create the list that will contain the generated tokens, along with the initial prompt tokens | ||
pad_id = self.tokenizer.pad_id() | ||
tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device) | ||
for k, t in enumerate(prompt_tokens): | ||
# Populate the initial tokens with the prompt tokens | ||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) | ||
|
||
eos_reached = torch.tensor([False] * batch_size, device=device) | ||
prompt_tokens_mask = tokens != pad_id | ||
cur_iterator = tqdm(range(1, total_len), desc="Generating tokens") | ||
for cur_pos in cur_iterator: | ||
with torch.no_grad(): | ||
logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos) | ||
if temperature > 0: | ||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1) | ||
next_token = self._sample_top_p(probs, top_p) | ||
else: | ||
next_token = torch.argmax(logits[:, -1], dim=-1) | ||
|
||
next_token = next_token.reshape(-1) | ||
next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token) | ||
tokens[:, cur_pos] = next_token | ||
eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id) | ||
if all(eos_reached): | ||
break | ||
|
||
out_tokens = [] | ||
out_text = [] | ||
for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()): | ||
if self.tokenizer.eos_id in current_prompt_tokens: | ||
eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id) | ||
current_prompt_tokens = current_prompt_tokens[:eos_idx] | ||
out_tokens.append(current_prompt_tokens) | ||
out_text.append(self.tokenizer.decode(current_prompt_tokens)) | ||
return (out_tokens, out_text) | ||
|
||
def _sample_top_p(self, probs, p): | ||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) | ||
probs_sum = torch.cumsum(probs_sort, dim=-1) | ||
mask = probs_sum - probs_sort > p | ||
probs_sort[mask] = 0.0 | ||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) | ||
next_token = torch.multinomial(probs_sort, num_samples=1) | ||
next_token = torch.gather(probs_idx, -1, next_token) | ||
return next_token | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
torch.manual_seed(0) | ||
|
||
allow_cuda = False | ||
device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu' | ||
|
||
prompts = [ | ||
"Simply put, the theory of relativity states that ", | ||
"If Google was an Italian company founded in Milan, it would", | ||
# Few shot promt | ||
"""Translate English to French: | ||
sea otter => loutre de mer | ||
peppermint => menthe poivrée | ||
plush girafe => girafe peluche | ||
cheese =>""", | ||
# Zero shot prompt | ||
"""Tell me if the following person is actually Doraemon disguised as human: | ||
Name: Umar Jamil | ||
Decision: | ||
""" | ||
] | ||
|
||
model = LLaMA.build( | ||
checkpoints_dir='llama-2-7b/', | ||
tokenizer_path='tokenizer.model', | ||
load_model=True, | ||
max_seq_len=1024, | ||
max_batch_size=len(prompts), | ||
device=device | ||
) | ||
|
||
out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64)) | ||
assert len(out_texts) == len(prompts) | ||
for i in range(len(out_texts)): | ||
print(f'{out_texts[i]}') | ||
print('-' * 50) |