Skip to content

Commit

Permalink
Merge branch 'master' into 8057
Browse files Browse the repository at this point in the history
  • Loading branch information
Joan Massich committed Sep 1, 2017
2 parents 57f345f + d6a4235 commit 5a11bd0
Show file tree
Hide file tree
Showing 24 changed files with 404 additions and 88 deletions.
5 changes: 5 additions & 0 deletions doc/developers/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ should be used when applicable.
be sliced or indexed using safe_index. This is used to validate input for
cross-validation.

- :func:`validation.check_memory` checks that input is ``joblib.Memory``-like,
which means that it can be converted into a
``sklearn.externals.joblib.Memory`` instance (typically a str denoting
the ``cachedir``) or has the same interface.

If your code relies on a random number generator, it should never use
functions like ``numpy.random.random`` or ``numpy.random.normal``. This
approach can lead to repeatability issues in unit tests. Instead, a
Expand Down
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,7 @@ Model validation
naive_bayes.BernoulliNB
naive_bayes.GaussianNB
naive_bayes.MultinomialNB
naive_bayes.ComplementNB


.. _neighbors_ref:
Expand Down Expand Up @@ -1377,6 +1378,7 @@ Low-level methods
utils.sparsefuncs.inplace_swap_column
utils.sparsefuncs.mean_variance_axis
utils.validation.check_is_fitted
utils.validation.check_memory
utils.validation.check_symmetric
utils.validation.column_or_1d
utils.validation.has_fit_parameter
Expand Down
2 changes: 1 addition & 1 deletion doc/modules/clustering.rst
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ Spectral clustering
:class:`SpectralClustering` does a low-dimension embedding of the
affinity matrix between samples, followed by a KMeans in the low
dimensional space. It is especially efficient if the affinity matrix is
sparse and the `pyamg <http://pyamg.org/>`_ module is installed.
sparse and the `pyamg <https://github.com/pyamg/pyamg>`_ module is installed.
SpectralClustering requires the number of clusters to be specified. It
works well for a small number of clusters but is not advised when using
many clusters.
Expand Down
43 changes: 43 additions & 0 deletions doc/modules/naive_bayes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,49 @@ in further computations.
Setting :math:`\alpha = 1` is called Laplace smoothing,
while :math:`\alpha < 1` is called Lidstone smoothing.

.. _complement_naive_bayes:

Complement Naive Bayes
----------------------

:class:`ComplementNB` implements the complement naive Bayes (CNB) algorithm.
CNB is an adaptation of the standard multinomial naive Bayes (MNB) algorithm
that is particularly suited for imbalanced data sets. Specifically, CNB uses
statistics from the *complement* of each class to compute the model's weights.
The inventors of CNB show empirically that the parameter estimates for CNB are
more stable than those for MNB. Further, CNB regularly outperforms MNB (often
by a considerable margin) on text classification tasks. The procedure for
calculating the weights is as follows:

.. math::
\hat{\theta}_{ci} = \frac{\alpha_i + \sum_{j:y_j \neq c} d_{ij}}
{\alpha + \sum_{j:y_j \neq c} \sum_{k} d_{kj}}
w_{ci} = \log \hat{\theta}_{ci}
w_{ci} = \frac{w_{ci}}{\sum_{j} w_{cj}}
where the summations are over all documents :math:`j` not in class :math:`c`,
:math:`d_{ij}` is either the count or tf-idf value of term :math:`i` in document
:math:`j`, :math:`\alpha_i` is a smoothing hyperparameter like that found in
MNB, and :math:`\alpha = \sum_{i} \alpha_i`. The second normalization addresses
the tendency for longer documents to dominate parameter estimates in MNB. The
classification rule is:

.. math::
\hat{c} = \arg\min_c \sum_{i} t_i w_{ci}
i.e., a document is assigned to the class that is the *poorest* complement
match.

.. topic:: References:

* Rennie, J. D., Shih, L., Teevan, J., & Karger, D. R. (2003).
`Tackling the poor assumptions of naive bayes text classifiers.
<http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf>`_
In ICML (Vol. 3, pp. 616-623).

.. _bernoulli_naive_bayes:

Expand Down
4 changes: 4 additions & 0 deletions doc/related_projects.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ and tasks.
- `xgboost <https://github.com/dmlc/xgboost>`_ Optimised gradient boosted decision
tree library.

- `ML-Ensemble <http://mlens.readthedocs.io/en/latest/>`_ Generalized
ensemble learning (stacking, blending, subsemble, deep ensembles,
etc.).

- `lightning <https://github.com/scikit-learn-contrib/lightning>`_ Fast
state-of-the-art linear model solvers (SDCA, AdaGrad, SVRG, SAG, etc...).

Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ Classifiers and regressors

- Add support for sparse multilabel ``y`` in :class:`NeighborsBase`
:issue:`8057` by :user:`Aman Dalmia <dalmia>`, :user:`Joan Massich <massich>`.
- Added :class:`naive_bayes.ComplementNB`, which implements the Complement
Naive Bayes classifier described in Rennie et al. (2003).
By :user:`Michael A. Alcorn <airalcorn2>`.

Enhancements
............
Expand Down
2 changes: 1 addition & 1 deletion examples/classification/plot_classifier_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
KNeighborsClassifier(3),
SVC(kernel="linear", C=0.025),
SVC(gamma=2, C=1),
GaussianProcessClassifier(1.0 * RBF(1.0), warm_start=True),
GaussianProcessClassifier(1.0 * RBF(1.0)),
DecisionTreeClassifier(max_depth=5),
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
MLPClassifier(alpha=1),
Expand Down
3 changes: 2 additions & 1 deletion examples/text/document_classification_20newsgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import Perceptron
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.naive_bayes import BernoulliNB, MultinomialNB
from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import NearestCentroid
from sklearn.ensemble import RandomForestClassifier
Expand Down Expand Up @@ -283,6 +283,7 @@ def benchmark(clf):
print("Naive Bayes")
results.append(benchmark(MultinomialNB(alpha=.01)))
results.append(benchmark(BernoulliNB(alpha=.01)))
results.append(benchmark(ComplementNB(alpha=.1)))

print('=' * 80)
print("LinearSVC with L1-based feature selection")
Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/affinity_propagation_.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def affinity_propagation(S, preference=None, convergence_iter=15, max_iter=200,
if verbose:
print("Did not converge")

I = np.where(np.diag(A + R) > 0)[0]
I = np.flatnonzero(E)
K = I.size # Identify exemplars

if K > 0:
Expand Down
19 changes: 4 additions & 15 deletions sklearn/cluster/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from scipy.sparse.csgraph import connected_components

from ..base import BaseEstimator, ClusterMixin
from ..externals.joblib import Memory
from ..externals import six
from ..metrics.pairwise import paired_distances, pairwise_distances
from ..utils import check_array
from ..utils.validation import check_memory

from . import _hierarchical
from ._feature_agglomeration import AgglomerationTransform
Expand Down Expand Up @@ -609,8 +609,7 @@ class AgglomerativeClustering(BaseEstimator, ClusterMixin):
"manhattan", "cosine", or 'precomputed'.
If linkage is "ward", only "euclidean" is accepted.
memory : Instance of sklearn.externals.joblib.Memory or string, optional \
(default=None)
memory : None, str or object with the joblib.Memory interface, optional
Used to cache the output of the computation of the tree.
By default, no caching is done. If a string is given, it is the
path to the caching directory.
Expand Down Expand Up @@ -693,16 +692,7 @@ def fit(self, X, y=None):
self
"""
X = check_array(X, ensure_min_samples=2, estimator=self)
memory = self.memory
if memory is None:
memory = Memory(cachedir=None, verbose=0)
elif isinstance(memory, six.string_types):
memory = Memory(cachedir=memory, verbose=0)
elif not isinstance(memory, Memory):
raise ValueError("'memory' should either be a string or"
" a sklearn.externals.joblib.Memory"
" instance, got 'memory={!r}' instead.".format(
type(memory)))
memory = check_memory(self.memory)

if self.n_clusters <= 0:
raise ValueError("n_clusters should be an integer greater than 0."
Expand Down Expand Up @@ -779,8 +769,7 @@ class FeatureAgglomeration(AgglomerativeClustering, AgglomerationTransform):
"manhattan", "cosine", or 'precomputed'.
If linkage is "ward", only "euclidean" is accepted.
memory : Instance of sklearn.externals.joblib.Memory or string, optional \
(default=None)
memory : None, str or object with the joblib.Memory interface, optional
Used to cache the output of the computation of the tree.
By default, no caching is done. If a string is given, it is the
path to the caching directory.
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class PCA(_BasePCA):
mean_ : array, shape (n_features,)
Per-feature empirical mean, estimated from the training set.
Equal to `X.mean(axis=1)`.
Equal to `X.mean(axis=0)`.
n_components_ : int
The estimated number of components. When n_components is set
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/randomized_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class RandomizedLasso(BaseRandomizedLinearModel):
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'
memory : Instance of sklearn.externals.joblib.Memory or string, optional \
memory : None, str or object with the joblib.Memory interface, optional \
(default=None)
Used for internal caching. By default, no caching is done.
If a string is given, it is the path to the caching directory.
Expand Down Expand Up @@ -472,7 +472,7 @@ class RandomizedLogisticRegression(BaseRandomizedLinearModel):
- A string, giving an expression as a function of n_jobs,
as in '2*n_jobs'
memory : Instance of sklearn.externals.joblib.Memory or string, optional \
memory : None, str or object with the joblib.Memory interface, optional \
(default=None)
Used for internal caching. By default, no caching is done.
If a string is given, it is the path to the caching directory.
Expand Down
2 changes: 0 additions & 2 deletions sklearn/linear_model/tests/test_ransac.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from scipy import sparse

import numpy as np
from scipy import sparse

Expand Down
5 changes: 0 additions & 5 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,6 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,

cv = check_cv(cv, y, classifier=is_classifier(estimator))

# Ensure the estimator has implemented the passed decision function
if not callable(getattr(estimator, method)):
raise AttributeError('{} not implemented in estimator'
.format(method))

if method in ['decision_function', 'predict_proba', 'predict_log_proba']:
le = LabelEncoder()
y = le.fit_transform(y)
Expand Down
9 changes: 8 additions & 1 deletion sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from sklearn.metrics import r2_score
from sklearn.metrics.scorer import check_scoring

from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.linear_model import Ridge, LogisticRegression, SGDClassifier
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
Expand Down Expand Up @@ -1194,6 +1194,13 @@ def test_cross_val_predict_with_method():
check_cross_val_predict_with_method(LogisticRegression())


def test_cross_val_predict_method_checking():
# Regression test for issue #9639. Tests that cross_val_predict does not
# check estimator methods (e.g. predict_proba) before fitting
est = SGDClassifier(loss='log', random_state=2)
check_cross_val_predict_with_method(est)


def test_gridsearchcv_cross_val_predict_with_method():
est = GridSearchCV(LogisticRegression(random_state=42),
{'C': [0.1, 1]},
Expand Down
93 changes: 92 additions & 1 deletion sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .utils.validation import check_is_fitted
from .externals import six

__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB']
__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB']


class BaseNB(six.with_metaclass(ABCMeta, BaseEstimator, ClassifierMixin)):
Expand Down Expand Up @@ -726,6 +726,97 @@ def _joint_log_likelihood(self, X):
self.class_log_prior_)


class ComplementNB(BaseDiscreteNB):
"""The Complement Naive Bayes classifier described in Rennie et al. (2003).
The Complement Naive Bayes classifier was designed to correct the "severe
assumptions" made by the standard Multinomial Naive Bayes classifier. It is
particularly suited for imbalanced data sets.
Read more in the :ref:`User Guide <complement_naive_bayes>`.
Parameters
----------
alpha : float, optional (default=1.0)
Additive (Laplace/Lidstone) smoothing parameter (0 for no smoothing).
fit_prior : boolean, optional (default=True)
Only used in edge case with a single class in the training set.
class_prior : array-like, size (n_classes,), optional (default=None)
Prior probabilities of the classes. Not used.
Attributes
----------
class_log_prior_ : array, shape (n_classes, )
Smoothed empirical log probability for each class. Only used in edge
case with a single class in the training set.
feature_log_prob_ : array, shape (n_classes, n_features)
Empirical weights for class complements.
class_count_ : array, shape (n_classes,)
Number of samples encountered for each class during fitting. This
value is weighted by the sample weight when provided.
feature_count_ : array, shape (n_classes, n_features)
Number of samples encountered for each (class, feature) during fitting.
This value is weighted by the sample weight when provided.
feature_all_ : array, shape (n_features,)
Number of samples encountered for each feature during fitting. This
value is weighted by the sample weight when provided.
Examples
--------
>>> import numpy as np
>>> X = np.random.randint(5, size=(6, 100))
>>> y = np.array([1, 2, 3, 4, 5, 6])
>>> from sklearn.naive_bayes import ComplementNB
>>> clf = ComplementNB()
>>> clf.fit(X, y)
ComplementNB(alpha=1.0, class_prior=None, fit_prior=True)
>>> print(clf.predict(X[2:3]))
[3]
References
----------
Rennie, J. D., Shih, L., Teevan, J., & Karger, D. R. (2003).
Tackling the poor assumptions of naive bayes text classifiers. In ICML
(Vol. 3, pp. 616-623).
http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf
"""

def __init__(self, alpha=1.0, fit_prior=True, class_prior=None):
self.alpha = alpha
self.fit_prior = fit_prior
self.class_prior = class_prior

def _count(self, X, Y):
"""Count feature occurrences."""
if np.any((X.data if issparse(X) else X) < 0):
raise ValueError("Input X must be non-negative")
self.feature_count_ += safe_sparse_dot(Y.T, X)
self.class_count_ += Y.sum(axis=0)
self.feature_all_ = self.feature_count_.sum(axis=0)

def _update_feature_log_prob(self, alpha):
"""Apply smoothing to raw counts and compute the weights."""
comp_count = self.feature_all_ + alpha - self.feature_count_
logged = np.log(comp_count / comp_count.sum(axis=1, keepdims=True))
self.feature_log_prob_ = logged / logged.sum(axis=1, keepdims=True)

def _joint_log_likelihood(self, X):
"""Calculate the class scores for the samples in X."""
check_is_fitted(self, "classes_")

X = check_array(X, accept_sparse="csr")
jll = safe_sparse_dot(X, self.feature_log_prob_.T)
if len(self.classes_) == 1:
jll += self.class_log_prior_
return jll


class BernoulliNB(BaseDiscreteNB):
"""Naive Bayes classifier for multivariate Bernoulli models.
Expand Down
Loading

0 comments on commit 5a11bd0

Please sign in to comment.