-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathzett.py
More file actions
64 lines (50 loc) · 1.8 KB
/
zett.py
File metadata and controls
64 lines (50 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import logging
from pprint import pformat
import jax
from dataclasses import dataclass, asdict
from transformers import FlaxAutoModelForCausalLM
from tokenkit.hf import get_config
from tokenkit import utils
from tokenkit.byteify import load_byteify_tokenizer
from tokenkit.models import param, sharding
from tokenkit import parse_args
logger = logging.getLogger(__name__)
@dataclass
class ZettArgs:
source_model: parse_args.ModelArgs
target_tokenizer_name: str
output: str
def main(args: ZettArgs) -> None:
logger.info(pformat(args))
# Load the model & tokenizer
source_tokenizer = load_byteify_tokenizer(args.source_model.tokenizer_name)
target_tokenizer = load_byteify_tokenizer(args.target_tokenizer_name)
mesh = sharding.get_mesh(devices=jax.devices("cpu"))
config = get_config(args.source_model.pretrained_model_name_or_path)
config.mesh = mesh
model = FlaxAutoModelForCausalLM.from_config(
config,
_do_init=False,
input_shape=(1, 128),
)
del model.config.mesh
model_params = param.load_params(**asdict(args.source_model))
embeddings, model_params = param.stack_embeddings(
model_params,
config,
pop_embeddings=True,
)
diff_embeddings, original_to_new_indices, diff_indices = utils.fvt(
source_tokenizer,
target_tokenizer,
embeddings,
)
new_embeddings = embeddings[original_to_new_indices]
if len(diff_indices) > 0:
new_embeddings[diff_indices] = diff_embeddings
model_params = param.assign_embeddings(model_params, new_embeddings, config)
model.save_pretrained(args.output, params=model_params)
config.save_pretrained(args.output)
target_tokenizer.save_pretrained(args.output)
if __name__ == "__main__":
main(parse_args.parse_args(ZettArgs))