-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfind_best_reg.py
executable file
·65 lines (47 loc) · 1.57 KB
/
find_best_reg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/python2.7
import numpy as np
import matplotlib.pyplot as plt
from rsvd import RSVD, rating_t, MovieLensDataset
ratingsDataset = MovieLensDataset.loadDat('data_movilens1m/ratings.dat')
ratings=ratingsDataset.ratings()
# make sure that the ratings a properly shuffled
np.random.shuffle(ratings)
# create train, validation and test sets.
n = int(ratings.shape[0]*0.8)
train = ratings[:n]
test = ratings[n:]
v = int(train.shape[0]*0.9)
val = train[v:]
train = train[:v]
dims = (ratingsDataset.movieIDs().shape[0], ratingsDataset.userIDs().shape[0])
factor = 40
lambdas = []
errors = []
# lambda_f ne doit pas depasser 1
# maxEpochs = 1000
for lambda_f in np.arange(0.0, 0.05, 0.0005):
model = RSVD.train(factor, train, dims, probeArray=val, maxEpochs = 1000, regularization=lambda_f)
sqerr=0.0
for movieID,userID,rating in test:
err = rating - model(movieID,userID)
sqerr += err * err
sqerr /= test.shape[0]
print "-------------------------------------------------"
print "Pour lambda = ",lambda_f, " Test RMSE: ", np.sqrt(sqerr)
print "-------------------------------------------------"
lambdas.append(lambda_f)
errors.append(np.sqrt(sqerr))
# print the lamdas and errors vectors
print lambdas
print errors
# get minimal error and its corresponding lamda
min_err = min(errors)
id_min_err = errors.index(min(errors))
best_lambda = lambdas[errors.index(min(errors))]
print "minimum trouve pour l erreur", min_err
print "correspond a lambda =", best_lambda
#plot errors /lambdas
plt.plot(lambdas, errors)
plt.ylabel('erreur')
plt.xlabel('lambda')
plt.show()