Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 47 additions & 22 deletions tinker_cookbook/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Message(TypedDict):
content: str
tool_calls: NotRequired[list[ToolCall]]
thinking: NotRequired[str]
trainable: NotRequired[bool]


class TrainOnWhat(StrEnum):
Expand All @@ -42,6 +43,7 @@ class TrainOnWhat(StrEnum):
ALL_MESSAGES = "all_messages"
ALL_TOKENS = "all_tokens"
ALL_USER_AND_SYSTEM_MESSAGES = "all_user_and_system_messages"
CUSTOMIZED = "customized"


class Renderer:
Expand Down Expand Up @@ -101,6 +103,10 @@ def build_supervised_example(
train_on_what: an enum that controls how the weights are assigned to the tokens.
- TrainOnWhat.LAST_ASSISTANT_MESSAGE: only the last assistant message is used for training
- TrainOnWhat.ALL_ASSISTANT_MESSAGES: all assistant messages are used for training
- TrainOnWhat.ALL_MESSAGES: all messages are used for training
- TrainOnWhat.ALL_TOKENS: all tokens are used for training
- TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES: all user and system messages are used for training
- TrainOnWhat.CUSTOMIZED: each message has a trainable field, and the weights are assigned based on the trainable field
messages: a list of messages to render.

Returns:
Expand All @@ -110,29 +116,48 @@ def build_supervised_example(
"""
tokens_weights = [(token, 0) for token in start_tokens]
for idx, message in enumerate(messages[:-1]):
ob_part, action_part, action_tail = render_message(idx, message)
if train_on_what == TrainOnWhat.LAST_ASSISTANT_MESSAGE:
tokens_weights.extend([(token, 0) for token in ob_part + action_part])
elif train_on_what == TrainOnWhat.ALL_ASSISTANT_MESSAGES:
tokens_weights += [(token, 0) for token in ob_part]
# TODO: look at the previous action tail and its overlap with the current action part
# and put weight of 1 on those tokens too.
is_assistant = message["role"] == "assistant"
tokens_weights += [(token, int(is_assistant)) for token in action_part]
elif train_on_what == TrainOnWhat.ALL_MESSAGES:
tokens_weights += [(token, 0) for token in ob_part]
tokens_weights += [(token, 1) for token in action_part]
elif train_on_what == TrainOnWhat.ALL_TOKENS:
tokens_weights += [(token, 1) for token in ob_part + action_part]
elif train_on_what == TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES:
tokens_weights += [(token, 0) for token in ob_part]
is_user_or_system = message["role"] in ["user", "system"]
tokens_weights += [(token, int(is_user_or_system)) for token in action_part]
if train_on_what == TrainOnWhat.CUSTOMIZED:
assert "trainable" in message, (
"When using CUSTOMIZED train_on_what, each message must have a trainable field: True if loss is applied on this message, False otherwise"
)
else:
raise ValueError(f"Unknown train_on_what: {train_on_what}")
ob_part, action_part, action_tail = render_message(len(messages) - 1, messages[-1])
tokens_weights.extend([(token, 0) for token in ob_part])
tokens_weights.extend([(token, 1) for token in action_part + action_tail])
assert "trainable" not in message, (
"When using non-CUSTOMIZED train_on_what, each message must not have a trainable field. Either change train_on_what to CUSTOMIZED or remove the trainable field from the message"
)

is_last_message = idx == len(messages) - 1
is_assistant = message["role"] == "assistant"
is_user_or_system = message["role"] in ["user", "system"]

# only apply weight to observation part if train_on_what is ALL_TOKENS
ob_part, action_part, action_tail = render_message(idx, message)
ob_weight = int(train_on_what == TrainOnWhat.ALL_TOKENS)
tokens_weights += [(token, ob_weight) for token in ob_part]

action_tokens = action_part
# action tail is effectively the stop_token and the start token for the next turn
# e.g. \n\nUser:
if is_last_message:
action_tokens += action_tail

match train_on_what:
case TrainOnWhat.LAST_ASSISTANT_MESSAGE:
action_has_weight = is_last_message and is_assistant
case TrainOnWhat.ALL_ASSISTANT_MESSAGES:
action_has_weight = is_assistant
case TrainOnWhat.ALL_MESSAGES:
action_has_weight = True
case TrainOnWhat.ALL_TOKENS:
action_has_weight = True
case TrainOnWhat.ALL_USER_AND_SYSTEM_MESSAGES:
action_has_weight = is_user_or_system
case TrainOnWhat.CUSTOMIZED:
action_has_weight = message.get("trainable", False)
case _:
raise ValueError(f"Unknown train_on_what: {train_on_what}")

tokens_weights += [(token, int(action_has_weight)) for token in action_tokens]

tokens, weights = zip(*tokens_weights, strict=True)
return torch.tensor(tokens), torch.tensor(weights)

Expand Down