Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 58 additions & 11 deletions src/nemotron/data_prep/core/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,42 @@ def find_last_user_message_end(
Returns:
Character position where last user message ends.

Raises:
ValueError: If ``messages`` contains no user message, or if the last
user message is the final message in the conversation (no
assistant turn follows). Both cases would otherwise crash with
``ValueError`` / ``IndexError``; we raise typed errors with
informative messages so upstream filters can categorize them.

Note:
Exact port of materialize.py::find_last_user_message_end()
Exact port of materialize.py::find_last_user_message_end(), with
added input guards. The actual rendering pipeline is preserved
byte-for-byte to keep output identical to the original on rows it
processes successfully.
"""
# Find the last user message index
last_user_idx = max(i for i, msg in enumerate(messages) if msg["role"] == "user")
# Guard: at least one user message must exist. The original
# implementation crashes with ``ValueError: max() iterable argument is
# empty`` here; we raise a typed error instead so callers can
# distinguish this from other failures.
user_idxs = [i for i, msg in enumerate(messages) if msg["role"] == "user"]
if not user_idxs:
raise ValueError("conversation has no user message")
last_user_idx = user_idxs[-1]

# Guard: there must be a message after the last user turn so the
# ``messages[last_user_idx + 1]`` access below is safe. The original
# implementation crashes with ``IndexError`` for trailing-user
# conversations; we raise a typed error instead.
if last_user_idx + 1 >= len(messages):
raise ValueError(
"conversation ends with a user message; no assistant response to render"
)

# Render up to the last user message (inclusive)
# Render up to the last user message (inclusive). Mirrors the original
# exactly -- no rstrip here, because there is no prefix-mismatch signal
# at this layer to gate a fallback on. Stripping unconditionally would
# shift chunk boundaries on every row, including ones the original
# processed cleanly.
if enable_thinking and (
"reasoning_content" not in messages[last_user_idx + 1]
or messages[last_user_idx + 1]["reasoning_content"] == ""
Expand Down Expand Up @@ -152,7 +181,13 @@ def split_template_into_messages(
# Get first "message": if starting from last user, this includes all prior turns
if start_from_last_user:
system_end = full_template.find("<|im_end|>\n") + len("<|im_end|>\n")
last_user_idx = max(i for i, msg in enumerate(messages) if msg["role"] == "user")
# Guard mirrors find_last_user_message_end: original crashes here
# with ``ValueError: max() iterable argument is empty`` when the
# conversation has no user message.
user_idxs = [i for i, msg in enumerate(messages) if msg["role"] == "user"]
if not user_idxs:
raise ValueError("conversation has no user message")
last_user_idx = user_idxs[-1]
last_user_pos = find_last_user_message_end(
messages, tokenizer, enable_thinking=enable_thinking, tools=tools
)
Expand Down Expand Up @@ -207,15 +242,27 @@ def split_template_into_messages(
chat_template_kwargs={"enable_thinking": enable_thinking},
)

# Verify incremental rendering matches full template. Strict check
# first -- this preserves byte-identical chunk boundaries with the
# original implementation on every row it processed successfully.
if template_up_to_here != full_template[: len(template_up_to_here)]:
# Trailing-whitespace fallback for issue #184: some chat
# templates emit a newline after the generation prompt that
# does not appear at the matching position in the full
# rendering. rstrip lets us recover those rows instead of
# filtering them out, while staying a no-op for templates that
# already line up exactly.
stripped = template_up_to_here.rstrip()
if stripped and stripped == full_template[: len(stripped)]:
template_up_to_here = stripped
else:
raise ValueError(
f"Template mismatch at message {i}: incremental rendering doesn't match full"
)

current_pos = len(template_up_to_here)
chunk_text = full_template[previous_pos:current_pos]

# Verify incremental rendering matches full template
if template_up_to_here != full_template[:current_pos]:
raise ValueError(
f"Template mismatch at message {i}: incremental rendering doesn't match full"
)

result.append({"role": messages[i]["role"], "content": chunk_text})
previous_pos = current_pos

Expand Down
251 changes: 251 additions & 0 deletions tests/data_prep/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for chat_template incremental rendering.

Covers:

* Issue #184 regression: tokenizer chat templates that append an extra
trailing newline after the generation prompt no longer cause rows to be
filtered out (the rstrip fallback recovers them).
* Semantic preservation: rows the original ``materialize.py`` processed
successfully produce byte-identical chunks. The rstrip fallback only fires
when the strict prefix check fails, so well-behaved templates are
unaffected.
* Input guards: empty / no-user / trailing-user conversations now raise
typed ``ValueError`` instead of crashing with bare ``IndexError`` /
``ValueError`` from internal indexing.
"""

from __future__ import annotations

import pytest

from nemotron.data_prep.core.chat_template import (
find_last_user_message_end,
split_template_into_messages,
)


class _BuggyTokenizer:
"""Fake tokenizer that reproduces the issue #184 pathology.

``apply_chat_template`` renders messages naturally for the full template
but, when ``add_generation_prompt=True``, appends an extra trailing
newline after ``<|im_start|>assistant\\n``. This is the exact shape that
causes the prefix check in ``split_template_into_messages`` to fail
before the conditional-rstrip fallback was added.
"""

def apply_chat_template(
self,
messages: list[dict],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: list | None = None,
chat_template_kwargs: dict | None = None,
) -> str:
out = ""
for m in messages:
out += f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n"
if add_generation_prompt:
# The bug: an extra \n that does NOT appear in the full template
# at the corresponding position.
out += "<|im_start|>assistant\n\n"
return out


class _CleanTokenizer:
"""Fake tokenizer with no trailing-whitespace pathology.

Used to confirm the conditional-rstrip fallback does NOT fire for
well-behaved templates -- chunk boundaries must be byte-identical to
the original implementation.
"""

def apply_chat_template(
self,
messages: list[dict],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: list | None = None,
chat_template_kwargs: dict | None = None,
) -> str:
out = ""
for m in messages:
out += f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n"
if add_generation_prompt:
out += "<|im_start|>assistant\n"
return out


@pytest.fixture
def messages() -> list[dict]:
return [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "Hello"},
]


# =============================================================================
# Issue #184 -- trailing newline fallback
# =============================================================================


class TestIssue184TrailingNewline:
"""Regression tests for issue #184."""

def test_buggy_template_does_not_filter_row(self, messages: list[dict]) -> None:
"""Trailing-newline pathology must NOT raise after the fallback.

Before the fix this raised ``ValueError: Template mismatch at
message 0`` and the whole conversation was dropped from the SFT
dataset.
"""
chunks = split_template_into_messages(
messages,
_BuggyTokenizer(),
start_from_last_user=False,
enable_thinking=False,
)

assert [c["role"] for c in chunks] == ["user", "assistant"]

def test_buggy_template_chunks_reconstruct_full(self, messages: list[dict]) -> None:
"""Recovered chunks must concatenate back to the full template.

Conditional rstrip means the fallback fires on exactly one
iteration (the buggy boundary). The next iteration's strict check
passes, so ``current_pos`` lands on the natural end of the full
template -- joined chunks reconstruct ``full`` exactly, with no
trailing-whitespace drift.
"""
tokenizer = _BuggyTokenizer()
full = tokenizer.apply_chat_template(messages, add_generation_prompt=False)

chunks = split_template_into_messages(
messages,
tokenizer,
start_from_last_user=False,
enable_thinking=False,
)

assert "".join(c["content"] for c in chunks) == full


# =============================================================================
# Semantic preservation -- byte-identical output for clean templates
# =============================================================================


class TestSemanticPreservation:
"""The fallback must NOT fire for templates the original handled cleanly."""

def test_clean_template_chunks_reconstruct_full_exactly(
self,
messages: list[dict],
) -> None:
"""Joined chunks must equal ``full`` byte-for-byte (no rstrip drift)."""
tokenizer = _CleanTokenizer()
full = tokenizer.apply_chat_template(messages, add_generation_prompt=False)

chunks = split_template_into_messages(
messages,
tokenizer,
start_from_last_user=False,
enable_thinking=False,
)

assert [c["role"] for c in chunks] == ["user", "assistant"]
assert "".join(c["content"] for c in chunks) == full

def test_clean_template_preserves_trailing_newline(
self,
messages: list[dict],
) -> None:
"""The final chunk's content must keep the natural trailing newline.

Regression guard for the previous unconditional-rstrip implementation,
which dropped the trailing ``\\n`` of the last chunk on every row.
"""
tokenizer = _CleanTokenizer()
chunks = split_template_into_messages(
messages,
tokenizer,
start_from_last_user=False,
enable_thinking=False,
)

assert chunks[-1]["content"].endswith("<|im_end|>\n")


# =============================================================================
# Input guards -- typed errors instead of bare IndexError / ValueError
# =============================================================================


class TestInputGuards:
"""Edge cases that previously crashed must now raise informative errors."""

def test_no_user_message_raises_value_error(self) -> None:
"""A conversation with no user turn must raise a clear ``ValueError``.

The original implementation crashed with the opaque
``ValueError: max() iterable argument is empty``.
"""
messages = [
{"role": "system", "content": "sys"},
{"role": "assistant", "content": "Hi"},
]
with pytest.raises(ValueError, match="no user message"):
find_last_user_message_end(
messages,
_CleanTokenizer(),
enable_thinking=False,
)

def test_trailing_user_message_raises_value_error(self) -> None:
"""A conversation ending with a user turn must raise a clear error.

The original implementation crashed with ``IndexError`` while
accessing ``messages[last_user_idx + 1]``.
"""
messages = [
{"role": "system", "content": "sys"},
{"role": "user", "content": "hi"},
]
with pytest.raises(ValueError, match="ends with a user message"):
find_last_user_message_end(
messages,
_CleanTokenizer(),
enable_thinking=False,
)

def test_split_template_propagates_no_user_guard(self) -> None:
"""``split_template_into_messages`` should also surface the guard.

The ``start_from_last_user=True`` path delegates to
``find_last_user_message_end``; the typed error must propagate.
"""
messages = [
{"role": "system", "content": "sys"},
{"role": "assistant", "content": "Hi"},
]
with pytest.raises(ValueError, match="no user message"):
split_template_into_messages(
messages,
_CleanTokenizer(),
start_from_last_user=True,
enable_thinking=False,
)
Loading