forked from ytang07/nn_examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsklearn_simple_nn.py
33 lines (27 loc) · 913 Bytes
/
sklearn_simple_nn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import warnings
from sklearn.datasets import fetch_openml
from sklearn.exceptions import ConvergenceWarning
from sklearn.neural_network import MLPClassifier
# load MNIST data from fetch_openml
X, y = fetch_openml("mnist_784", version=1, return_X_y=True)
X = X/255.0
# get train/test split
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
# set up MLP Classifier
mlp = MLPClassifier(
hidden_layer_sizes=(50,),
max_iter=15,
alpha=1e-4,
solver="sgd",
verbose=True,
random_state=1,
learning_rate_init=0.1
)
# We probably won't converge so we'll catch the warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning, module="sklearn")
mlp.fit(X_train, y_train)
# print out the model scores
print(f"Training set score: {mlp.score(X_train, y_train)}")
print(f"Test set score: {mlp.score(X_test, y_test)}")