-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_tree_regression.py
40 lines (32 loc) · 1.19 KB
/
plot_tree_regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# plot tree regression, The decision trees is used to fit a sine curve with addition noisy observation.
# As a result, it learns local linear regressions approximating the sine curve.
# if the maximum depth of the tree (controlled by the max_depth parameter) is set too high, the decision trees learn too fine
# details of the training data and learn from the noise, i.e. they overfit.
print(__doc__)
import numpy as np
# Create a random dataset
rng = np.random.RandomState(1)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
# Fit regression model
from sklearn.tree import DecisionTreeRegressor
clf_1 = DecisionTreeRegressor(max_depth=2)
clf_2 = DecisionTreeRegressor(max_depth=5)
clf_1.fit(X, y)
clf_2.fit(X, y)
# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = clf_1.predict(X_test)
y_2 = clf_2.predict(X_test)
# Plot the results
import matplotlib.pyplot as plt
plt.figure()
plt.scatter(X, y, c="k", label="data")
plt.plot(X_test, y_1, c="g", label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, c="r", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()