-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutils.py
148 lines (121 loc) · 4.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright (c) 2023, Haruka Kiyohara, Ren Kishimoto, HAKUHODO Technologies Inc., and Hanjuku-kaso Co., Ltd. All rights reserved.
# Licensed under the Apache 2.0 License.
"""Useful tools."""
from dataclasses import dataclass
from typing import Union, Optional
import numpy as np
from sklearn.utils import check_scalar, check_random_state
from .types import Numeric
@dataclass
class NormalDistribution:
"""Class to sample from normal distribution.
Parameters
-------
mean: {int, float, array-like}
Mean parameter of the normal distribution.
std: {int, float, array-like}
Standard deviation of the normal distribution.
random_state: int, default=None (>= 0)
Random state.
"""
mean: Union[int, float, np.ndarray]
std: Union[int, float, np.ndarray]
random_state: Optional[int] = None
def __post_init__(self):
if not isinstance(self.mean, Numeric) and not (
isinstance(self.mean, np.ndarray) and self.mean.ndim == 1
):
raise ValueError(
"mean must be a float number or an 1-dimensional NDArray of float values"
)
if not (isinstance(self.std, Numeric) and self.std >= 0) and not (
isinstance(self.std, np.ndarray)
and self.std.ndim == 1
and self.std.min() >= 0
):
raise ValueError(
"std must be a non-negative float number or an 1-dimensional NDArray of non-negative float values"
)
if not (
isinstance(self.mean, Numeric) and isinstance(self.std, Numeric)
) and not (
isinstance(self.mean, np.ndarray)
and isinstance(self.std, np.ndarray)
and len(self.mean) == len(self.std)
):
raise ValueError("mean and std must have the same length")
if self.random_state is None:
raise ValueError("random_state must be given")
self.random_ = check_random_state(self.random_state)
self.is_single_parameter = False
if isinstance(self.mean, Numeric):
self.is_single_parameter = True
def sample(self, size: int = 1) -> np.ndarray:
"""Sample random variables from the pre-determined normal distribution.
Parameters
-------
size: int, default=1 (> 0)
Total number of the random variable to sample.
Returns
-------
random_variables: ndarray of shape (size, )
Random variables sampled from the normal distribution.
"""
check_scalar(size, name="size", target_type=int, min_val=1)
if self.is_single_parameter:
random_variables = self.random_.normal(
loc=self.mean, scale=self.std, size=size
)
else:
random_variables = self.random_.normal(
loc=self.mean, scale=self.std, size=(size, len(self.mean))
)
return random_variables
def sigmoid(x: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""Sigmoid function"""
return 1 / (1 + np.exp(-x))
def check_array(
array: np.ndarray,
name: str,
expected_dim: int = 1,
expected_dtype: Optional[type] = None,
min_val: Optional[float] = None,
max_val: Optional[float] = None,
) -> ValueError:
"""Input validation on array.
Parameters
-------
array: object
Input array to check.
name: str
Name of the input array.
expected_dim: int, default=1
Expected dimension of the input array.
expected_dtype: {type, tuple of type}, default=None
Expected dtype of the input array.
min_val: float, default=None
Minimum value allowed in the input array.
max_val: float, default=None
Maximum value allowed in the input array.
"""
if not isinstance(array, np.ndarray):
raise ValueError(f"{name} must be {expected_dim}D array, but got {type(array)}")
if array.ndim != expected_dim:
raise ValueError(
f"{name} must be {expected_dim}D array, but got {expected_dim}D array"
)
if expected_dtype is not None:
if not np.issubsctype(array, expected_dtype):
raise ValueError(
f"The elements of {name} must be {expected_dtype}, but got {array.dtype}"
)
if min_val is not None:
if array.min() < min_val:
raise ValueError(
f"The elements of {name} must be larger than {min_val}, but got minimum value {array.min()}"
)
if max_val is not None:
if array.max() > max_val:
raise ValueError(
f"The elements of {name} must be smaller than {max_val}, but got maximum value {array.max()}"
)