Skip to content

Commit 1b9987e

Browse files
committed
clarified number of elements in reg
1 parent c999ba6 commit 1b9987e

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/nn_regression.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pandas as pd
88
from sklearn.datasets import fetch_california_housing
9-
from sklearn.metrics import mean_absolute_error, mean_squared_error
9+
from sklearn.metrics import mean_squared_error
1010
from sklearn.model_selection import train_test_split
1111
from sklearn.neighbors import KNeighborsRegressor
1212
from sklearn.preprocessing import StandardScaler
@@ -30,7 +30,7 @@ def gs_knearest_regressor(
3030
(Tuple[KNeighborsRegressor, signedinteger, np.ndarray]):
3131
KNeighborsRegressor: The model that was trained with the optimal number of k.
3232
signedinteger: k that produced the smallest mean squared error on the val set.
33-
np.ndarray: The array with means squared errors for k values between 1 and 39.
33+
np.ndarray: The array with means squared errors for k values between 1 and 40 (40 numbers in total).
3434
"""
3535
# TODO: Implement me.
3636
return None
@@ -101,10 +101,10 @@ def gs_knearest_regressor(
101101
# compute and print MSE of best estimator on test set
102102
# TODO
103103

104-
# plot mean squared error for k values between 1 and 39
104+
# plot mean squared error for k values between 1 and 40
105105
plt.figure(figsize=(12, 6))
106106
plt.plot(
107-
range(1, 40),
107+
range(1, 41),
108108
error_array,
109109
color="red",
110110
linestyle="dashed",

tests/test_nn_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_gs_knearest_regressor() -> None:
2525
# Check the returned objects
2626
assert isinstance(knn, KNeighborsRegressor)
2727
assert isinstance(error_array, np.ndarray)
28-
assert error_array.shape[0] == 39
28+
assert error_array.shape[0] == 40
2929

3030
# Generate a validation set
3131
x_val = np.random.rand(10, n_features)

0 commit comments

Comments
 (0)