diff --git a/src/nemotron/data_prep/core/chat_template.py b/src/nemotron/data_prep/core/chat_template.py index d940dd6ac..f23f82058 100644 --- a/src/nemotron/data_prep/core/chat_template.py +++ b/src/nemotron/data_prep/core/chat_template.py @@ -79,13 +79,42 @@ def find_last_user_message_end( Returns: Character position where last user message ends. + Raises: + ValueError: If ``messages`` contains no user message, or if the last + user message is the final message in the conversation (no + assistant turn follows). Both cases would otherwise crash with + ``ValueError`` / ``IndexError``; we raise typed errors with + informative messages so upstream filters can categorize them. + Note: - Exact port of materialize.py::find_last_user_message_end() + Exact port of materialize.py::find_last_user_message_end(), with + added input guards. The actual rendering pipeline is preserved + byte-for-byte to keep output identical to the original on rows it + processes successfully. """ - # Find the last user message index - last_user_idx = max(i for i, msg in enumerate(messages) if msg["role"] == "user") + # Guard: at least one user message must exist. The original + # implementation crashes with ``ValueError: max() iterable argument is + # empty`` here; we raise a typed error instead so callers can + # distinguish this from other failures. + user_idxs = [i for i, msg in enumerate(messages) if msg["role"] == "user"] + if not user_idxs: + raise ValueError("conversation has no user message") + last_user_idx = user_idxs[-1] + + # Guard: there must be a message after the last user turn so the + # ``messages[last_user_idx + 1]`` access below is safe. The original + # implementation crashes with ``IndexError`` for trailing-user + # conversations; we raise a typed error instead. + if last_user_idx + 1 >= len(messages): + raise ValueError( + "conversation ends with a user message; no assistant response to render" + ) - # Render up to the last user message (inclusive) + # Render up to the last user message (inclusive). Mirrors the original + # exactly -- no rstrip here, because there is no prefix-mismatch signal + # at this layer to gate a fallback on. Stripping unconditionally would + # shift chunk boundaries on every row, including ones the original + # processed cleanly. if enable_thinking and ( "reasoning_content" not in messages[last_user_idx + 1] or messages[last_user_idx + 1]["reasoning_content"] == "" @@ -152,7 +181,13 @@ def split_template_into_messages( # Get first "message": if starting from last user, this includes all prior turns if start_from_last_user: system_end = full_template.find("<|im_end|>\n") + len("<|im_end|>\n") - last_user_idx = max(i for i, msg in enumerate(messages) if msg["role"] == "user") + # Guard mirrors find_last_user_message_end: original crashes here + # with ``ValueError: max() iterable argument is empty`` when the + # conversation has no user message. + user_idxs = [i for i, msg in enumerate(messages) if msg["role"] == "user"] + if not user_idxs: + raise ValueError("conversation has no user message") + last_user_idx = user_idxs[-1] last_user_pos = find_last_user_message_end( messages, tokenizer, enable_thinking=enable_thinking, tools=tools ) @@ -207,15 +242,27 @@ def split_template_into_messages( chat_template_kwargs={"enable_thinking": enable_thinking}, ) + # Verify incremental rendering matches full template. Strict check + # first -- this preserves byte-identical chunk boundaries with the + # original implementation on every row it processed successfully. + if template_up_to_here != full_template[: len(template_up_to_here)]: + # Trailing-whitespace fallback for issue #184: some chat + # templates emit a newline after the generation prompt that + # does not appear at the matching position in the full + # rendering. rstrip lets us recover those rows instead of + # filtering them out, while staying a no-op for templates that + # already line up exactly. + stripped = template_up_to_here.rstrip() + if stripped and stripped == full_template[: len(stripped)]: + template_up_to_here = stripped + else: + raise ValueError( + f"Template mismatch at message {i}: incremental rendering doesn't match full" + ) + current_pos = len(template_up_to_here) chunk_text = full_template[previous_pos:current_pos] - # Verify incremental rendering matches full template - if template_up_to_here != full_template[:current_pos]: - raise ValueError( - f"Template mismatch at message {i}: incremental rendering doesn't match full" - ) - result.append({"role": messages[i]["role"], "content": chunk_text}) previous_pos = current_pos diff --git a/tests/data_prep/test_chat_template.py b/tests/data_prep/test_chat_template.py new file mode 100644 index 000000000..033a615a5 --- /dev/null +++ b/tests/data_prep/test_chat_template.py @@ -0,0 +1,251 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for chat_template incremental rendering. + +Covers: + +* Issue #184 regression: tokenizer chat templates that append an extra + trailing newline after the generation prompt no longer cause rows to be + filtered out (the rstrip fallback recovers them). +* Semantic preservation: rows the original ``materialize.py`` processed + successfully produce byte-identical chunks. The rstrip fallback only fires + when the strict prefix check fails, so well-behaved templates are + unaffected. +* Input guards: empty / no-user / trailing-user conversations now raise + typed ``ValueError`` instead of crashing with bare ``IndexError`` / + ``ValueError`` from internal indexing. +""" + +from __future__ import annotations + +import pytest + +from nemotron.data_prep.core.chat_template import ( + find_last_user_message_end, + split_template_into_messages, +) + + +class _BuggyTokenizer: + """Fake tokenizer that reproduces the issue #184 pathology. + + ``apply_chat_template`` renders messages naturally for the full template + but, when ``add_generation_prompt=True``, appends an extra trailing + newline after ``<|im_start|>assistant\\n``. This is the exact shape that + causes the prefix check in ``split_template_into_messages`` to fail + before the conditional-rstrip fallback was added. + """ + + def apply_chat_template( + self, + messages: list[dict], + tokenize: bool = False, + add_generation_prompt: bool = False, + tools: list | None = None, + chat_template_kwargs: dict | None = None, + ) -> str: + out = "" + for m in messages: + out += f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n" + if add_generation_prompt: + # The bug: an extra \n that does NOT appear in the full template + # at the corresponding position. + out += "<|im_start|>assistant\n\n" + return out + + +class _CleanTokenizer: + """Fake tokenizer with no trailing-whitespace pathology. + + Used to confirm the conditional-rstrip fallback does NOT fire for + well-behaved templates -- chunk boundaries must be byte-identical to + the original implementation. + """ + + def apply_chat_template( + self, + messages: list[dict], + tokenize: bool = False, + add_generation_prompt: bool = False, + tools: list | None = None, + chat_template_kwargs: dict | None = None, + ) -> str: + out = "" + for m in messages: + out += f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n" + if add_generation_prompt: + out += "<|im_start|>assistant\n" + return out + + +@pytest.fixture +def messages() -> list[dict]: + return [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "Hello"}, + ] + + +# ============================================================================= +# Issue #184 -- trailing newline fallback +# ============================================================================= + + +class TestIssue184TrailingNewline: + """Regression tests for issue #184.""" + + def test_buggy_template_does_not_filter_row(self, messages: list[dict]) -> None: + """Trailing-newline pathology must NOT raise after the fallback. + + Before the fix this raised ``ValueError: Template mismatch at + message 0`` and the whole conversation was dropped from the SFT + dataset. + """ + chunks = split_template_into_messages( + messages, + _BuggyTokenizer(), + start_from_last_user=False, + enable_thinking=False, + ) + + assert [c["role"] for c in chunks] == ["user", "assistant"] + + def test_buggy_template_chunks_reconstruct_full(self, messages: list[dict]) -> None: + """Recovered chunks must concatenate back to the full template. + + Conditional rstrip means the fallback fires on exactly one + iteration (the buggy boundary). The next iteration's strict check + passes, so ``current_pos`` lands on the natural end of the full + template -- joined chunks reconstruct ``full`` exactly, with no + trailing-whitespace drift. + """ + tokenizer = _BuggyTokenizer() + full = tokenizer.apply_chat_template(messages, add_generation_prompt=False) + + chunks = split_template_into_messages( + messages, + tokenizer, + start_from_last_user=False, + enable_thinking=False, + ) + + assert "".join(c["content"] for c in chunks) == full + + +# ============================================================================= +# Semantic preservation -- byte-identical output for clean templates +# ============================================================================= + + +class TestSemanticPreservation: + """The fallback must NOT fire for templates the original handled cleanly.""" + + def test_clean_template_chunks_reconstruct_full_exactly( + self, + messages: list[dict], + ) -> None: + """Joined chunks must equal ``full`` byte-for-byte (no rstrip drift).""" + tokenizer = _CleanTokenizer() + full = tokenizer.apply_chat_template(messages, add_generation_prompt=False) + + chunks = split_template_into_messages( + messages, + tokenizer, + start_from_last_user=False, + enable_thinking=False, + ) + + assert [c["role"] for c in chunks] == ["user", "assistant"] + assert "".join(c["content"] for c in chunks) == full + + def test_clean_template_preserves_trailing_newline( + self, + messages: list[dict], + ) -> None: + """The final chunk's content must keep the natural trailing newline. + + Regression guard for the previous unconditional-rstrip implementation, + which dropped the trailing ``\\n`` of the last chunk on every row. + """ + tokenizer = _CleanTokenizer() + chunks = split_template_into_messages( + messages, + tokenizer, + start_from_last_user=False, + enable_thinking=False, + ) + + assert chunks[-1]["content"].endswith("<|im_end|>\n") + + +# ============================================================================= +# Input guards -- typed errors instead of bare IndexError / ValueError +# ============================================================================= + + +class TestInputGuards: + """Edge cases that previously crashed must now raise informative errors.""" + + def test_no_user_message_raises_value_error(self) -> None: + """A conversation with no user turn must raise a clear ``ValueError``. + + The original implementation crashed with the opaque + ``ValueError: max() iterable argument is empty``. + """ + messages = [ + {"role": "system", "content": "sys"}, + {"role": "assistant", "content": "Hi"}, + ] + with pytest.raises(ValueError, match="no user message"): + find_last_user_message_end( + messages, + _CleanTokenizer(), + enable_thinking=False, + ) + + def test_trailing_user_message_raises_value_error(self) -> None: + """A conversation ending with a user turn must raise a clear error. + + The original implementation crashed with ``IndexError`` while + accessing ``messages[last_user_idx + 1]``. + """ + messages = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hi"}, + ] + with pytest.raises(ValueError, match="ends with a user message"): + find_last_user_message_end( + messages, + _CleanTokenizer(), + enable_thinking=False, + ) + + def test_split_template_propagates_no_user_guard(self) -> None: + """``split_template_into_messages`` should also surface the guard. + + The ``start_from_last_user=True`` path delegates to + ``find_last_user_message_end``; the typed error must propagate. + """ + messages = [ + {"role": "system", "content": "sys"}, + {"role": "assistant", "content": "Hi"}, + ] + with pytest.raises(ValueError, match="no user message"): + split_template_into_messages( + messages, + _CleanTokenizer(), + start_from_last_user=True, + enable_thinking=False, + )