@@ -22,19 +22,14 @@ defmodule Scholar.Linear.RidgeRegression do
22
22
require Nx
23
23
import Nx.Defn
24
24
import Scholar.Shared
25
+ alias Scholar.Linear.LinearHelpers
25
26
26
27
@ derive { Nx.Container , containers: [ :coefficients , :intercept ] }
27
28
defstruct [ :coefficients , :intercept ]
28
29
29
30
opts = [
30
31
sample_weights: [
31
- type:
32
- { :or ,
33
- [
34
- { :custom , Scholar.Options , :non_negative_number , [ ] } ,
35
- { :list , { :custom , Scholar.Options , :non_negative_number , [ ] } } ,
36
- { :custom , Scholar.Options , :weights , [ ] }
37
- ] } ,
32
+ type: { :custom , Scholar.Options , :weights , [ ] } ,
38
33
doc: """
39
34
The weights for each observation. If not provided,
40
35
all observations are assigned equal weight.
@@ -126,13 +121,9 @@ defmodule Scholar.Linear.RidgeRegression do
126
121
] ++
127
122
opts
128
123
129
- { sample_weights , opts } = Keyword . pop ( opts , :sample_weights , 1.0 )
130
124
x_type = to_float_type ( x )
131
125
132
- sample_weights =
133
- if Nx . is_tensor ( sample_weights ) ,
134
- do: Nx . as_type ( sample_weights , x_type ) ,
135
- else: Nx . tensor ( sample_weights , type: x_type )
126
+ sample_weights = LinearHelpers . build_sample_weights ( x , opts )
136
127
137
128
{ alpha , opts } = Keyword . pop! ( opts , :alpha )
138
129
alpha = Nx . tensor ( alpha , type: x_type ) |> Nx . flatten ( )
@@ -160,7 +151,7 @@ defmodule Scholar.Linear.RidgeRegression do
160
151
161
152
{ a_offset , b_offset } =
162
153
if opts [ :fit_intercept? ] do
163
- preprocess_data ( a , b , sample_weights , opts )
154
+ LinearHelpers . preprocess_data ( a , b , sample_weights , opts )
164
155
else
165
156
a_offset_shape = Nx . axis_size ( a , 1 )
166
157
b_reshaped = if Nx . rank ( b ) > 1 , do: b , else: Nx . reshape ( b , { :auto , 1 } )
@@ -175,7 +166,7 @@ defmodule Scholar.Linear.RidgeRegression do
175
166
176
167
{ a , b } =
177
168
if opts [ :rescale_flag ] do
178
- rescale ( a , b , sample_weights )
169
+ LinearHelpers . rescale ( a , b , sample_weights )
179
170
else
180
171
{ a , b }
181
172
end
@@ -198,7 +189,7 @@ defmodule Scholar.Linear.RidgeRegression do
198
189
end
199
190
200
191
coeff = if flatten? , do: Nx . flatten ( coeff ) , else: coeff
201
- intercept = set_intercept ( coeff , a_offset , b_offset , opts [ :fit_intercept? ] )
192
+ intercept = LinearHelpers . set_intercept ( coeff , a_offset , b_offset , opts [ :fit_intercept? ] )
202
193
% __MODULE__ { coefficients: coeff , intercept: intercept }
203
194
end
204
195
@@ -222,20 +213,6 @@ defmodule Scholar.Linear.RidgeRegression do
222
213
if original_rank <= 1 , do: Nx . squeeze ( res , axes: [ 1 ] ) , else: res
223
214
end
224
215
225
- # Implements sample weighting by rescaling inputs and
226
- # targets by sqrt(sample_weight).
227
- defnp rescale ( a , b , sample_weights ) do
228
- case Nx . shape ( sample_weights ) do
229
- { } = scalar ->
230
- scalar = Nx . sqrt ( scalar )
231
- { scalar * a , scalar * b }
232
-
233
- _ ->
234
- scale = sample_weights |> Nx . sqrt ( ) |> Nx . make_diagonal ( )
235
- { Nx . dot ( scale , a ) , Nx . dot ( scale , b ) }
236
- end
237
- end
238
-
239
216
defnp solve_cholesky_kernel ( kernel , b , alpha , sample_weights , opts ) do
240
217
num_samples = Nx . axis_size ( kernel , 0 )
241
218
num_targets = Nx . axis_size ( b , 1 )
@@ -325,20 +302,4 @@ defmodule Scholar.Linear.RidgeRegression do
325
302
d_uty = d * uty
326
303
Nx . dot ( d_uty , [ 0 ] , vt , [ 0 ] )
327
304
end
328
-
329
- defnp set_intercept ( coeff , x_offset , y_offset , fit_intercept? ) do
330
- if fit_intercept? do
331
- y_offset - Nx . dot ( coeff , x_offset )
332
- else
333
- Nx . tensor ( 0.0 , type: Nx . type ( coeff ) )
334
- end
335
- end
336
-
337
- defnp preprocess_data ( a , b , sample_weights , opts ) do
338
- if opts [ :sample_weights_flag ] ,
339
- do:
340
- { Nx . weighted_mean ( a , sample_weights , axes: [ 0 ] ) ,
341
- Nx . weighted_mean ( b , sample_weights , axes: [ 0 ] ) } ,
342
- else: { Nx . mean ( a , axes: [ 0 ] ) , Nx . mean ( b , axes: [ 0 ] ) }
343
- end
344
305
end
0 commit comments