-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathnbsvm.py
More file actions
37 lines (30 loc) · 1.29 KB
/
nbsvm.py
File metadata and controls
37 lines (30 loc) · 1.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_X_y, check_is_fitted
from sklearn.linear_model import LogisticRegression
from scipy import sparse
class NbSvmClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, C=1.0, dual=False, n_jobs=1):
self.C = C
self.dual = dual
self.n_jobs = n_jobs
def predict(self, x):
# Verify that model has been fit
check_is_fitted(self, ['_r', '_clf'])
return self._clf.predict(x.multiply(self._r))
def predict_proba(self, x):
# Verify that model has been fit
check_is_fitted(self, ['_r', '_clf'])
return self._clf.predict_proba(x.multiply(self._r))
def fit(self, x, y):
# Check that X and y have correct shape
y = y.values
x, y = check_X_y(x, y, accept_sparse=True)
def pr(x, y_i, y):
p = x[y==y_i].sum(0)
return (p+1) / ((y==y_i).sum()+1)
self._r = sparse.csr_matrix(np.log(pr(x,1,y) / pr(x,0,y)))
x_nb = x.multiply(self._r)
self._clf = LogisticRegression(C=self.C, dual=self.dual, n_jobs=self.n_jobs).fit(x_nb, y)
return self
#EXAMPLE USAGE
#model = NbSvmClassifier(C=4, dual=True, n_jobs=-1).fit(training_features, training_labels)