diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 356017ea..3dccb41f 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -1,58 +1,71 @@ import pandas as pd from texthero import nlp, visualization, preprocessing, representation +import texthero as hero from . import PandasTestCase import unittest import string from parameterized import parameterized +import importlib +import inspect +import numpy as np -# Define valid inputs for different functions. +""" +This file intends to test each function whether the input Series's index is the same as +the output Series. + +This will go through all functions under texthero, and automatically +generate test cases, according to the HeroSeries type they accept, e.g. TokenSeries, TextSeries, etc. + +Normally, if you want a function to be auto-tested, you must specify default values for all parameters +except the first Series param, i.e., + +``` +def replace_punctuation(s: TextSeries, symbol: str = " ") -> TextSeries: + ... +``` + +However there might be exceptions that you want to specify test cases manually, or omit some functions +for testing. For example, +- Functions that has arguments without default value (preprocessing.replace_hashtags) +- Functions that returns Series with different index (representation.tfidf) +- Functions that doesn't return Series(mostly in visualization) +- Functions that doesn't take HeroSeries (yet) + +In those cases, you should add your custom test case so as to override the default one, +in the form of [name_of_test_case, function_to_test, tuple_of_valid_input_for_the_function]. +If you want to omit some functions, add their string name to func_white_list variable. + +The tests will be run by AbstractIndexTest below through the @parameterized +decorator. The names will be expanded automatically, so e.g. "named_entities" +creates test cases test_correct_index_named_entities and test_incorrect_index_named_entities. +""" + +# Define the valid input for each HeroSeries type s_text = pd.Series(["Test"], index=[5]) s_tokenized_lists = pd.Series([["Test", "Test2"], ["Test3"]], index=[5, 6]) s_numeric = pd.Series([5.0], index=[5]) s_numeric_lists = pd.Series([[5.0, 5.0], [6.0, 6.0]], index=[5, 6]) -# Define all test cases. Every test case is a list -# of [name of test case, function to test, tuple of valid input for the function]. -# First argument of valid input has to be the Pandas Series where we -# want to keep the index. If this is different for a function, a separate -# test case has to implemented in the class below. -# The tests will be run by AbstractIndexTest below through the @parameterized -# decorator. -# The names will be expanded automatically, so e.g. "named_entities" -# creates test cases test_correct_index_named_entities and test_incorrect_index_named_entities. - -test_cases_nlp = [ - ["named_entities", nlp.named_entities, (s_text,)], - ["noun_chunks", nlp.noun_chunks, (s_text,)], - ["stem", nlp.stem, (s_text,)], -] +valid_inputs = { + "TokenSeries": s_tokenized_lists, + "TextSeries": s_text, + "VectorSeries": s_numeric_lists, +} + +# Specify your custom test cases here (functions that +# has multiple arguments, doesn't accpet HeroSeries, etc.) +test_cases_nlp = [] test_cases_preprocessing = [ - ["fillna", preprocessing.fillna, (s_text,)], - ["lowercase", preprocessing.lowercase, (s_text,)], - ["replace_digits", preprocessing.replace_digits, (s_text, "")], - ["remove_digits", preprocessing.remove_digits, (s_text,)], - ["replace_punctuation", preprocessing.replace_punctuation, (s_text, "")], - ["remove_punctuation", preprocessing.remove_punctuation, (s_text,)], - ["remove_diacritics", preprocessing.remove_diacritics, (s_text,)], - ["remove_whitespace", preprocessing.remove_whitespace, (s_text,)], + # ["replace_digits", preprocessing.replace_digits, (s_text, "")], + # ["replace_punctuation", preprocessing.replace_punctuation, (s_text, "")], + # TODO: Add default params for these functions to make them auto-tested? ["replace_stopwords", preprocessing.replace_stopwords, (s_text, "")], - ["remove_stopwords", preprocessing.remove_stopwords, (s_text,)], - ["clean", preprocessing.clean, (s_text,)], - ["remove_round_brackets", preprocessing.remove_round_brackets, (s_text,)], - ["remove_curly_brackets", preprocessing.remove_curly_brackets, (s_text,)], - ["remove_square_brackets", preprocessing.remove_square_brackets, (s_text,)], - ["remove_angle_brackets", preprocessing.remove_angle_brackets, (s_text,)], - ["remove_brackets", preprocessing.remove_brackets, (s_text,)], - ["remove_html_tags", preprocessing.remove_html_tags, (s_text,)], - ["tokenize", preprocessing.tokenize, (s_text,)], - ["phrases", preprocessing.phrases, (s_tokenized_lists,)], ["replace_urls", preprocessing.replace_urls, (s_text, "")], - ["remove_urls", preprocessing.remove_urls, (s_text,)], ["replace_tags", preprocessing.replace_tags, (s_text, "")], - ["remove_tags", preprocessing.remove_tags, (s_text,)], + ["replace_hashtags", preprocessing.replace_hashtags, (s_text, "")], ] test_cases_representation = [ @@ -69,13 +82,51 @@ test_cases_visualization = [] -test_cases = ( +# Custom test cases, a dictionary of {func_str: test_case} +test_case_custom = {} +for case in ( test_cases_nlp + test_cases_preprocessing + test_cases_representation + test_cases_visualization +): + test_case_custom[case[0]] = case + + +# Put functions' name into white list if you want to omit them +func_white_list = set( + [s for s in inspect.getmembers(visualization, inspect.isfunction)] ) +test_cases = [] + +# Find all functions under texthero +func_strs = [ + s[0] + for s in inspect.getmembers(hero, inspect.isfunction) + if s not in func_white_list +] + +for func_str in func_strs: + # Use a custom test case + if func_str in test_case_custom: + test_cases.append(test_case_custom[func_str]) + else: + # Generate one by default + func = getattr(hero, func_str) + # Functions that accept HeroSeries + if ( + hasattr(func, "allowed_hero_series_type") + and func.allowed_hero_series_type.__name__ in valid_inputs + ): + test_cases.append( + [ + func_str, + func, + (valid_inputs[func.allowed_hero_series_type.__name__],), + ] + ) + class AbstractIndexTest(PandasTestCase): """ diff --git a/texthero/_types.py b/texthero/_types.py index 3bb5d8c7..b67b90c3 100644 --- a/texthero/_types.py +++ b/texthero/_types.py @@ -238,6 +238,8 @@ def InputSeries(allowed_hero_series_type): """ def decorator(func): + func.allowed_hero_series_type = allowed_hero_series_type + @functools.wraps(func) def wrapper(*args, **kwargs): s = args[0] # The first input argument will be checked.