Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions hail/hail/src/is/hail/expr/ir/functions/MathFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -425,16 +425,15 @@ object MathFunctions extends RegistryFunctions {
fetStruct.virtualType,
(_, _, _, _, _) => fetStruct.sType,
) { case (r, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) =>
val res = cb.newLocal[Array[Double]](
"fisher_exact_test_res",
val res = cb.memoize[Array[Double]](
Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]](
statsPackageClass,
"fisherExactTest",
a.value,
b.value,
c.value,
d.value,
),
)
)

fetStruct.constructFromFields(
Expand All @@ -450,6 +449,29 @@ object MathFunctions extends RegistryFunctions {
)
}

// FIXME: delete when PruneDeadField can optimize fisher_exact_test when only
// the pvalue is used from the result struct
registerSCode4(
"fisher_exact_test_pvalue_only",
TInt32,
TInt32,
TInt32,
TInt32,
TFloat64,
(_, _, _, _, _) => SFloat64,
) { case (_, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) =>
primitive(cb.memoize[Double](
Code.invokeScalaObject4[Int, Int, Int, Int, Double](
statsPackageClass,
"fisherExactTestPValueOnly",
a.value,
b.value,
c.value,
d.value,
)
))
}

registerSCode4(
"chi_squared_test",
TInt32,
Expand Down
113 changes: 68 additions & 45 deletions hail/hail/src/is/hail/stats/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package is.hail
import is.hail.types.physical.{PCanonicalStruct, PFloat64}
import is.hail.utils._

import scala.annotation.tailrec

import net.sourceforge.jdistlib.{Beta, ChiSquare, Gamma, NonCentralChiSquare, Normal, Poisson}
import net.sourceforge.jdistlib.disttest.{DistributionTest, TestKind}
import org.apache.commons.math3.distribution.HypergeometricDistribution
Expand Down Expand Up @@ -162,45 +164,28 @@ package object stats {
)

def fisherExactTest(a: Int, b: Int, c: Int, d: Int): Array[Double] =
fisherExactTest(a, b, c, d, 1.0, 0.95, "two.sided")

def fisherExactTest(
a: Int,
b: Int,
c: Int,
d: Int,
oddsRatio: Double = 1d,
confidenceLevel: Double = 0.95,
alternative: String = "two.sided",
): Array[Double] = {
fisherExactTest(a, b, c, d, 0.95)

def fisherExactTest(a: Int, b: Int, c: Int, d: Int, confidenceLevel: Double): Array[Double] = {
if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0))
fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d")

if (confidenceLevel < 0d || confidenceLevel > 1d)
fatal("Confidence level must be between 0 and 1")

if (oddsRatio < 0d)
fatal("Odds ratio must be non-negative")

if (alternative != "greater" && alternative != "less" && alternative != "two.sided")
fatal("Did not recognize test type string. Use one of greater, less, two.sided")

val popSize = a + b + c + d
val numSuccessPopulation = a + c
val sampleSize = a + b
val nGood = a + c
val nSample = a + b
val numSuccessSample = a

if (
!(popSize > 0 && sampleSize > 0 && sampleSize < popSize && numSuccessPopulation > 0 && numSuccessPopulation < popSize)
)
if (!(popSize > 0 && nSample > 0 && nSample < popSize && nGood > 0 && nGood < popSize))
return Array(Double.NaN, Double.NaN, Double.NaN, Double.NaN)

val low = math.max(0, (a + b) - (b + d))
val high = math.min(a + b, a + c)
val support = (low to high).toArray

val hgd = new HypergeometricDistribution(null, popSize, numSuccessPopulation, sampleSize)
val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample)
val epsilon = 2.220446e-16

def dhyper(k: Int, logProb: Boolean): Double =
Expand Down Expand Up @@ -320,36 +305,74 @@ package object stats {
}
}

val pvalue: Double = (alternative: @unchecked) match {
case "less" => pnhyper(numSuccessSample, oddsRatio)
case "greater" => pnhyper(numSuccessSample, oddsRatio, upper_tail = true)
case "two.sided" =>
if (oddsRatio == 0)
if (low == numSuccessSample) 1d else 0d
else if (oddsRatio == Double.PositiveInfinity)
if (high == numSuccessSample) 1d else 0d
else {
val relErr = 1d + 1e-7
val d = dnhyper(oddsRatio)
d.filter(_ <= d(numSuccessSample - low) * relErr).sum
}
}

assert(pvalue >= 0d && pvalue <= 1.000000000002)
val pvalue = fisherExactTestPValueOnly(a, b, c, d)

val oddsRatioEstimate = mle(numSuccessSample.toDouble)

val confInterval = alternative match {
case "less" => (0d, ncpUpper(numSuccessSample, 1 - confidenceLevel))
case "greater" => (ncpLower(numSuccessSample, 1 - confidenceLevel), Double.PositiveInfinity)
case "two.sided" =>
val alpha = (1 - confidenceLevel) / 2d
(ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha))
val confInterval = {
val alpha = (1 - confidenceLevel) / 2d
(ncpLower(numSuccessSample, alpha), ncpUpper(numSuccessSample, alpha))
}

Array(pvalue, oddsRatioEstimate, confInterval._1, confInterval._2)
}

def fisherExactTestPValueOnly(a: Int, b: Int, c: Int, d: Int): Double = {
val popSize = a + b + c + d
val nGood = a + c
val nSample = a + b
val numSuccessSample = a

if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0))
fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d")

if (!(popSize > 0 && nSample > 0 && nSample < popSize && nGood > 0 && nGood < popSize))
return Double.NaN

val hgd = new HypergeometricDistribution(null, popSize, nGood, nSample)

// Returns i in [start, end] such that a([start, i)) is <= d, and a([i, end)) is > d
@tailrec def upperBoundIncreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = {
if (start >= end) return start
val mid = (start + end) >>> 1
val elt = a(mid)
if (elt <= d) upperBoundIncreasing(a, d, mid + 1, end)
else upperBoundIncreasing(a, d, start, mid)
}

// Returns i in [start, end] such that a([start, i)) is > d, and a([i, end)) is <= d
@tailrec def lowerBoundDecreasing(a: Int => Double, d: Double, start: Int, end: Int): Int = {
if (start >= end) return start
val mid = (start + end) >>> 1
val elt = a(mid)
if (elt > d) lowerBoundDecreasing(a, d, mid + 1, end)
else lowerBoundDecreasing(a, d, start, mid)
}

val epsilon = 1e-14
val gamma = 1 + epsilon

val mode = ((nSample + 1.0) * (nGood + 1.0) / (popSize + 2.0)).toInt
val pexact = hgd.probability(numSuccessSample)
val pmode = hgd.probability(mode)

val pvalue = if (math.abs(pexact - pmode) / math.max(pexact, pmode) <= epsilon) {
1.0
} else if (numSuccessSample < mode) {
val plower = hgd.cumulativeProbability(numSuccessSample)
val bound = lowerBoundDecreasing(hgd.probability, pexact * gamma, mode + 1, nSample + 1)
plower + hgd.upperCumulativeProbability(bound)
} else {
val pupper = hgd.upperCumulativeProbability(numSuccessSample)
val bound = upperBoundIncreasing(hgd.probability, pexact * gamma, 0, mode)
pupper + hgd.cumulativeProbability(bound - 1)
}

assert(pvalue >= 0d && pvalue <= 1.000000000002)

pvalue
}

def dnorm(x: Double, mu: Double, sigma: Double, logP: Boolean): Double =
Normal.density(x, mu, sigma, logP)

Expand Down
57 changes: 53 additions & 4 deletions hail/hail/test/src/is/hail/stats/FisherExactTestSuite.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package is.hail.stats

import is.hail.HailSuite
import is.hail.utils.D_==

import org.testng.annotations.Test

Expand All @@ -14,9 +15,57 @@ class FisherExactTestSuite extends HailSuite {

val result = fisherExactTest(a, b, c, d)

assert(math.abs(result(0) - 0.2828) < 1e-4)
assert(math.abs(result(1) - 0.4754059) < 1e-4)
assert(math.abs(result(2) - 0.122593) < 1e-4)
assert(math.abs(result(3) - 1.597972) < 1e-4)
assert(D_==(result(0), 0.2828, 1e-3))
assert(D_==(result(1), 0.4754059, 1e-4))
assert(D_==(result(2), 0.122593, 1e-4))
assert(D_==(result(3), 1.597972, 1e-4))
}

@Test def testPvalue2(): Unit = {
val a = 10
val b = 5
val c = 90
val d = 95

val result = fisherExactTest(a, b, c, d)

assert(D_==(result(0), 0.2828, 1e-3))
}

@Test def test_basic(): Unit = {
// test cases taken from scipy/stats/tests/test_stats.py
var res = fisherExactTestPValueOnly(14500, 20000, 30000, 40000)
assert(D_==(res, 0.01106, 1e-3))
res = fisherExactTestPValueOnly(100, 2, 1000, 5)
assert(D_==(res, 0.1301, 1e-3))
res = fisherExactTestPValueOnly(2, 7, 8, 2)
assert(D_==(res, 0.0230141, 1e-5))
res = fisherExactTestPValueOnly(5, 1, 10, 10)
assert(D_==(res, 0.1973244, 1e-6))
res = fisherExactTestPValueOnly(5, 15, 20, 20)
assert(D_==(res, 0.0958044, 1e-6))
res = fisherExactTestPValueOnly(5, 16, 20, 25)
assert(D_==(res, 0.1725862, 1e-5))
res = fisherExactTestPValueOnly(10, 5, 10, 1)
assert(D_==(res, 0.1973244, 1e-6))
res = fisherExactTestPValueOnly(5, 0, 1, 4)
assert(D_==(res, 0.04761904, 1e-6))
res = fisherExactTestPValueOnly(0, 1, 3, 2)
assert(res == 1.0)
res = fisherExactTestPValueOnly(0, 2, 6, 4)
assert(D_==(res, 0.4545454545))
res = fisherExactTestPValueOnly(2, 7, 8, 2)
assert(D_==(res, 0.0230141, 1e-5))

res = fisherExactTestPValueOnly(6, 37, 108, 200)
assert(D_==(res, 0.005092697748126))
res = fisherExactTestPValueOnly(22, 0, 0, 102)
assert(D_==(res, 7.175066786244549e-25))
res = fisherExactTestPValueOnly(94, 48, 3577, 16988)
assert(D_==(res, 2.069356340993818e-37))
res = fisherExactTestPValueOnly(5829225, 5692693, 5760959, 5760959)
assert(res <= 1e-170)
for ((a, b, c, d) <- Array((0, 0, 5, 10), (5, 10, 0, 0), (0, 5, 0, 10), (5, 0, 10, 0)))
assert(fisherExactTestPValueOnly(a, b, c, d).isNaN)
}
}
17 changes: 10 additions & 7 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def contingency_table_test(c1, c2, c3, c4, min_cell_count) -> StructExpression:
Struct(p_value=1.4626257805267089e-07, odds_ratio=4.959830866807611)

>>> hl.eval(hl.contingency_table_test(51, 43, 22, 92, min_cell_count=23))
Struct(p_value=2.1564999740157304e-07, odds_ratio=4.918058171469967)
Struct(p_value=2.156499974015729e-07, odds_ratio=4.918058171469967)

Notes
-----
Expand Down Expand Up @@ -1133,20 +1133,20 @@ def exp(x) -> Float64Expression:
return _func("exp", tfloat64, x)


@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32)
def fisher_exact_test(c1, c2, c3, c4) -> StructExpression:
@typecheck(c1=expr_int32, c2=expr_int32, c3=expr_int32, c4=expr_int32, _pvalue_only=bool)
def fisher_exact_test(c1, c2, c3, c4, *, _pvalue_only=False) -> StructExpression:
"""Calculates the p-value, odds ratio, and 95% confidence interval using
Fisher's exact test for a 2x2 table.

Examples
--------

>>> hl.eval(hl.fisher_exact_test(10, 10, 10, 10))
Struct(p_value=1.0000000000000002, odds_ratio=1.0,
Struct(p_value=1.0, odds_ratio=1.0,
ci_95_lower=0.24385796914260355, ci_95_upper=4.100747675033819)

>>> hl.eval(hl.fisher_exact_test(51, 43, 22, 92))
Struct(p_value=2.1564999740157304e-07, odds_ratio=4.918058171469967,
Struct(p_value=2.156499974015729e-07, odds_ratio=4.918058171469967,
ci_95_lower=2.5659373368248444, ci_95_upper=9.677929632035475)

Notes
Expand Down Expand Up @@ -1176,8 +1176,11 @@ def fisher_exact_test(c1, c2, c3, c4) -> StructExpression:
`ci_95_lower (:py:data:`.tfloat64`), and `ci_95_upper`
(:py:data:`.tfloat64`).
"""
ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64)
return _func("fisher_exact_test", ret_type, c1, c2, c3, c4)
if _pvalue_only:
return struct(p_value=_func("fisher_exact_test_pvalue_only", tfloat64, c1, c2, c3, c4))
else:
ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64)
return _func("fisher_exact_test", ret_type, c1, c2, c3, c4)


@typecheck(x=expr_oneof(expr_float32, expr_float64, expr_ndarray(expr_float64)))
Expand Down