This repository was archived by the owner on Jul 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask2.py
More file actions
118 lines (107 loc) · 3.33 KB
/
task2.py
File metadata and controls
118 lines (107 loc) · 3.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
from pathlib import Path
from argparse import ArgumentParser, Namespace
import torch
from miditok import REMI, TokenizerConfig
from transformers import AutoModelForCausalLM, GenerationConfig
from tqdm import tqdm
from src.constants import CONFIG_FILE, CKPT_FILE
from src.utils import (
get_file_paths,
get_device,
get_trucated_idx,
load_config,
save_json,
truncate_to_nbars,
generate_tokens,
filter_invalid_tokens,
)
def parse_arguments() -> Namespace:
parser = ArgumentParser(description="Task 2")
parser.add_argument(
"--prompt_song_path",
type=str,
default="prompt_song",
help="path of prompt song",
)
parser.add_argument(
"--ckpt_path",
type=str,
default="checkpoints/11-08-23-02-35",
help="path of checkpoint",
)
parser.add_argument(
"--num_velocities",
type=int,
default=16,
)
parser.add_argument(
"--n_target_bar",
type=int,
default=32,
help="number of target bars",
)
parser.add_argument(
"--max_length",
type=int,
default=1024,
)
parser.add_argument(
"--top_k",
type=int,
default=5,
)
parser.add_argument(
"--temperature",
type=float,
default=1.2,
)
parser.add_argument(
"--repetition_penalty",
type=float,
default=1.2,
)
parser.add_argument(
"--output_folder",
type=str,
default="results",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
save_folder = Path(args.output_folder, Path(args.ckpt_path).name, "task2")
os.makedirs(save_folder, exist_ok=True)
ckpt_config = load_config(Path(args.ckpt_path, CONFIG_FILE))
tokenizer_config = TokenizerConfig(
num_velocities=args.num_velocities,
use_chords=True,
use_tempos=True,
use_programs=True,
params=Path(args.ckpt_path, "tokenizer.json")
)
tokenizer = REMI(tokenizer_config)
BAR_TOKEN = [v for k, v in tokenizer.vocab.items() if "Bar" in k][0]
midi_paths = get_file_paths(args.prompt_song_path)
truncated_midi_tokens = truncate_to_nbars(midi_paths, tokenizer, num_bar=8)
model = AutoModelForCausalLM.from_pretrained(ckpt_config.model_name)
model.load_state_dict(
torch.load(Path(args.ckpt_path, CKPT_FILE), weights_only=True)["model"]
)
device = get_device()
model.to(device)
model.eval()
generation_config = GenerationConfig(
max_length=args.max_length,
do_sample=True,
top_k=args.top_k,
temperature=args.temperature,
pad_token_id=model.config.eos_token_id,
repetition_penalty=args.repetition_penalty,
)
save_json(vars(args) | {"checkpoint": ckpt_config}, Path(save_folder, CONFIG_FILE))
for i, data in enumerate(tqdm(truncated_midi_tokens), start=1):
generated_tokens = generate_tokens(data, model, device, args.n_target_bar, BAR_TOKEN, generation_config)
truncated_idx = get_trucated_idx(generated_tokens, tokenizer, args.n_target_bar)
valid_tokens = filter_invalid_tokens(generated_tokens[:truncated_idx + 1], tokenizer)
generated_midi = tokenizer.decode(valid_tokens)
generated_midi.dump_midi(Path(save_folder, f"song_{i}.mid"))