diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 945a6f55..b36ae9e6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,7 +29,7 @@ python -m pip install --editable .[dev] Note that `python -m` can be omitted most of the times, but within virtualenvs, it can prevent certain errors. Also, in certain terminals (such as `zsh`), the square brackets must be escaped, e.g. replace `.[dev]` by `.\[dev\]`. -In addition to `numpy`, `scipy` and `ruptures`, this command will install all packages needed to develop `ruptures`. +In addition to `numpy`, `scipy`, `scikit-learn` and `ruptures`, this command will install all packages needed to develop `ruptures`. The exact list of librairies can be found in the [`setup.cfg` file](https://github.com/deepcharles/ruptures/blob/master/setup.cfg) (section `[options.extras_require]`). ### Pre-commit hooks diff --git a/docs/install.md b/docs/install.md index 48234181..021f82e3 100644 --- a/docs/install.md +++ b/docs/install.md @@ -1,6 +1,6 @@ # Installation -This library requires Python >=3.6 and the following packages: `numpy`, `scipy` and `matplotlib` (the last one is optional and only for display purposes). +This library requires Python >=3.6 and the following packages: `numpy`, `scipy`, `scikit-learn` and `matplotlib` (the last one is optional and only for display purposes). You can either install the latest stable release or the development version. ## Stable release diff --git a/pyproject.toml b/pyproject.toml index baaf717e..fae65be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ requires = [ "setuptools_scm[toml]>=3.4", # https://scikit-hep.org/developer/packaging#git-tags-official-pypa-method "oldest-supported-numpy", # https://github.com/scipy/oldest-supported-numpy "scipy>=0.19.1", + "scikit-learn>=1.0", ] build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg index a921da57..ca5a9bd8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,6 +35,7 @@ python_requires = >= 3.6 install_requires = numpy scipy + scikit-learn packages = find: package_dir = =src diff --git a/src/ruptures/metrics/__init__.py b/src/ruptures/metrics/__init__.py index d75ed187..f96fdf45 100644 --- a/src/ruptures/metrics/__init__.py +++ b/src/ruptures/metrics/__init__.py @@ -21,3 +21,4 @@ from .precisionrecall import precision_recall from .hamming import hamming from .randindex import randindex +from .adjusted_randindex import adjusted_randindex diff --git a/src/ruptures/metrics/adjusted_randindex.py b/src/ruptures/metrics/adjusted_randindex.py new file mode 100644 index 00000000..92f81517 --- /dev/null +++ b/src/ruptures/metrics/adjusted_randindex.py @@ -0,0 +1,41 @@ +r"""Adjusted Rand index (`adjusted_randindex`)""" +import numpy as np +from ruptures.metrics.sanity_check import sanity_check +from sklearn.metrics import adjusted_rand_score + + +def chpt_to_label(bkps): + """Return the segment index each sample belongs to. + + Example: + ------- + >>> chpt_to_label([4, 10]) + array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + """ + duration = np.diff([0] + bkps) + return np.repeat(np.arange(len(bkps)), duration) + + +def adjusted_randindex(bkps1, bkps2): + """Compute the adjusted Rand index (between -0.5 and 1.) between two + segmentations. + + The Rand index (RI) measures the similarity between two segmentations and + is equal to the proportion of aggreement between two partitions. + + The metric implemented here is RI variant, adjusted for chance, and based + on [scikit-learn's implementation](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.adjusted_rand_score.html). + + Args: + ---- + bkps1 (list): sorted list of the last index of each regime. + bkps2 (list): sorted list of the last index of each regime. + + Return: + ------ + float: Adjusted Rand index + """ # noqa E501 + sanity_check(bkps1, bkps2) + label1 = chpt_to_label(bkps1) + label2 = chpt_to_label(bkps2) + return adjusted_rand_score(label1, label2) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index b9cca471..14aa6d33 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -6,6 +6,7 @@ meantime, precision_recall, randindex, + adjusted_randindex, ) from ruptures.metrics.sanity_check import BadPartitions @@ -31,6 +32,14 @@ def test_randindex(b_mb): assert m == 1 +def test_adjusted_randindex(b_mb): + b, mb = b_mb + m = adjusted_randindex(b, mb) + assert 1 > m > -0.5 + m = adjusted_randindex(b, b) + assert m == 1 + + def test_meantime(b_mb): b, mb = b_mb m = meantime(b, mb)