diff --git a/examples/number_token_loss_example.py b/examples/number_token_loss_example.py new file mode 100644 index 000000000000..148579383aa3 --- /dev/null +++ b/examples/number_token_loss_example.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating the use of Number Token Loss (NTL) in transformers. + +This script shows how to: +1. Use NTL-WAS and NTL-MSE losses with a language model +2. Compare the performance with standard cross-entropy loss +3. Demonstrate the benefits on numerical tasks +""" + +import torch +import torch.nn as nn +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling +) +from transformers.loss import ForCausalLMWithNTLWAS, ForCausalLMWithNTLMSE +from datasets import Dataset +import numpy as np + + +def create_math_dataset(num_examples=1000): + """ + Create a simple math dataset for demonstration. + + Args: + num_examples: Number of examples to generate + + Returns: + Dataset with math problems and solutions + """ + examples = [] + + for i in range(num_examples): + # Generate simple arithmetic problems + a = np.random.randint(1, 100) + b = np.random.randint(1, 100) + operation = np.random.choice(['+', '-', '*']) + + if operation == '+': + result = a + b + elif operation == '-': + result = a - b + else: # '*' + result = a * b + + # Create the problem text + problem = f"What is {a} {operation} {b}? The answer is {result}." + examples.append({"text": problem}) + + return Dataset.from_list(examples) + + +def custom_loss_function(loss_name, tokenizer=None, alpha=0.1): + """ + Create a custom loss function based on the specified loss type. + + Args: + loss_name: Type of loss ('ce', 'ntl_was', 'ntl_mse') + tokenizer: Tokenizer for NTL losses + alpha: Weight for NTL loss component + + Returns: + Loss function + """ + if loss_name == 'ce': + # Standard cross-entropy loss + return nn.CrossEntropyLoss() + elif loss_name == 'ntl_was': + # NTL with Wasserstein-1 distance + def ntl_was_loss(logits, labels): + return ForCausalLMWithNTLWAS( + logits, labels, + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + alpha=alpha + ) + return ntl_was_loss + elif loss_name == 'ntl_mse': + # NTL with MSE + def ntl_mse_loss(logits, labels): + return ForCausalLMWithNTLMSE( + logits, labels, + vocab_size=tokenizer.vocab_size, + tokenizer=tokenizer, + alpha=alpha + ) + return ntl_mse_loss + else: + raise ValueError(f"Unknown loss type: {loss_name}") + + +class CustomTrainer(Trainer): + """Custom trainer that supports different loss functions.""" + + def __init__(self, loss_function, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_function = loss_function + + def compute_loss(self, model, inputs, return_outputs=False): + """ + Compute the loss using the custom loss function. + """ + outputs = model(**inputs) + logits = outputs.logits + + # Get labels from inputs + labels = inputs.get("labels") + if labels is None: + # If no labels provided, use input_ids shifted by 1 + labels = inputs["input_ids"].clone() + labels[:, :-1] = inputs["input_ids"][:, 1:] + labels[:, -1] = -100 # Ignore last token + + # Compute loss using custom loss function + if isinstance(self.loss_function, nn.Module): + # Standard loss function (e.g., CrossEntropyLoss) + loss = self.loss_function(logits.view(-1, logits.size(-1)), labels.view(-1)) + else: + # Custom loss function (e.g., NTL) + loss = self.loss_function(logits, labels) + + return (loss, outputs) if return_outputs else loss + + +def main(): + """Main function demonstrating Number Token Loss usage.""" + print("Number Token Loss (NTL) Example") + print("=" * 50) + + # Load tokenizer and model + model_name = "gpt2" # Using GPT-2 for demonstration + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + + # Add padding token if not present + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Create dataset + print("Creating math dataset...") + dataset = create_math_dataset(num_examples=500) + + # Tokenize dataset + def tokenize_function(examples): + return tokenizer(examples["text"], truncation=True, padding=True) + + tokenized_dataset = dataset.map(tokenize_function, batched=True) + + # Data collator + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, # We're doing causal LM, not masked LM + ) + + # Training arguments + training_args = TrainingArguments( + output_dir="./ntl_example_output", + num_train_epochs=1, + per_device_train_batch_size=4, + save_steps=100, + save_total_limit=2, + logging_steps=50, + learning_rate=5e-5, + warmup_steps=100, + remove_unused_columns=False, + ) + + # Test different loss functions + loss_functions = { + 'Cross-Entropy': custom_loss_function('ce'), + 'NTL-WAS': custom_loss_function('ntl_was', tokenizer, alpha=0.1), + 'NTL-MSE': custom_loss_function('ntl_mse', tokenizer, alpha=0.1), + } + + results = {} + + for loss_name, loss_function in loss_functions.items(): + print(f"\nTraining with {loss_name} loss...") + + # Create trainer with custom loss + trainer = CustomTrainer( + loss_function=loss_function, + model=model, + args=training_args, + train_dataset=tokenized_dataset, + data_collator=data_collator, + ) + + # Train the model + trainer.train() + + # Evaluate on a simple test + test_text = "What is 15 + 27? The answer is" + inputs = tokenizer(test_text, return_tensors="pt") + + with torch.no_grad(): + outputs = model(**inputs) + logits = outputs.logits[:, -1, :] # Get logits for last position + probs = torch.softmax(logits, dim=-1) + + # Get top 5 predictions + top_probs, top_indices = torch.topk(probs, 5) + + print(f"\nTop 5 predictions for '{test_text}':") + for i in range(5): + token = tokenizer.decode([top_indices[0][i]]) + prob = top_probs[0][i].item() + print(f" {token}: {prob:.4f}") + + results[loss_name] = { + 'model': model, + 'final_loss': trainer.state.log_history[-1]['train_loss'] if trainer.state.log_history else None + } + + # Print summary + print("\n" + "=" * 50) + print("Training Summary:") + print("=" * 50) + for loss_name, result in results.items(): + print(f"{loss_name}: Final Loss = {result['final_loss']:.4f}") + + print("\nNote: This is a demonstration. For real applications:") + print("- Use larger models and datasets") + print("- Tune hyperparameters (alpha, learning rate, etc.)") + print("- Evaluate on proper test sets") + print("- Consider the computational cost of NTL") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 75c4cbf3451b..01aca9bf652d 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -23,6 +23,7 @@ from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss from .loss_grounding_dino import GroundingDinoForObjectDetectionLoss from .loss_rt_detr import RTDetrForObjectDetectionLoss +from .number_token_loss import ForCausalLMWithNTLWAS, ForCausalLMWithNTLMSE def fixed_cross_entropy( @@ -146,6 +147,8 @@ def ForTokenClassification(logits: torch.Tensor, labels, config, **kwargs): LOSS_MAPPING = { "ForCausalLM": ForCausalLMLoss, + "ForCausalLMWithNTLWAS": ForCausalLMWithNTLWAS, + "ForCausalLMWithNTLMSE": ForCausalLMWithNTLMSE, "ForMaskedLM": ForMaskedLMLoss, "ForQuestionAnswering": ForQuestionAnsweringLoss, "ForSequenceClassification": ForSequenceClassificationLoss, diff --git a/src/transformers/loss/number_token_loss.py b/src/transformers/loss/number_token_loss.py new file mode 100644 index 000000000000..177d1db94385 --- /dev/null +++ b/src/transformers/loss/number_token_loss.py @@ -0,0 +1,501 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Number Token Loss (NTL) implementation for transformers. + +This module provides two variants of Number Token Loss: +- NTL-WAS: Uses Wasserstein-1 distance between numerical values +- NTL-MSE: Uses Mean Squared Error between numerical values + +The loss is designed to augment cross-entropy loss for language models when dealing with numerical tokens, +providing a more meaningful loss signal for tokens that represent numbers. +""" + +import re +import math +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def extract_numerical_value(token: str) -> Optional[float]: + """ + Extract numerical value from a token string. + + Args: + token: The token string to extract numerical value from + + Returns: + The numerical value as float, or None if the token is not numerical + + Examples: + >>> extract_numerical_value("123") + 123.0 + >>> extract_numerical_value("3.14") + 3.14 + >>> extract_numerical_value("hello") + None + """ + token_lower = token.lower() + + # Handle special number words + number_words = { + 'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, + 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10, + 'eleven': 11, 'twelve': 12, 'thirteen': 13, 'fourteen': 14, 'fifteen': 15, + 'sixteen': 16, 'seventeen': 17, 'eighteen': 18, 'nineteen': 19, 'twenty': 20, + 'thirty': 30, 'forty': 40, 'fifty': 50, 'sixty': 60, 'seventy': 70, + 'eighty': 80, 'ninety': 90, 'hundred': 100, 'thousand': 1000, + 'million': 1000000, 'billion': 1000000000 + } + + if token_lower in number_words: + return float(number_words[token_lower]) + + # Handle ordinal words (check before removing suffixes) + ordinal_words = { + 'first': 1, 'second': 2, 'third': 3, 'fourth': 4, 'fifth': 5, + 'sixth': 6, 'seventh': 7, 'eighth': 8, 'ninth': 9, 'tenth': 10, + 'eleventh': 11, 'twelfth': 12, 'thirteenth': 13, 'fourteenth': 14, 'fifteenth': 15, + 'sixteenth': 16, 'seventeenth': 17, 'eighteenth': 18, 'nineteenth': 19, 'twentieth': 20 + } + + if token_lower in ordinal_words: + return float(ordinal_words[token_lower]) + + # Remove common suffixes that might be attached to numbers + token_clean = re.sub(r'(st|nd|rd|th|s)$', '', token_lower) + + # Try to parse as a regular number + try: + # Remove commas from numbers like "1,000" + token_clean = token_clean.replace(',', '') + + # Handle scientific notation + if 'e' in token_clean.lower(): + return float(token_clean) + + # Handle regular numbers (integers and floats) + if re.match(r'^[+-]?\d*\.?\d+$', token_clean): + return float(token_clean) + + # Handle numbers with currency symbols + if re.match(r'^[+-]?[$€£¥₹₽฿₺₴₣₡₱₪₮₩₦₫﷼]?\d*\.?\d+$', token_clean): + # Remove currency symbols and parse + cleaned = re.sub(r'[^0-9+-.]', '', token_clean) + return float(cleaned) + + except (ValueError, TypeError): + pass + + return None + + +def build_token_to_number_map(tokenizer) -> Dict[int, float]: + """ + Build a mapping from token IDs to their numerical values. + + Args: + tokenizer: The tokenizer to extract tokens from + + Returns: + Dictionary mapping token IDs to numerical values + """ + token_to_number = {} + + for token_id in range(tokenizer.vocab_size): + try: + token = tokenizer.convert_ids_to_tokens(token_id) + numerical_value = extract_numerical_value(token) + if numerical_value is not None: + token_to_number[token_id] = numerical_value + except (KeyError, ValueError): + # Skip tokens that can't be converted + continue + + return token_to_number + + +def wasserstein_1_distance_numerical( + pred_dist: torch.Tensor, + target_dist: torch.Tensor, + token_to_number_map: Dict[int, float], + vocab_size: int +) -> torch.Tensor: + """ + Compute Wasserstein-1 distance between predicted and target distributions + for numerical tokens. + + For one-hot target distributions, this computes the expected absolute difference + between the predicted and target numerical values. + + Args: + pred_dist: Predicted distribution [batch_size, vocab_size] + target_dist: Target distribution [batch_size, vocab_size] (one-hot) + token_to_number_map: Mapping from token IDs to numerical values + vocab_size: Size of the vocabulary + + Returns: + Wasserstein-1 distance for each sample in the batch + """ + batch_size = pred_dist.shape[0] + device = pred_dist.device + + # Get target token indices + target_indices = torch.argmax(target_dist, dim=-1) # [batch_size] + + # Get target numerical values + target_values = torch.zeros(batch_size, device=device) + for i, token_id in enumerate(target_indices): + if token_id.item() in token_to_number_map: + target_values[i] = token_to_number_map[token_id.item()] + else: + # If target is not numerical, set to NaN to ignore + target_values[i] = float('nan') + + # Compute expected predicted numerical values + pred_values = torch.zeros(batch_size, device=device) + for token_id, num_value in token_to_number_map.items(): + pred_values += pred_dist[:, token_id] * num_value + + # Compute absolute difference (Wasserstein-1 distance) + # Only for positions where target is numerical + valid_mask = ~torch.isnan(target_values) + if not valid_mask.any(): + return torch.tensor(0.0, device=device) + + was_distances = torch.abs(pred_values[valid_mask] - target_values[valid_mask]) + return was_distances + + +def ntl_was_loss( + logits: torch.Tensor, + labels: torch.Tensor, + token_to_number_map: Dict[int, float], + vocab_size: int, + alpha: float = 0.1, + ignore_index: int = -100, + **kwargs +) -> torch.Tensor: + """ + Number Token Loss using Wasserstein-1 distance (NTL-WAS). + + Args: + logits: Model logits [batch_size, seq_len, vocab_size] + labels: Target labels [batch_size, seq_len] + token_to_number_map: Mapping from token IDs to numerical values + vocab_size: Size of the vocabulary + alpha: Weight for NTL loss (default: 0.1) + ignore_index: Index to ignore in loss computation (default: -100) + **kwargs: Additional arguments + + Returns: + Combined loss (CE + alpha * NTL-WAS) + """ + # Standard cross-entropy loss + ce_loss = F.cross_entropy( + logits.view(-1, vocab_size), + labels.view(-1), + ignore_index=ignore_index, + reduction='mean' + ) + + # Create numerical value tensors + device = logits.device + batch_size, seq_len = labels.shape + + # Initialize numerical values tensor + numerical_values = torch.full((batch_size, seq_len), float('nan'), device=device) + + # Fill in numerical values for tokens that have them + for token_id, num_value in token_to_number_map.items(): + mask = (labels == token_id) + numerical_values[mask] = num_value + + # Only compute NTL for positions with numerical values + numerical_mask = ~torch.isnan(numerical_values) + + if not numerical_mask.any(): + return ce_loss + + # Extract logits and labels for numerical positions + num_logits = logits[numerical_mask] # [num_numerical, vocab_size] + num_labels = labels[numerical_mask] # [num_numerical] + + # Create target distribution (one-hot) + target_dist = F.one_hot(num_labels, num_classes=vocab_size).float() + + # Create predicted distribution (softmax of logits) + pred_dist = F.softmax(num_logits, dim=-1) + + # Compute Wasserstein-1 distance + was_distances = wasserstein_1_distance_numerical(pred_dist, target_dist, token_to_number_map, vocab_size) + + # Average over numerical positions + ntl_loss = was_distances.mean() + + return ce_loss + alpha * ntl_loss + + +def ntl_mse_loss( + logits: torch.Tensor, + labels: torch.Tensor, + token_to_number_map: Dict[int, float], + vocab_size: int, + alpha: float = 0.1, + ignore_index: int = -100, + **kwargs +) -> torch.Tensor: + """ + Number Token Loss using Mean Squared Error (NTL-MSE). + + Args: + logits: Model logits [batch_size, seq_len, vocab_size] + labels: Target labels [batch_size, seq_len] + token_to_number_map: Mapping from token IDs to numerical values + vocab_size: Size of the vocabulary + alpha: Weight for NTL loss (default: 0.1) + ignore_index: Index to ignore in loss computation (default: -100) + **kwargs: Additional arguments + + Returns: + Combined loss (CE + alpha * NTL-MSE) + """ + # Standard cross-entropy loss + ce_loss = F.cross_entropy( + logits.view(-1, vocab_size), + labels.view(-1), + ignore_index=ignore_index, + reduction='mean' + ) + + # Create numerical value tensors + device = logits.device + batch_size, seq_len = labels.shape + + # Initialize numerical values tensor + numerical_values = torch.full((batch_size, seq_len), float('nan'), device=device) + + # Fill in numerical values for tokens that have them + for token_id, num_value in token_to_number_map.items(): + mask = (labels == token_id) + numerical_values[mask] = num_value + + # Only compute NTL for positions with numerical values + numerical_mask = ~torch.isnan(numerical_values) + + if not numerical_mask.any(): + return ce_loss + + # Extract logits and labels for numerical positions + num_logits = logits[numerical_mask] # [num_numerical, vocab_size] + num_labels = labels[numerical_mask] # [num_numerical] + num_values = numerical_values[numerical_mask] # [num_numerical] + + # Create predicted distribution (softmax of logits) + pred_dist = F.softmax(num_logits, dim=-1) + + # Compute MSE between predicted and target numerical values + # For each position, compute the expected numerical value + pred_numerical_values = torch.zeros_like(num_values) + for token_id, num_value in token_to_number_map.items(): + pred_numerical_values += pred_dist[:, token_id] * num_value + + # MSE loss + mse_loss = F.mse_loss(pred_numerical_values, num_values) + + return ce_loss + alpha * mse_loss + + +# Global token_to_number_map cache to avoid rebuilding for each loss computation +_token_to_number_map_cache = {} + + +def get_token_to_number_map(tokenizer): + """ + Get or create token-to-number mapping with caching. + + Args: + tokenizer: The tokenizer to extract tokens from + + Returns: + Dictionary mapping token IDs to numerical values + """ + tokenizer_id = id(tokenizer) + if tokenizer_id not in _token_to_number_map_cache: + _token_to_number_map_cache[tokenizer_id] = build_token_to_number_map(tokenizer) + return _token_to_number_map_cache[tokenizer_id] + + +def ForCausalLMWithNTLWAS( + logits, + labels, + vocab_size: int, + tokenizer=None, + alpha: float = 0.1, + num_items_in_batch: Optional[torch.Tensor] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + """ + Causal LM loss augmented with Number Token Loss using Wasserstein-1 distance. + + Args: + logits: Model logits + labels: Target labels + vocab_size: Size of the vocabulary + tokenizer: Tokenizer for extracting numerical tokens + alpha: Weight for NTL loss + num_items_in_batch: Number of items in batch (for compatibility) + ignore_index: Index to ignore in loss computation + shift_labels: Shifted labels (for compatibility) + **kwargs: Additional arguments + + Returns: + Combined loss (CE + alpha * NTL-WAS) + """ + if tokenizer is None: + # Fall back to standard CE loss if no tokenizer provided + from .loss_utils import ForCausalLMLoss + return ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, ignore_index, shift_labels, **kwargs) + + # Handle label shifting for causal LM + if shift_labels is None: + # Shift so that tokens < n predict n + labels = nn.functional.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + # Get token-to-number mapping + token_to_number_map = get_token_to_number_map(tokenizer) + + # Compute NTL-WAS loss + return ntl_was_loss( + logits, shift_labels, token_to_number_map, vocab_size, + alpha, ignore_index, **kwargs + ) + + +def ForCausalLMWithNTLMSE( + logits, + labels, + vocab_size: int, + tokenizer=None, + alpha: float = 0.1, + num_items_in_batch: Optional[torch.Tensor] = None, + ignore_index: int = -100, + shift_labels: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + """ + Causal LM loss augmented with Number Token Loss using MSE. + + Args: + logits: Model logits + labels: Target labels + vocab_size: Size of the vocabulary + tokenizer: Tokenizer for extracting numerical tokens + alpha: Weight for NTL loss + num_items_in_batch: Number of items in batch (for compatibility) + ignore_index: Index to ignore in loss computation + shift_labels: Shifted labels (for compatibility) + **kwargs: Additional arguments + + Returns: + Combined loss (CE + alpha * NTL-MSE) + """ + if tokenizer is None: + # Fall back to standard CE loss if no tokenizer provided + from .loss_utils import ForCausalLMLoss + return ForCausalLMLoss(logits, labels, vocab_size, num_items_in_batch, ignore_index, shift_labels, **kwargs) + + # Handle label shifting for causal LM + if shift_labels is None: + # Shift so that tokens < n predict n + labels = nn.functional.pad(labels, (0, 1), value=ignore_index) + shift_labels = labels[..., 1:].contiguous() + + # Get token-to-number mapping + token_to_number_map = get_token_to_number_map(tokenizer) + + # Compute NTL-MSE loss + return ntl_mse_loss( + logits, shift_labels, token_to_number_map, vocab_size, + alpha, ignore_index, **kwargs + ) + + +class NumberTokenLoss: + """ + Number Token Loss class that can be used as a drop-in replacement for standard loss functions. + + This class maintains the token-to-number mapping and provides both NTL-WAS and NTL-MSE variants. + """ + + def __init__( + self, + tokenizer, + variant: str = "was", + alpha: float = 0.1, + ignore_index: int = -100 + ): + """ + Initialize Number Token Loss. + + Args: + tokenizer: The tokenizer to extract numerical tokens from + variant: Loss variant ("was" or "mse") + alpha: Weight for NTL loss + ignore_index: Index to ignore in loss computation + """ + self.token_to_number_map = build_token_to_number_map(tokenizer) + self.variant = variant.lower() + self.alpha = alpha + self.ignore_index = ignore_index + + if self.variant not in ["was", "mse"]: + raise ValueError(f"Unknown variant: {variant}. Must be 'was' or 'mse'") + + def __call__( + self, + logits: torch.Tensor, + labels: torch.Tensor, + vocab_size: int, + **kwargs + ) -> torch.Tensor: + """ + Compute the Number Token Loss. + + Args: + logits: Model logits [batch_size, seq_len, vocab_size] + labels: Target labels [batch_size, seq_len] + vocab_size: Size of the vocabulary + **kwargs: Additional arguments + + Returns: + Combined loss (CE + alpha * NTL) + """ + if self.variant == "was": + return ntl_was_loss( + logits, labels, self.token_to_number_map, vocab_size, + self.alpha, self.ignore_index, **kwargs + ) + else: # mse + return ntl_mse_loss( + logits, labels, self.token_to_number_map, vocab_size, + self.alpha, self.ignore_index, **kwargs + ) \ No newline at end of file diff --git a/tests/test_number_token_loss.py b/tests/test_number_token_loss.py new file mode 100644 index 000000000000..108263459c57 --- /dev/null +++ b/tests/test_number_token_loss.py @@ -0,0 +1,227 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for Number Token Loss implementation. +""" + +import unittest +import torch +from transformers import AutoTokenizer +from transformers.loss.number_token_loss import ( + extract_numerical_value, + build_token_to_number_map, + ntl_was_loss, + ntl_mse_loss, + ForCausalLMWithNTLWAS, + ForCausalLMWithNTLMSE, +) + + +class TestNumberTokenLoss(unittest.TestCase): + """Test cases for Number Token Loss.""" + + def test_extract_numerical_value(self): + """Test numerical value extraction from tokens.""" + # Test regular numbers + self.assertEqual(extract_numerical_value("123"), 123.0) + self.assertEqual(extract_numerical_value("3.14"), 3.14) + self.assertEqual(extract_numerical_value("-42"), -42.0) + self.assertEqual(extract_numerical_value("1,000"), 1000.0) + + # Test number words + self.assertEqual(extract_numerical_value("five"), 5.0) + self.assertEqual(extract_numerical_value("twenty"), 20.0) + self.assertEqual(extract_numerical_value("hundred"), 100.0) + + # Test ordinal words + self.assertEqual(extract_numerical_value("first"), 1.0) + self.assertEqual(extract_numerical_value("tenth"), 10.0) + + # Test with suffixes + self.assertEqual(extract_numerical_value("5th"), 5.0) + self.assertEqual(extract_numerical_value("1st"), 1.0) + + # Test non-numerical tokens + self.assertIsNone(extract_numerical_value("hello")) + self.assertIsNone(extract_numerical_value("the")) + self.assertIsNone(extract_numerical_value("")) + + def test_build_token_to_number_map(self): + """Test building token-to-number mapping.""" + # Use a simple tokenizer for testing + tokenizer = AutoTokenizer.from_pretrained("gpt2") + token_to_number = build_token_to_number_map(tokenizer) + + # Should be a dictionary + self.assertIsInstance(token_to_number, dict) + + # Should contain some numerical tokens + self.assertGreater(len(token_to_number), 0) + + # Check that some common numbers are present + # Note: This depends on the specific tokenizer vocabulary + for token_id, num_value in token_to_number.items(): + self.assertIsInstance(token_id, int) + self.assertIsInstance(num_value, float) + self.assertGreaterEqual(token_id, 0) + self.assertLess(token_id, tokenizer.vocab_size) + + def test_ntl_was_loss(self): + """Test NTL-WAS loss computation.""" + vocab_size = 1000 + batch_size = 2 + seq_len = 10 + + # Create dummy logits and labels + logits = torch.randn(batch_size, seq_len, vocab_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # Create a simple token-to-number mapping + token_to_number = {0: 1.0, 1: 2.0, 2: 3.0} # Only first few tokens are numerical + + # Compute loss + loss = ntl_was_loss(logits, labels, token_to_number, vocab_size, alpha=0.1) + + # Should be a tensor + self.assertIsInstance(loss, torch.Tensor) + + # Should be a scalar + self.assertEqual(loss.dim(), 0) + + # Should be positive + self.assertGreater(loss.item(), 0) + + def test_ntl_mse_loss(self): + """Test NTL-MSE loss computation.""" + vocab_size = 1000 + batch_size = 2 + seq_len = 10 + + # Create dummy logits and labels + logits = torch.randn(batch_size, seq_len, vocab_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # Create a simple token-to-number mapping + token_to_number = {0: 1.0, 1: 2.0, 2: 3.0} # Only first few tokens are numerical + + # Compute loss + loss = ntl_mse_loss(logits, labels, token_to_number, vocab_size, alpha=0.1) + + # Should be a tensor + self.assertIsInstance(loss, torch.Tensor) + + # Should be a scalar + self.assertEqual(loss.dim(), 0) + + # Should be positive + self.assertGreater(loss.item(), 0) + + def test_for_causal_lm_with_ntl_was(self): + """Test ForCausalLMWithNTLWAS function.""" + # Get actual vocab size from tokenizer + tokenizer = AutoTokenizer.from_pretrained("gpt2") + vocab_size = tokenizer.vocab_size + batch_size = 2 + seq_len = 10 + + # Create dummy logits and labels + logits = torch.randn(batch_size, seq_len, vocab_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # Test with tokenizer + loss = ForCausalLMWithNTLWAS( + logits, labels, vocab_size, tokenizer=tokenizer, alpha=0.1 + ) + + # Should be a tensor + self.assertIsInstance(loss, torch.Tensor) + + # Should be a scalar + self.assertEqual(loss.dim(), 0) + + # Should be positive + self.assertGreater(loss.item(), 0) + + # Test without tokenizer (should fall back to CE) + loss_no_tokenizer = ForCausalLMWithNTLWAS( + logits, labels, vocab_size, tokenizer=None + ) + + # Should still be a valid loss + self.assertIsInstance(loss_no_tokenizer, torch.Tensor) + self.assertGreater(loss_no_tokenizer.item(), 0) + + def test_for_causal_lm_with_ntl_mse(self): + """Test ForCausalLMWithNTLMSE function.""" + # Get actual vocab size from tokenizer + tokenizer = AutoTokenizer.from_pretrained("gpt2") + vocab_size = tokenizer.vocab_size + batch_size = 2 + seq_len = 10 + + # Create dummy logits and labels + logits = torch.randn(batch_size, seq_len, vocab_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + # Test with tokenizer + loss = ForCausalLMWithNTLMSE( + logits, labels, vocab_size, tokenizer=tokenizer, alpha=0.1 + ) + + # Should be a tensor + self.assertIsInstance(loss, torch.Tensor) + + # Should be a scalar + self.assertEqual(loss.dim(), 0) + + # Should be positive + self.assertGreater(loss.item(), 0) + + # Test without tokenizer (should fall back to CE) + loss_no_tokenizer = ForCausalLMWithNTLMSE( + logits, labels, vocab_size, tokenizer=None + ) + + # Should still be a valid loss + self.assertIsInstance(loss_no_tokenizer, torch.Tensor) + self.assertGreater(loss_no_tokenizer.item(), 0) + + def test_loss_with_ignore_index(self): + """Test that ignore_index is handled correctly.""" + vocab_size = 1000 + batch_size = 2 + seq_len = 10 + + # Create dummy logits and labels with some ignored positions + logits = torch.randn(batch_size, seq_len, vocab_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + labels[0, 5:] = -100 # Ignore some positions + + # Create a simple token-to-number mapping + token_to_number = {0: 1.0, 1: 2.0, 2: 3.0} + + # Test both loss variants + loss_was = ntl_was_loss(logits, labels, token_to_number, vocab_size, ignore_index=-100) + loss_mse = ntl_mse_loss(logits, labels, token_to_number, vocab_size, ignore_index=-100) + + # Both should be valid losses + self.assertIsInstance(loss_was, torch.Tensor) + self.assertIsInstance(loss_mse, torch.Tensor) + self.assertGreater(loss_was.item(), 0) + self.assertGreater(loss_mse.item(), 0) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file