Skip to content

Commit

Permalink
add gist model generation utils to library
Browse files Browse the repository at this point in the history
  • Loading branch information
uSaiPrashanth committed Jul 4, 2024
1 parent 12be67c commit 9056c97
Show file tree
Hide file tree
Showing 18 changed files with 287 additions and 206 deletions.
71 changes: 69 additions & 2 deletions cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def add_cache_arguments(parser: argparse.ArgumentParser):
help="Cache size per layer. If len < n layers, the values are tiled. Must have len divisible by n layers. \
If 0 < x <= 1, it is percent of |prompt| + max new tokens. Otherwise, if > 1, its the maximum size.",
)
strategies = ["full", "random", "window", "scissor", "l2", "fastgen"]
strategies = ["full", "random", "window", "scissor", "l2", "fastgen", "gist"]
debug_strategies = [f"debug_{strategy}" for strategy in strategies]
strategies.extend(debug_strategies)

Expand Down Expand Up @@ -105,11 +105,14 @@ def add_cache_arguments(parser: argparse.ArgumentParser):


def cache_compatibility(args):
if args.cache_strategy == "full":
if args.cache_strategy in ("full", "gist"):
# Full implies no compression, which means --max_cache_length = [1.0] (same size as prompt + max_new_tokens)
assert all(
[l == 1.0 for l in args.max_cache_length]
), "Full cache strategy only supports max_cache_length=1.0."

if args.cache_strategy == "gist":
assert "gist" in str(args.checkpoint_path), "You must provide a gist token id for the gist cache."

# Attention-based eviction policies must use an attention-based prompt compressor
if args.cache_strategy in {"scissor"}:
Expand Down Expand Up @@ -461,6 +464,68 @@ def mark_global_tokens(self, num_total_insertions: int) -> bool:
self.pos[:, :, :num_to_mark] = self.max_seq_length
return num_to_mark == self.global_tokens

class KVCacheGist(KVCache):
relevant_kwargs = [
'gist_token_id',
'max_cache_length'
]
def __init__(
self, max_batch_size, n_heads, head_dim, dtype=torch.bfloat16, **kwargs
):
# Gist does not use additional compression strategies
self.prompt_compression_strategy = None
self.global_tokens = 0 # No global tokens for gist cache

self.gist_token_id = kwargs.pop('gist_token_id')
super().__init__(
max_batch_size, n_heads, head_dim, dtype, head_specific=False, **kwargs
)
self.prefill_attn_callback = {
"func": self.profile_and_update,
"kwargs": {},
}
self.register_buffer(
"ids", # Track ids to keep track of the original ids of each item in cache. required to determine gist mask in case of multi-batch inputs
torch.full(
(
max_batch_size,
self.max_cache_length,
),
-1,
dtype=torch.int,
),
)


def _update(self, input_pos, k_val, v_val, input_ids=None):
# input_pos: [S], k_val: [B, H, S, D], input_ids: [B, S]

self.fill_contiguous(input_pos, k_val, v_val)
self.ids[:, self.cache_cts[0]:self.cache_cts[0]+input_ids.shape[-1]] = input_ids
return input_pos.shape[-1]

def profile_and_update(self, input_pos, input_ids, k_val, v_val, attn):
assert self.is_prefill(), "Should only be profiling during prefill stage."

gist_pos = torch.where(input_ids == self.gist_token_id)[-1].min().cpu().item() # use lowest position of gist token in case of multi batch inputs
seq_len = input_pos.shape[-1]
input_pos = input_pos[gist_pos:]
input_ids = input_ids[:, gist_pos:]
k_val = k_val[:, :, gist_pos:, :]
v_val = v_val[:, :, gist_pos:, :]

self.fill_contiguous(input_pos, k_val, v_val)
self.ids[:, self.cache_cts[0]:self.cache_cts[0]+input_ids.shape[-1]] = input_ids
self.cache_cts[0] = input_pos.shape[-1]

def return_kv_cache(self):
k, v, mask = super().return_kv_cache()
mask_shape = (k.shape[0], k.shape[1], 1, k.shape[-2])
gist_mask = torch.ones(mask_shape, dtype=torch.bool).to(k.device)
gist_token_positions = torch.stack(torch.where(self.ids == self.gist_token_id)).T
for position in gist_token_positions:
gist_mask[position[0], :, :, :position[1]] = False
return k, v, gist_mask

class KVCacheFull(KVCache):
def __init__(
Expand Down Expand Up @@ -1258,6 +1323,8 @@ def get_cache_constructor(cache_strategy):
cls = KVCacheScissorhands
elif cache_strategy == "fastgen":
cls = KVCacheFastGen
elif cache_strategy == "gist":
cls = KVCacheGist
elif cache_strategy.startswith("debug"):
cache_strategy = re.sub(r"debug_+", "", cache_strategy).strip()
relevant_kwargs = get_cache_constructor(cache_strategy)[1]
Expand Down
13 changes: 9 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import time
import contextlib
import json
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -84,6 +85,8 @@ def main(

tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat)

gist_token_id = tokenizer.gist_token_id() if hasattr(tokenizer, "gist_token_id") else None

inputs = [encode(tokenizer, prompt, device=device, is_chat=is_chat)]

terminator_ids = tokenizer.get_terminator_ids()
Expand Down Expand Up @@ -124,6 +127,7 @@ def main(
inputs[0],
max_new_tokens=max_new_tokens,
terminator_ids=terminator_ids,
gist_token_id=gist_token_id,
feed_long_prompts=feed_long_prompts,
)
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
Expand Down Expand Up @@ -159,8 +163,8 @@ def main(
parser.add_argument(
"--prompt",
type=str,
default="long_prompt_short_output.txt",
help="Input prompt. If it ends in .txt, we will load the prompt from the ./prompts dir.",
default="long_prompt_short_output.json",
help="Input prompt. If it ends in .json, we will load the prompt from the ./prompts dir.",
)
parser.add_argument(
"--max_new_tokens", type=int, default=512, help="Maximum number of new tokens."
Expand All @@ -171,10 +175,10 @@ def main(

args = parser.parse_args()

if args.prompt.endswith(".txt"):
if args.prompt.endswith(".json"):
prompt_fn = Path(__file__).resolve().parent / "prompts" / args.prompt
with open(prompt_fn) as fd:
args.prompt = fd.read().strip()
args.prompt = json.load(fd)

cache_compatibility(args)

Expand All @@ -188,3 +192,4 @@ def main(
args.device,
cache_kwargs=vars(args),
)

11 changes: 11 additions & 0 deletions generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def prefill(
x: torch.Tensor,
input_pos: torch.Tensor,
next_token: torch.Tensor = None,
gist_token_id: Optional[int] = -1,
**sampling_kwargs,
) -> torch.Tensor:
# input_pos: [B, S]
Expand All @@ -113,6 +114,11 @@ def prefill(
.unsqueeze(0)
.to(x.device)
)
if gist_token_id is not None:
gist_token_positions = torch.stack(torch.where(x == gist_token_id)).T
for position in gist_token_positions:
causal_mask[position[0], :, position[1] + 1:, :position[1]] = False

logits = model(x, input_pos, mask=causal_mask)
return greedy(logits, next_token)

Expand Down Expand Up @@ -240,6 +246,9 @@ def setup_caches(
"punctuation": tokenizer.punctuation_ids(),
}

if "gist" in cache_kwargs["cache_strategy"]:
cache_kwargs["gist_token_id"] = tokenizer.gist_token_id()

with torch.device(device):
model.setup_caches(max_batch_size=1, **cache_kwargs)

Expand All @@ -258,6 +267,7 @@ def generate(
prompt: torch.Tensor,
max_new_tokens: int,
terminator_ids: Optional[list] = None,
gist_token_id: int = -1,
feed_long_prompts: bool = False,
**sampling_kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -295,6 +305,7 @@ def generate(
prompt.view(1, -1),
input_pos,
next_token=None if prefix is None else prefix[0].view(1),
gist_token_id=gist_token_id,
**sampling_kwargs,
)
next_token = ret[0].clone()
Expand Down
12 changes: 11 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def from_name(cls, name: str):
),
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Meta-Llama-3-8B-Instruct": dict(
"Meta-Llama-3-8B": dict(
block_size=8192,
n_layer=32,
n_head=32,
Expand Down Expand Up @@ -145,6 +145,16 @@ def from_name(cls, name: str):
norm_eps=1e-6,
max_length=32768,
),
"Meta-Llama-3-8B-gist-finetune": dict(
block_size=8192,
n_layer=32,
n_head=32,
n_local_heads=8,
dim=4096,
intermediate_size=14336,
vocab_size=128257,
rope_base=500000
),
}


Expand Down
4 changes: 4 additions & 0 deletions prompts/long_prompt_long_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"instruction": "You are an architect tasked with drawing up plans for a modern residential house.\n\nArchitectural Plan Creation Instructions\n\nObjective:\nCreate a comprehensive set of architectural plans for a modern residential house. The plans should include detailed layouts, elevations, sections, and necessary annotations to guide the construction process. The design should focus on functionality, aesthetics, sustainability, and compliance with local building codes.\n\nRequirements:\n\nGeneral Layout:\n\nTotal area: Approximately 2,500 square feet.\nNumber of floors: Two.\nNumber of bedrooms: Four (including a master suite).\nNumber of bathrooms: Three full bathrooms and one half bathroom.\nCommon areas: Open-plan kitchen, dining area, living room, and a study/office.\nAdditional spaces: Laundry room, garage (for two cars), storage rooms, and a small basement.\nSite Plan:\n\nInclude property boundaries, adjacent streets, and any existing structures.\nShow the placement of the house, driveway, pathways, garden, and outdoor living spaces (e.g., patio, deck).\nInclude landscaping elements like trees, shrubs, and lawn areas.\nFloor Plans:\n\nGround Floor: Include entryway, living spaces, kitchen, one bedroom (guest room), one full bathroom, and access to the garage.\nSecond Floor: Include master suite with attached bathroom and walk-in closet, two additional bedrooms, one full bathroom, and a study/office.\nIndicate all door and window placements, furniture layouts, and circulation paths.\nElevations:\n\nProvide front, rear, and side elevations.\nShow the external appearance, including the roof design, facade materials, window and door placements, and any architectural features (e.g., balconies, porches).\nSections:\n\nInclude at least two sections (one longitudinal and one cross-sectional) showing internal details.\nHighlight the relationship between different floors and ceiling heights.\nShow structural elements like beams, columns, and floor slabs.\nRoof Plan:\n\nIndicate the roof slope, materials, drainage system, and any roof features (e.g., skylights, chimneys).\nElectrical and Plumbing Plans:\n\nShow the layout of electrical outlets, switches, lighting fixtures, and major appliances.\nInclude the plumbing layout for water supply and drainage, showing the location of pipes, fixtures, and connections.\nMaterials and Finishes:\n\nSpecify the materials for walls, floors, ceilings, and roofs.\nInclude details on interior and exterior finishes (e.g., paint, tiles, cladding).\nSustainability Features:\n\nIncorporate energy-efficient systems (e.g., HVAC, solar panels).\nUse sustainable building materials.\nPlan for natural lighting and ventilation.\nInclude rainwater harvesting and greywater recycling systems if possible.\nCompliance:\n\nEnsure the design complies with local building codes and regulations.\nInclude necessary annotations and notes for construction guidelines.\n\n",
"input": "You must return the following:\n- Include a detailed list of materials and specifications.\n- Add a cover sheet with project title, address, date, and designer's name.\n- Add a sheet for each component with detailed plans.\n- Ensure all documents are clearly labeled and organized."
}
63 changes: 0 additions & 63 deletions prompts/long_prompt_long_output.txt

This file was deleted.

Loading

0 comments on commit 9056c97

Please sign in to comment.