Skip to content

Commit 09d500a

Browse files
authoredMay 16, 2024··
Unify weight handling and refactor linear models' helper functions (elixir-nx#267)
1 parent 74ed5fe commit 09d500a

6 files changed

+81
-132
lines changed
 

‎lib/scholar/linear/bayesian_ridge_regression.ex

+6-45
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
6262
require Nx
6363
import Nx.Defn
6464
import Scholar.Shared
65+
alias Scholar.Linear.LinearHelpers
6566

6667
@derive {Nx.Container,
6768
containers: [
@@ -95,13 +96,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
9596
"""
9697
],
9798
sample_weights: [
98-
type:
99-
{:or,
100-
[
101-
{:custom, Scholar.Options, :non_negative_number, []},
102-
{:list, {:custom, Scholar.Options, :non_negative_number, []}},
103-
{:custom, Scholar.Options, :weights, []}
104-
]},
99+
type: {:custom, Scholar.Options, :weights, []},
105100
doc: """
106101
The weights for each observation. If not provided,
107102
all observations are assigned equal weight.
@@ -237,13 +232,9 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
237232
] ++
238233
opts
239234

240-
{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
241235
x_type = to_float_type(x)
242236

243-
sample_weights =
244-
if Nx.is_tensor(sample_weights),
245-
do: Nx.as_type(sample_weights, x_type),
246-
else: Nx.tensor(sample_weights, type: x_type)
237+
sample_weights = LinearHelpers.build_sample_weights(x, opts)
247238

248239
# handle vector types
249240
# handle default alpha value, add eps to avoid division by 0
@@ -288,7 +279,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
288279

289280
{x_offset, y_offset} =
290281
if opts[:fit_intercept?] do
291-
preprocess_data(x, y, sample_weights, opts)
282+
LinearHelpers.preprocess_data(x, y, sample_weights, opts)
292283
else
293284
x_offset_shape = Nx.axis_size(x, 1)
294285
y_reshaped = if Nx.rank(y) > 1, do: y, else: Nx.reshape(y, {:auto, 1})
@@ -302,7 +293,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
302293

303294
{x, y} =
304295
if opts[:sample_weights_flag] do
305-
rescale(x, y, sample_weights)
296+
LinearHelpers.rescale(x, y, sample_weights)
306297
else
307298
{x, y}
308299
end
@@ -360,7 +351,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
360351
{x, y, xt_y, u, s, vh, eigenvals, alpha_1, alpha_2, lambda_1, lambda_2, iterations}}
361352
end
362353

363-
intercept = set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?])
354+
intercept = LinearHelpers.set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?])
364355
scaled_sigma = Nx.dot(vh, [0], vh / Nx.new_axis(eigenvals + lambda / alpha, -1), [0])
365356
sigma = scaled_sigma / alpha
366357
{coef, intercept, alpha, lambda, iter, has_converged, scores, sigma}
@@ -449,34 +440,4 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
449440
end
450441

451442
defnp predict_n(coeff, intercept, x), do: Nx.dot(x, [-1], coeff, [-1]) + intercept
452-
453-
# Implements sample weighting by rescaling inputs and
454-
# targets by sqrt(sample_weight).
455-
defnp rescale(x, y, sample_weights) do
456-
factor = Nx.sqrt(sample_weights)
457-
458-
x_scaled =
459-
case Nx.shape(factor) do
460-
{} -> factor * x
461-
_ -> Nx.new_axis(factor, 1) * x
462-
end
463-
464-
{x_scaled, factor * y}
465-
end
466-
467-
defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
468-
if fit_intercept? do
469-
y_offset - Nx.dot(x_offset, coeff)
470-
else
471-
Nx.tensor(0.0, type: Nx.type(coeff))
472-
end
473-
end
474-
475-
defnp preprocess_data(x, y, sample_weights, opts) do
476-
if opts[:sample_weights_flag],
477-
do:
478-
{Nx.weighted_mean(x, sample_weights, axes: [0]),
479-
Nx.weighted_mean(y, sample_weights, axes: [0])},
480-
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
481-
end
482443
end

‎lib/scholar/linear/linear_helpers.ex

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
defmodule Scholar.Linear.LinearHelpers do
2+
require Nx
3+
import Nx.Defn
4+
import Scholar.Shared
5+
6+
@moduledoc false
7+
8+
@doc false
9+
def build_sample_weights(x, opts) do
10+
x_type = to_float_type(x)
11+
{num_samples, _} = Nx.shape(x)
12+
default_sample_weights = Nx.broadcast(Nx.as_type(1.0, x_type), {num_samples})
13+
{sample_weights, _} = Keyword.pop(opts, :sample_weights, default_sample_weights)
14+
15+
# this is required for ridge regression
16+
sample_weights =
17+
if Nx.is_tensor(sample_weights),
18+
do: Nx.as_type(sample_weights, x_type),
19+
else: Nx.tensor(sample_weights, type: x_type)
20+
21+
sample_weights
22+
end
23+
24+
@doc false
25+
defn preprocess_data(x, y, sample_weights, opts) do
26+
if opts[:sample_weights_flag],
27+
do:
28+
{Nx.weighted_mean(x, sample_weights, axes: [0]),
29+
Nx.weighted_mean(y, sample_weights, axes: [0])},
30+
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
31+
end
32+
33+
@doc false
34+
defn set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
35+
if fit_intercept? do
36+
y_offset - Nx.dot(coeff, x_offset)
37+
else
38+
Nx.tensor(0.0, type: Nx.type(coeff))
39+
end
40+
end
41+
42+
# Implements sample weighting by rescaling inputs and
43+
# targets by sqrt(sample_weight).
44+
@doc false
45+
defn rescale(x, y, sample_weights) do
46+
factor = Nx.sqrt(sample_weights)
47+
48+
x_scaled =
49+
case Nx.shape(factor) do
50+
{} -> factor * x
51+
_ -> x * Nx.new_axis(factor, -1)
52+
end
53+
54+
y_scaled =
55+
case Nx.rank(y) do
56+
1 -> factor * y
57+
_ -> y * Nx.new_axis(factor, -1)
58+
end
59+
60+
{x_scaled, y_scaled}
61+
end
62+
end

‎lib/scholar/linear/linear_regression.ex

+5-40
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ defmodule Scholar.Linear.LinearRegression do
88
require Nx
99
import Nx.Defn
1010
import Scholar.Shared
11+
alias Scholar.Linear.LinearHelpers
1112

1213
@derive {Nx.Container, containers: [:coefficients, :intercept]}
1314
defstruct [:coefficients, :intercept]
@@ -75,13 +76,7 @@ defmodule Scholar.Linear.LinearRegression do
7576
] ++
7677
opts
7778

78-
{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
79-
x_type = to_float_type(x)
80-
81-
sample_weights =
82-
if Nx.is_tensor(sample_weights),
83-
do: Nx.as_type(sample_weights, x_type),
84-
else: Nx.tensor(sample_weights, type: x_type)
79+
sample_weights = LinearHelpers.build_sample_weights(x, opts)
8580

8681
fit_n(x, y, sample_weights, opts)
8782
end
@@ -92,7 +87,7 @@ defmodule Scholar.Linear.LinearRegression do
9287

9388
{a_offset, b_offset} =
9489
if opts[:fit_intercept?] do
95-
preprocess_data(a, b, sample_weights, opts)
90+
LinearHelpers.preprocess_data(a, b, sample_weights, opts)
9691
else
9792
a_offset_shape = Nx.axis_size(a, 1)
9893
b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1})
@@ -106,7 +101,7 @@ defmodule Scholar.Linear.LinearRegression do
106101

107102
{a, b} =
108103
if opts[:sample_weights_flag] do
109-
rescale(a, b, sample_weights)
104+
LinearHelpers.rescale(a, b, sample_weights)
110105
else
111106
{a, b}
112107
end
@@ -132,42 +127,12 @@ defmodule Scholar.Linear.LinearRegression do
132127
Nx.dot(x, coeff) + intercept
133128
end
134129

135-
# Implements sample weighting by rescaling inputs and
136-
# targets by sqrt(sample_weight).
137-
defnp rescale(x, y, sample_weights) do
138-
case Nx.shape(sample_weights) do
139-
{} = scalar ->
140-
scalar = Nx.sqrt(scalar)
141-
{scalar * x, scalar * y}
142-
143-
_ ->
144-
scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal()
145-
{Nx.dot(scale, x), Nx.dot(scale, y)}
146-
end
147-
end
148-
149130
# Implements ordinary least-squares by estimating the
150131
# solution A to the equation A.X = b.
151132
defnp lstsq(a, b, a_offset, b_offset, fit_intercept?) do
152133
pinv = Nx.LinAlg.pinv(a)
153134
coeff = Nx.dot(b, [0], pinv, [1])
154-
intercept = set_intercept(coeff, a_offset, b_offset, fit_intercept?)
135+
intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, fit_intercept?)
155136
{coeff, intercept}
156137
end
157-
158-
defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
159-
if fit_intercept? do
160-
y_offset - Nx.dot(coeff, x_offset)
161-
else
162-
Nx.tensor(0.0, type: Nx.type(coeff))
163-
end
164-
end
165-
166-
defnp preprocess_data(x, y, sample_weights, opts) do
167-
if opts[:sample_weights_flag],
168-
do:
169-
{Nx.weighted_mean(x, sample_weights, axes: [0]),
170-
Nx.weighted_mean(y, sample_weights, axes: [0])},
171-
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
172-
end
173138
end

‎lib/scholar/linear/polynomial_regression.ex

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ defmodule Scholar.Linear.PolynomialRegression do
1212

1313
opts = [
1414
sample_weights: [
15-
type: {:list, {:custom, Scholar.Options, :positive_number, []}},
15+
type: {:custom, Scholar.Options, :weights, []},
1616
doc: """
1717
The weights for each observation. If not provided,
1818
all observations are assigned equal weight.

‎lib/scholar/linear/ridge_regression.ex

+6-45
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,14 @@ defmodule Scholar.Linear.RidgeRegression do
2222
require Nx
2323
import Nx.Defn
2424
import Scholar.Shared
25+
alias Scholar.Linear.LinearHelpers
2526

2627
@derive {Nx.Container, containers: [:coefficients, :intercept]}
2728
defstruct [:coefficients, :intercept]
2829

2930
opts = [
3031
sample_weights: [
31-
type:
32-
{:or,
33-
[
34-
{:custom, Scholar.Options, :non_negative_number, []},
35-
{:list, {:custom, Scholar.Options, :non_negative_number, []}},
36-
{:custom, Scholar.Options, :weights, []}
37-
]},
32+
type: {:custom, Scholar.Options, :weights, []},
3833
doc: """
3934
The weights for each observation. If not provided,
4035
all observations are assigned equal weight.
@@ -126,13 +121,9 @@ defmodule Scholar.Linear.RidgeRegression do
126121
] ++
127122
opts
128123

129-
{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
130124
x_type = to_float_type(x)
131125

132-
sample_weights =
133-
if Nx.is_tensor(sample_weights),
134-
do: Nx.as_type(sample_weights, x_type),
135-
else: Nx.tensor(sample_weights, type: x_type)
126+
sample_weights = LinearHelpers.build_sample_weights(x, opts)
136127

137128
{alpha, opts} = Keyword.pop!(opts, :alpha)
138129
alpha = Nx.tensor(alpha, type: x_type) |> Nx.flatten()
@@ -160,7 +151,7 @@ defmodule Scholar.Linear.RidgeRegression do
160151

161152
{a_offset, b_offset} =
162153
if opts[:fit_intercept?] do
163-
preprocess_data(a, b, sample_weights, opts)
154+
LinearHelpers.preprocess_data(a, b, sample_weights, opts)
164155
else
165156
a_offset_shape = Nx.axis_size(a, 1)
166157
b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1})
@@ -175,7 +166,7 @@ defmodule Scholar.Linear.RidgeRegression do
175166

176167
{a, b} =
177168
if opts[:rescale_flag] do
178-
rescale(a, b, sample_weights)
169+
LinearHelpers.rescale(a, b, sample_weights)
179170
else
180171
{a, b}
181172
end
@@ -198,7 +189,7 @@ defmodule Scholar.Linear.RidgeRegression do
198189
end
199190

200191
coeff = if flatten?, do: Nx.flatten(coeff), else: coeff
201-
intercept = set_intercept(coeff, a_offset, b_offset, opts[:fit_intercept?])
192+
intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, opts[:fit_intercept?])
202193
%__MODULE__{coefficients: coeff, intercept: intercept}
203194
end
204195

@@ -222,20 +213,6 @@ defmodule Scholar.Linear.RidgeRegression do
222213
if original_rank <= 1, do: Nx.squeeze(res, axes: [1]), else: res
223214
end
224215

225-
# Implements sample weighting by rescaling inputs and
226-
# targets by sqrt(sample_weight).
227-
defnp rescale(a, b, sample_weights) do
228-
case Nx.shape(sample_weights) do
229-
{} = scalar ->
230-
scalar = Nx.sqrt(scalar)
231-
{scalar * a, scalar * b}
232-
233-
_ ->
234-
scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal()
235-
{Nx.dot(scale, a), Nx.dot(scale, b)}
236-
end
237-
end
238-
239216
defnp solve_cholesky_kernel(kernel, b, alpha, sample_weights, opts) do
240217
num_samples = Nx.axis_size(kernel, 0)
241218
num_targets = Nx.axis_size(b, 1)
@@ -325,20 +302,4 @@ defmodule Scholar.Linear.RidgeRegression do
325302
d_uty = d * uty
326303
Nx.dot(d_uty, [0], vt, [0])
327304
end
328-
329-
defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
330-
if fit_intercept? do
331-
y_offset - Nx.dot(coeff, x_offset)
332-
else
333-
Nx.tensor(0.0, type: Nx.type(coeff))
334-
end
335-
end
336-
337-
defnp preprocess_data(a, b, sample_weights, opts) do
338-
if opts[:sample_weights_flag],
339-
do:
340-
{Nx.weighted_mean(a, sample_weights, axes: [0]),
341-
Nx.weighted_mean(b, sample_weights, axes: [0])},
342-
else: {Nx.mean(a, axes: [0]), Nx.mean(b, axes: [0])}
343-
end
344305
end

‎test/scholar/linear/bayesian_ridge_regression_test.exs

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ defmodule Scholar.Linear.BayesianRidgeRegressionTest do
5353
score = compute_score(x, y, alpha, lambda, alpha_1, alpha_2, lambda_1, lambda_2)
5454

5555
brr =
56-
BayesianRidgeRegression.fit(x, y,
56+
BayesianRidgeRegression.fit(x, Nx.flatten(y),
5757
alpha_1: alpha_1,
5858
alpha_2: alpha_2,
5959
lambda_1: lambda_1,

0 commit comments

Comments
 (0)
Please sign in to comment.