-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathonline_learning_example.py
More file actions
22 lines (17 loc) · 1.03 KB
/
online_learning_example.py
File metadata and controls
22 lines (17 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import HashingVectorizer
vectorizer = HashingVectorizer(decode_error='ignore', n_features=2 ** 16)
dataset = fetch_20newsgroups(shuffle=True, random_state=1, remove=('headers', 'footers'))
X_test = vectorizer.fit_transform(dataset.data[:1000])
y_test = dataset.target[:1000]
from sklearn.linear_model import SGDClassifier
classifier = SGDClassifier() # SVM classifier trained online with stochastic gradient descent
# Artificially increase the size of the training set ten-fold
train_data = 10 * dataset.data[1000:]
train_target = 10 * list(dataset.target[1000:])
for i in range(0, len(train_data), 1000): # Iterate over "mini-batches" of 1000 samples each
y_train = train_target[i:i + 1000]
X_train = vectorizer.fit_transform(train_data[i:i + 1000])
# Update the classifier with documents in the current mini-batch
classifier.partial_fit(X_train, y_train, classes=range(len(dataset.target_names)))
print(classifier.score(X_test, y_test))