diff --git a/bulbea/_util/_util.py b/bulbea/_util/_util.py index 9ff2cfe4f..fb94c3242 100644 --- a/bulbea/_util/_util.py +++ b/bulbea/_util/_util.py @@ -1,123 +1,101 @@ -# imports - compatibility packages from __future__ import absolute_import from six import string_types -# imports - standard packages import os -import collections import numbers +import collections from datetime import datetime -# imports - third-party packages import pandas as pd -# module imports from bulbea.exceptions import TYPE_ERROR_STRING -def _raise_type_error(expected_type_name, recieved_type_name): +__all__ = [ + "_raise_type_error", "_get_type_name", "_get_datetime_str", "_check_type", + "_check_str", "_check_int", "_check_real", "_check_pandas_series", "_check_pandas_dataframe", + "_check_iterable", "_check_sequence", "_check_environment_variable_set", + "_validate_in_range", "_validate_date", "_assign_if_none", "_is_sequence_all" +] + +def _raise_type_error(expected_type_name: str, received_type_name: str) -> None: raise TypeError(TYPE_ERROR_STRING.format( - expected_type_name = expected_type_name, - recieved_type_name = recieved_type_name + expected_type_name=expected_type_name, + recieved_type_name=received_type_name )) -def _get_type_name(o): - type_ = type(o) - name = type_.__name__ +def _get_type_name(o: object) -> str: + return type(o).__name__ - return name - -def _get_datetime_str(dt, format_): - if _check_type(dt, pd.Timestamp): +def _get_datetime_str(dt: datetime, format_: str) -> str: + if isinstance(dt, pd.Timestamp): dt = dt.to_pydatetime() - _check_type(dt, type_ = datetime, raise_err = True, expected_type_name = 'datetime.datetime') - - string = dt.strftime(format_) - - return string + _check_type(dt, type_=datetime, raise_err=True, expected_type_name='datetime.datetime') + return dt.strftime(format_) -def _check_type(o, type_, raise_err = False, expected_type_name = None): +def _check_type(o: object, type_: type, raise_err: bool = False, expected_type_name: str = None) -> bool: if not isinstance(o, type_): if raise_err: _raise_type_error( - expected_type_name = expected_type_name, - recieved_type_name = _get_type_name(o) + expected_type_name=expected_type_name or type_.__name__, + received_type_name=_get_type_name(o) ) - else: - return False - else: - return True + return False + return True -def _check_str(o, raise_err = False): - return _check_type(o, string_types, raise_err = raise_err, expected_type_name = 'str') +def _check_str(o: object, raise_err: bool = False) -> bool: + return _check_type(o, string_types, raise_err=raise_err, expected_type_name='str') -def _check_int(o, raise_err = False): - return _check_type(o, numbers.Integral, raise_err = raise_err, expected_type_name = 'int') +def _check_int(o: object, raise_err: bool = False) -> bool: + return _check_type(o, numbers.Integral, raise_err=raise_err, expected_type_name='int') -def _check_real(o, raise_err = False): - return _check_type(o, numbers.Real, raise_err = raise_err, expected_type_name = '(int, float)') +def _check_real(o: object, raise_err: bool = False) -> bool: + return _check_type(o, numbers.Real, raise_err=raise_err, expected_type_name='(int, float)') -def _check_pandas_series(data, raise_err = False): - return _check_type(data, pd.Series, raise_err = raise_err, expected_type_name = 'pandas.Series') +def _check_pandas_series(data: object, raise_err: bool = False) -> bool: + return _check_type(data, pd.Series, raise_err=raise_err, expected_type_name='pandas.Series') -def _check_pandas_dataframe(data, raise_err = False): - return _check_type(data, pd.DataFrame, raise_err = raise_err, expected_type_name = 'pandas.DataFrame') +def _check_pandas_dataframe(data: object, raise_err: bool = False) -> bool: + return _check_type(data, pd.DataFrame, raise_err=raise_err, expected_type_name='pandas.DataFrame') -def _check_iterable(o, raise_err = False): - return _check_type(o, collections.Iterable, raise_err = raise_err, expected_type_name = '(str, list, tuple)') +def _check_iterable(o: object, raise_err: bool = False) -> bool: + return _check_type(o, collections.abc.Iterable, raise_err=raise_err, expected_type_name='Iterable') -def _check_sequence(o, string = True, raise_err = False): - return _check_type(o, collections.Sequence, raise_err = raise_err, expected_type_name = '(list, tuple)') +def _check_sequence(o: object, string: bool = True, raise_err: bool = False) -> bool: + return _check_type(o, collections.abc.Sequence, raise_err=raise_err, expected_type_name='(list, tuple, str)' if string else '(list, tuple)') -def _check_environment_variable_set(variable, raise_err = False): - _check_str(variable, raise_err = raise_err) +def _check_environment_variable_set(variable: str, raise_err: bool = False) -> bool: + _check_str(variable, raise_err=raise_err) - try: - os.environ[variable] - except KeyError: + if variable not in os.environ: if raise_err: - raise ValueError('Environment variable {variable} not set.') - else: - return False + raise ValueError(f"Environment variable '{variable}' not set.") + return False return True -def _validate_in_range(value, low, high, raise_err = False): - if not low <= value <= high: +def _validate_in_range(value: float, low: float, high: float, raise_err: bool = False) -> bool: + if not (low <= value <= high): if raise_err: - raise ValueError('{value} out of bounds, must be in range [{low}, {high}].'.format( - value = value, - low = low, - high = high - )) - else: - return False - else: - return True - -def _validate_date(value, format_ = '%Y-%m-%d', raise_err = False): - _check_str(value, raise_err = raise_err) + raise ValueError(f"{value} out of bounds, must be in range [{low}, {high}].") + return False + return True + +def _validate_date(value: str, format_: str = '%Y-%m-%d', raise_err: bool = False) -> bool: + _check_str(value, raise_err=raise_err) try: datetime.strptime(value, format_) except ValueError: if raise_err: - raise ValueError('Expected {format_} format, got {value} instead.'.format( - format_ = format_, - value = value - )) - else: - return False + raise ValueError(f"Expected {format_} format, got '{value}' instead.") + return False return True def _assign_if_none(a, b): return b if a is None else a -def _is_sequence_all(seq): - _check_sequence(seq, raise_err = True) - - length = len(seq) - is_seq = True if length != 0 and seq.count(seq[0]) == length else False - - return is_seq +def _is_sequence_all(seq: list) -> bool: + _check_sequence(seq, raise_err=True) + return len(seq) > 0 and seq.count(seq[0]) == len(seq)