diff --git a/flair/distributed_utils.py b/flair/distributed_utils.py index 862984983e..d3ad96697d 100644 --- a/flair/distributed_utils.py +++ b/flair/distributed_utils.py @@ -2,9 +2,8 @@ import os import random from multiprocessing.connection import Connection -from typing import Callable +from typing import Callable, Collection, Iterable, TypeVar -import numpy as np import torch import torch.multiprocessing as mp from torch.distributed import destroy_process_group, init_process_group @@ -15,8 +14,10 @@ log = logging.getLogger("flair") +T = TypeVar("T") -def launch_distributed(fn, *args, **kwargs): + +def launch_distributed(fn: Callable, *args, **kwargs): """Executes the function fn(*args, **kwargs) on multiple processes (one for each local GPU). If training with multi_gpu=True, launch_distributed should wrap your code that calls .train or .fine_tune. @@ -61,16 +62,6 @@ def is_main_process() -> bool: return True -def aggregate(value, aggregation_fn=np.mean): - """Gather `value` from all processes and send to `aggregation_fn` to get a single return value.""" - if torch.distributed.is_initialized(): - gathered_values = [None for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather_object(gathered_values, value) - else: - gathered_values = [value] - return aggregation_fn(gathered_values) - - def validate_corpus_same_each_process(corpus: Corpus) -> None: """Catches most cases in which a corpus is not the same on each process. @@ -84,9 +75,76 @@ def validate_corpus_same_each_process(corpus: Corpus) -> None: def _validate_dataset_same_each_process(dataset: Dataset, sample_size: int = 10) -> None: + """:raises: ValueError if the dataset is not the same on each process.""" random_indices = random.sample(range(_len_dataset(dataset)), min(sample_size, _len_dataset(dataset))) for i in random_indices: example = str(dataset[i]) examples = aggregate(example, list) if not all(example == examples[0] for example in examples): raise ValueError("Dataset must be the same on each process") + + +def gather(value: T) -> list[T]: + """Gather `value` from all processes and return a list of values.""" + if torch.distributed.is_initialized(): + gathered_values = [value for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_values, value) + else: + gathered_values = [value] + return gathered_values + + +def aggregate(value: T, aggregation_fn: Callable): + """Gather `value` from all processes and send to `aggregation_fn` to get a single return value.""" + gathered_values = gather(value) + return aggregation_fn(gathered_values) + + +def broadcast_value(value: T, src: int = 0) -> T: + """ + Broadcasts a Python object from the source process (src) to all other processes. + Every process returns the same object. + """ + obj_list = [value] + torch.distributed.broadcast_object_list(obj_list, src=src) + return obj_list[0] + + +# aggregation functions +def flatten(l: Iterable[Iterable[T]]) -> list[T]: + """Flattens all elements in an iterable, such as a list, of iterables into a single list.""" + return [x for s in l for x in s] + + +def flatten_set(list_of_sets: Iterable[Iterable[T]]) -> set[T]: + """Flattens all elements in an iterable, such as a list, of iterables into a single set.""" + return {x for subset in list_of_sets for x in subset} + + +def merge_sets(list_of_sets: Collection[set[T]]) -> set[T]: + """Merges a collection of sets into a single set.""" + merged_set = set() + for s in list_of_sets: + merged_set.update(s) + return merged_set + + +def flatten_dicts(list_of_dicts: list[dict[str, list[T]]]) -> dict[str, list[T]]: + """This function merges a list of dictionaries with list values into a single dictionary with merged list values.""" + merged_dict: dict[str, list[T]] = {} + for d in list_of_dicts: + for k, v in d.items(): + if k not in merged_dict: + merged_dict[k] = [] + merged_dict[k].extend(v) + return merged_dict + + +def aggregate_tensor_sum(list_of_tensors: list[torch.Tensor]) -> torch.Tensor: + """ + Custom aggregation function to sum loss values from all processes. + Moves all tensors to CPU and converts them to Python scalars before summing. + Returns a single tensor containing the summed loss. + """ + total = sum(t.cpu().item() for t in list_of_tensors) + return torch.tensor(total) diff --git a/flair/models/pairwise_regression_model.py b/flair/models/pairwise_regression_model.py index bc77b54dce..11051fbb1d 100644 --- a/flair/models/pairwise_regression_model.py +++ b/flair/models/pairwise_regression_model.py @@ -4,6 +4,7 @@ import torch from torch import nn +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import Dataset from tqdm import tqdm @@ -11,6 +12,7 @@ import flair.nn from flair.data import Corpus, Dictionary, Sentence, TextPair, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset +from flair.distributed_utils import aggregate, aggregate_tensor_sum, broadcast_value, flatten, is_main_process from flair.nn.model import ReduceTransformerVocabMixin from flair.training_utils import EmbeddingStorageMode, MetricRegression, Result, store_embeddings @@ -288,13 +290,21 @@ def evaluate( exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, + multi_gpu: bool = False, **kwargs, ) -> Result: exclude_labels = exclude_labels if exclude_labels is not None else [] + # read Dataset into data loader, if list of sentences passed, make Dataset first if not isinstance(data_points, Dataset): data_points = FlairDatapointDataset(data_points) - data_loader = DataLoader(data_points, batch_size=mini_batch_size) + + data_loader = DataLoader( + data_points, + batch_size=mini_batch_size, + shuffle=False, + sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None, + ) with torch.no_grad(): eval_loss = torch.zeros(1, device=flair.device) @@ -311,7 +321,7 @@ def evaluate( if isinstance(batch, Sentence): batch = [batch] - loss, num, scores = self._forward_loss_and_scores(batch, return_scores=True) + loss, num, scores_forward = self._forward_loss_and_scores(batch, return_scores=True) true_values = [] for sentence in batch: @@ -319,7 +329,7 @@ def evaluate( for label in sentence.get_labels(gold_label_type): true_values.append(float(label.value)) - results = scores.cpu().tolist() + results = scores_forward.cpu().tolist() eval_loss += loss @@ -336,30 +346,46 @@ def evaluate( if out_path is not None: out_file.close() + if multi_gpu: + metric.true = aggregate(metric.true, flatten) + metric.pred = aggregate(metric.pred, flatten) + eval_loss = aggregate(eval_loss, aggregate_tensor_sum) + total_count = aggregate(total_count, sum) + eval_loss /= total_count - detailed_result = ( - f"AVG: mse: {metric.mean_squared_error():.4f} - " - f"mae: {metric.mean_absolute_error():.4f} - " - f"pearson: {metric.pearsonr():.4f} - " - f"spearman: {metric.spearmanr():.4f}" - ) + if is_main_process(): # only calculate metrics in main process - eval_metrics = { - "loss": eval_loss.item(), - "mse": metric.mean_squared_error(), - "mae": metric.mean_absolute_error(), - "pearson": metric.pearsonr(), - "spearman": metric.spearmanr(), - } + detailed_result = ( + f"AVG: mse: {metric.mean_squared_error():.4f} - " + f"mae: {metric.mean_absolute_error():.4f} - " + f"pearson: {metric.pearsonr():.4f} - " + f"spearman: {metric.spearmanr():.4f}" + ) - if main_evaluation_metric[0] in ("correlation", "other"): - main_score = eval_metrics[main_evaluation_metric[1]] - else: - main_score = eval_metrics["spearman"] + scores = { + "loss": eval_loss.item(), + "mse": metric.mean_squared_error(), + "mae": metric.mean_absolute_error(), + "pearson": metric.pearsonr(), + "spearman": metric.spearmanr(), + } - return Result( - main_score=main_score, - detailed_results=detailed_result, - scores=eval_metrics, - ) + if main_evaluation_metric[0] in ("correlation", "other"): + main_score = scores[main_evaluation_metric[1]] + else: + main_score = scores["spearman"] + + result = Result( + main_score=main_score, + detailed_results=detailed_result, + scores=scores, + ) + + else: # if it's not the main process, just set a dummy Result + result = Result(0.0, "", {}, {"loss": 0.0}) + + if multi_gpu: + result = broadcast_value(result, src=0) + + return result diff --git a/flair/models/text_regression_model.py b/flair/models/text_regression_model.py index a0a99e6402..56cc635a72 100644 --- a/flair/models/text_regression_model.py +++ b/flair/models/text_regression_model.py @@ -5,6 +5,7 @@ import torch from torch import nn +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import Dataset from tqdm import tqdm @@ -12,6 +13,7 @@ import flair.embeddings from flair.data import Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset +from flair.distributed_utils import aggregate, aggregate_tensor_sum, broadcast_value, flatten, is_main_process from flair.embeddings.base import load_embeddings from flair.nn.model import ReduceTransformerVocabMixin from flair.training_utils import EmbeddingStorageMode, MetricRegression, Result, store_embeddings @@ -141,13 +143,21 @@ def evaluate( exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, + multi_gpu: bool = False, **kwargs, ) -> Result: exclude_labels = exclude_labels if exclude_labels is not None else [] + # read Dataset into data loader, if list of sentences passed, make Dataset first if not isinstance(data_points, Dataset): data_points = FlairDatapointDataset(data_points) - data_loader = DataLoader(data_points, batch_size=mini_batch_size) + + data_loader = DataLoader( + data_points, + batch_size=mini_batch_size, + shuffle=False, + sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None, + ) with torch.no_grad(): eval_loss = torch.zeros(1, device=flair.device) @@ -156,11 +166,11 @@ def evaluate( lines: list[str] = [] total_count = 0 - for batch in data_loader: + for batch in tqdm(data_loader): if isinstance(batch, Sentence): batch = [batch] - scores, loss = self.forward_labels_and_loss(batch) + scores_forward, loss = self.forward_labels_and_loss(batch) true_values = [] for sentence in batch: @@ -168,7 +178,7 @@ def evaluate( for label in sentence.get_labels(gold_label_type): true_values.append(float(label.value)) - results = scores[:, 0].cpu().tolist() + results = scores_forward[:, 0].cpu().tolist() eval_loss += loss @@ -181,6 +191,12 @@ def evaluate( store_embeddings(batch, embedding_storage_mode) + if multi_gpu: + metric.true = aggregate(metric.true, flatten) + metric.pred = aggregate(metric.pred, flatten) + eval_loss = aggregate(eval_loss, aggregate_tensor_sum) + total_count = aggregate(total_count, sum) + eval_loss /= total_count # TODO: not saving lines yet @@ -188,31 +204,39 @@ def evaluate( with open(out_path, "w", encoding="utf-8") as outfile: outfile.write("".join(lines)) - detailed_result = ( - f"AVG: mse: {metric.mean_squared_error():.4f} - " - f"mae: {metric.mean_absolute_error():.4f} - " - f"pearson: {metric.pearsonr():.4f} - " - f"spearman: {metric.spearmanr():.4f}" - ) - - eval_metrics = { - "loss": eval_loss.item(), - "mse": metric.mean_squared_error(), - "mae": metric.mean_absolute_error(), - "pearson": metric.pearsonr(), - "spearman": metric.spearmanr(), - } - - if main_evaluation_metric[0] in ("correlation", "other"): - main_score = eval_metrics[main_evaluation_metric[1]] - else: - main_score = eval_metrics["spearman"] - - result = Result( - main_score=main_score, - detailed_results=detailed_result, - scores=eval_metrics, - ) + if is_main_process(): # only calculate metrics in main process + + detailed_result = ( + f"AVG: mse: {metric.mean_squared_error():.4f} - " + f"mae: {metric.mean_absolute_error():.4f} - " + f"pearson: {metric.pearsonr():.4f} - " + f"spearman: {metric.spearmanr():.4f}" + ) + + scores = { + "loss": eval_loss.item(), + "mse": metric.mean_squared_error(), + "mae": metric.mean_absolute_error(), + "pearson": metric.pearsonr(), + "spearman": metric.spearmanr(), + } + + if main_evaluation_metric[0] in ("correlation", "other"): + main_score = scores[main_evaluation_metric[1]] + else: + main_score = scores["spearman"] + + result = Result( + main_score=main_score, + detailed_results=detailed_result, + scores=scores, + ) + + else: # if it's not the main process, just set a dummy Result + result = Result(0.0, "", {}, {"loss": 0.0}) + + if multi_gpu: + result = broadcast_value(result, src=0) return result diff --git a/flair/nn/model.py b/flair/nn/model.py index 5cf030c5d6..ec2451e17a 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -10,6 +10,7 @@ import torch.nn from torch import Tensor from torch.nn.modules.loss import _Loss +from torch.utils.data import DistributedSampler from torch.utils.data.dataset import Dataset from tqdm import tqdm @@ -17,7 +18,14 @@ from flair.class_utils import get_non_abstract_subclasses from flair.data import DT, DT2, Corpus, Dictionary, Sentence, _iter_dataset from flair.datasets import DataLoader, FlairDatapointDataset -from flair.distributed_utils import is_main_process +from flair.distributed_utils import ( + aggregate, + aggregate_tensor_sum, + broadcast_value, + flatten_dicts, + is_main_process, + merge_sets, +) from flair.embeddings import Embeddings from flair.embeddings.base import load_embeddings from flair.file_utils import Tqdm, load_torch_state @@ -92,7 +100,6 @@ def evaluate( Returns: The evaluation results. """ - exclude_labels = exclude_labels if exclude_labels is not None else [] raise NotImplementedError def _get_state_dict(self) -> dict: @@ -362,6 +369,7 @@ def evaluate( exclude_labels: Optional[list[str]] = None, gold_label_dictionary: Optional[Dictionary] = None, return_loss: bool = True, + multi_gpu: bool = False, **kwargs, ) -> Result: exclude_labels = exclude_labels if exclude_labels is not None else [] @@ -390,7 +398,13 @@ def evaluate( all_true_values = {} all_predicted_values = {} - loader = DataLoader(data_points, batch_size=mini_batch_size) + loader = DataLoader( + data_points, + batch_size=mini_batch_size, + shuffle=False, + sampler=DistributedSampler(data_points, shuffle=False) if multi_gpu else None, + ) + rank = torch.distributed.get_rank() if multi_gpu else 0 sentence_id = 0 for batch in Tqdm.tqdm(loader, disable=not is_main_process()): @@ -417,7 +431,7 @@ def evaluate( # get the gold labels for datapoint in batch: for gold_label in datapoint.get_labels(gold_label_type): - representation = str(sentence_id) + ": " + gold_label.unlabeled_identifier + representation = f"{rank}-{sentence_id}: {gold_label.unlabeled_identifier}" value = gold_label.value if gold_label_dictionary and gold_label_dictionary.get_idx_for_item(value) == 0: @@ -432,7 +446,7 @@ def evaluate( all_spans.add(representation) for predicted_span in datapoint.get_labels("predicted"): - representation = str(sentence_id) + ": " + predicted_span.unlabeled_identifier + representation = f"{rank}-{sentence_id}: {predicted_span.unlabeled_identifier}" # add to all_predicted_values if representation not in all_predicted_values: @@ -451,6 +465,16 @@ def evaluate( if out_path: lines.extend(self._print_predictions(batch, gold_label_type)) + if multi_gpu: + all_spans = aggregate(all_spans, merge_sets) + all_true_values = aggregate(all_true_values, flatten_dicts) + all_predicted_values = aggregate(all_predicted_values, flatten_dicts) + average_over = aggregate(average_over, sum) + eval_loss = aggregate(eval_loss, aggregate_tensor_sum) + + result = Result(0.0, "", {}, {"loss": 0.0}) + if is_main_process(): + # convert true and predicted values to two span-aligned lists true_values_span_aligned = [] predicted_values_span_aligned = [] @@ -481,137 +505,142 @@ def evaluate( for label in predicted_values: evaluation_label_dictionary.add_item(label) - # check if this is a multi-label problem - multi_label = False - for true_instance, predicted_instance in zip(true_values_span_aligned, predicted_values_span_aligned): - if len(true_instance) > 1 or len(predicted_instance) > 1: - multi_label = True - break - - log.debug(f"Evaluating as a multi-label problem: {multi_label}") - - # compute numbers by formatting true and predicted such that Scikit-Learn can use them - y_true = [] - y_pred = [] - if multi_label: - # multi-label problems require a multi-hot vector for each true and predicted label - for true_instance in true_values_span_aligned: - y_true_instance = np.zeros(len(evaluation_label_dictionary), dtype=int) - for true_value in true_instance: - y_true_instance[evaluation_label_dictionary.get_idx_for_item(true_value)] = 1 - y_true.append(y_true_instance.tolist()) - - for predicted_values in predicted_values_span_aligned: - y_pred_instance = np.zeros(len(evaluation_label_dictionary), dtype=int) - for predicted_value in predicted_values: - y_pred_instance[evaluation_label_dictionary.get_idx_for_item(predicted_value)] = 1 - y_pred.append(y_pred_instance.tolist()) - else: - # single-label problems can do with a single index for each true and predicted label - y_true = [ - evaluation_label_dictionary.get_idx_for_item(true_instance[0]) - for true_instance in true_values_span_aligned - ] - y_pred = [ - evaluation_label_dictionary.get_idx_for_item(predicted_instance[0]) - for predicted_instance in predicted_values_span_aligned - ] - - # now, calculate evaluation numbers - target_names = [] - labels = [] - - counter = Counter(itertools.chain.from_iterable(all_true_values.values())) - counter.update(list(itertools.chain.from_iterable(all_predicted_values.values()))) - - for label_name, _count in counter.most_common(): - if label_name == "O": - continue - target_names.append(label_name) - labels.append(evaluation_label_dictionary.get_idx_for_item(label_name)) - - # there is at least one gold label or one prediction (default) - if len(all_true_values) + len(all_predicted_values) > 1: - classification_report = sklearn.metrics.classification_report( - y_true, - y_pred, - digits=4, - target_names=target_names, - zero_division=0, - labels=labels, - ) + # check if this is a multi-label problem + multi_label = False + for true_instance, predicted_instance in zip(true_values_span_aligned, predicted_values_span_aligned): + if len(true_instance) > 1 or len(predicted_instance) > 1: + multi_label = True + break + + log.debug(f"Evaluating as a multi-label problem: {multi_label}") + + # compute numbers by formatting true and predicted such that Scikit-Learn can use them + y_true = [] + y_pred = [] + if multi_label: + # multi-label problems require a multi-hot vector for each true and predicted label + for true_instance in true_values_span_aligned: + y_true_instance = np.zeros(len(evaluation_label_dictionary), dtype=int) + for true_value in true_instance: + y_true_instance[evaluation_label_dictionary.get_idx_for_item(true_value)] = 1 + y_true.append(y_true_instance.tolist()) + + for predicted_values in predicted_values_span_aligned: + y_pred_instance = np.zeros(len(evaluation_label_dictionary), dtype=int) + for predicted_value in predicted_values: + y_pred_instance[evaluation_label_dictionary.get_idx_for_item(predicted_value)] = 1 + y_pred.append(y_pred_instance.tolist()) + else: + # single-label problems can do with a single index for each true and predicted label + y_true = [ + evaluation_label_dictionary.get_idx_for_item(true_instance[0]) + for true_instance in true_values_span_aligned + ] + y_pred = [ + evaluation_label_dictionary.get_idx_for_item(predicted_instance[0]) + for predicted_instance in predicted_values_span_aligned + ] + + # now, calculate evaluation numbers + target_names = [] + labels = [] + + counter = Counter(itertools.chain.from_iterable(all_true_values.values())) + counter.update(list(itertools.chain.from_iterable(all_predicted_values.values()))) + + for label_name, _count in counter.most_common(): + if label_name == "O": + continue + target_names.append(label_name) + labels.append(evaluation_label_dictionary.get_idx_for_item(label_name)) + + # there is at least one gold label or one prediction (default) + if is_main_process() and len(all_true_values) + len(all_predicted_values) > 1: + classification_report = sklearn.metrics.classification_report( + y_true, + y_pred, + digits=4, + target_names=target_names, + zero_division=0, + labels=labels, + ) - classification_report_dict = sklearn.metrics.classification_report( - y_true, - y_pred, - target_names=target_names, - zero_division=0, - output_dict=True, - labels=labels, - ) + classification_report_dict = sklearn.metrics.classification_report( + y_true, + y_pred, + target_names=target_names, + zero_division=0, + output_dict=True, + labels=labels, + ) - # compute accuracy separately as it is not always in classification_report (e.g. when micro avg exists) - accuracy_score = round(sklearn.metrics.accuracy_score(y_true, y_pred), 4) + # compute accuracy separately as it is not always in classification_report (e.g. when micro avg exists) + accuracy_score = round(sklearn.metrics.accuracy_score(y_true, y_pred), 4) - # if there is only one label, then "micro avg" = "macro avg" - if len(target_names) == 1: - classification_report_dict["micro avg"] = classification_report_dict["macro avg"] + # if there is only one label, then "micro avg" = "macro avg" + if len(target_names) == 1: + classification_report_dict["micro avg"] = classification_report_dict["macro avg"] - # The "micro avg" appears only in the classification report if no prediction is possible. - # Otherwise, it is identical to the "macro avg". In this case, we add it to the report. - if "micro avg" not in classification_report_dict: - classification_report_dict["micro avg"] = {} - for metric_key in classification_report_dict["macro avg"]: - if metric_key != "support": - classification_report_dict["micro avg"][metric_key] = classification_report_dict["accuracy"] - else: - classification_report_dict["micro avg"][metric_key] = classification_report_dict["macro avg"][ - "support" - ] - - detailed_result = ( - "\nResults:" - f"\n- F-score (micro) {round(classification_report_dict['micro avg']['f1-score'], 4)}" - f"\n- F-score (macro) {round(classification_report_dict['macro avg']['f1-score'], 4)}" - f"\n- Accuracy {accuracy_score}" - "\n\nBy class:\n" + classification_report - ) + # The "micro avg" appears only in the classification report if no prediction is possible. + # Otherwise, it is identical to the "macro avg". In this case, we add it to the report. + if "micro avg" not in classification_report_dict: + classification_report_dict["micro avg"] = {} + for metric_key in classification_report_dict["macro avg"]: + if metric_key != "support": + classification_report_dict["micro avg"][metric_key] = classification_report_dict["accuracy"] + else: + classification_report_dict["micro avg"][metric_key] = classification_report_dict[ + "macro avg" + ]["support"] + + detailed_result = ( + "\nResults:" + f"\n- F-score (micro) {round(classification_report_dict['micro avg']['f1-score'], 4)}" + f"\n- F-score (macro) {round(classification_report_dict['macro avg']['f1-score'], 4)}" + f"\n- Accuracy {accuracy_score}" + "\n\nBy class:\n" + classification_report + ) - # Create and populate score object for logging with all evaluation values, plus the loss - scores: dict[Union[tuple[str, ...], str], Any] = {} + # Create and populate score object for logging with all evaluation values, plus the loss + scores: dict[Union[tuple[str, ...], str], Any] = {} - for avg_type in ("micro avg", "macro avg"): - for metric_type in ("f1-score", "precision", "recall"): - scores[(avg_type, metric_type)] = classification_report_dict[avg_type][metric_type] + for avg_type in ("micro avg", "macro avg"): + for metric_type in ("f1-score", "precision", "recall"): + scores[(avg_type, metric_type)] = classification_report_dict[avg_type][metric_type] - scores["accuracy"] = accuracy_score + scores["accuracy"] = accuracy_score - if average_over > 0: - eval_loss /= average_over - scores["loss"] = eval_loss.item() + if average_over > 0: + eval_loss /= average_over + scores["loss"] = eval_loss.item() - return Result( - main_score=classification_report_dict[main_evaluation_metric[0]][main_evaluation_metric[1]], - detailed_results=detailed_result, - classification_report=classification_report_dict, - scores=scores, - ) + result = Result( + main_score=classification_report_dict[main_evaluation_metric[0]][main_evaluation_metric[1]], + detailed_results=detailed_result, + classification_report=classification_report_dict, + scores=scores, + ) - else: - # issue error and default all evaluation numbers to 0. - error_text = ( - f"It was not possible to compute evaluation values because: \n" - f"- The evaluation data has no gold labels for label_type='{gold_label_type}'!\n" - f"- And no predictions were made!\n" - "Double check your corpus (if the test split has labels), and how you initialize the ModelTrainer!" - ) + else: + # issue error and default all evaluation numbers to 0. + error_text = ( + f"It was not possible to compute evaluation values because: \n" + f"- The evaluation data has no gold labels for label_type='{gold_label_type}'!\n" + f"- And no predictions were made!\n" + "Double check your corpus (if the test split has labels), and how you initialize the ModelTrainer!" + ) - return Result( - main_score=0.0, - detailed_results=error_text, - classification_report={}, - scores={"loss": 0.0}, - ) + result = Result( + main_score=0.0, + detailed_results=error_text, + classification_report={}, + scores={"loss": 0.0}, + ) + + if multi_gpu: + result = broadcast_value(result, src=0) + + return result @abstractmethod def predict( diff --git a/flair/trainers/trainer.py b/flair/trainers/trainer.py index f47bb86b53..bbb23a4d69 100644 --- a/flair/trainers/trainer.py +++ b/flair/trainers/trainer.py @@ -734,7 +734,7 @@ def wrapped_forward_loss(*args, **kwargs2): if epoch_train_samples > 0 else epoch_train_samples / (batch_no + 1) ) - intermittent_loss = aggregate(intermittent_loss) + intermittent_loss = aggregate(intermittent_loss, np.mean) current_time = time.time() samples_per_second = epoch_train_samples / (current_time - epoch_start_time) @@ -755,7 +755,7 @@ def wrapped_forward_loss(*args, **kwargs2): self.dispatch("after_training_batch", **batch_kw) train_loss = epoch_train_loss / epoch_train_samples - train_loss = aggregate(train_loss) + train_loss = aggregate(train_loss, np.mean) self._record(MetricRecord.scalar(("train", "loss"), train_loss, epoch)) total_train_samples += epoch_train_samples @@ -784,6 +784,7 @@ def wrapped_forward_loss(*args, **kwargs2): embedding_storage_mode=embeddings_storage_mode, gold_label_type=self.model.label_type, gold_label_dictionary_for_eval=gold_label_dictionary_for_eval, + multi_gpu=multi_gpu, ) # log results @@ -886,6 +887,7 @@ def wrapped_forward_loss(*args, **kwargs2): gold_label_dictionary=gold_label_dictionary_for_eval, exclude_labels=exclude_labels, return_loss=False, + multi_gpu=multi_gpu, ) log.info(test_results.detailed_results)