From 87c4932ca0e4dbe10f918bc5e8096e30fc6557e7 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Mon, 29 Jul 2024 17:57:11 +0100 Subject: [PATCH 1/2] BUG: stats: adapt to `np.floor` type promotion removal `rv_discrete._cdf` relied on `np.floor` promoting its integer input to `np.float64`. This is no longer the case since numpy/numpy#26766. [skip cirrus] [skip circle] --- scipy/stats/_distn_infrastructure.py | 2 +- scipy/stats/tests/test_discrete_basic.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/scipy/stats/_distn_infrastructure.py b/scipy/stats/_distn_infrastructure.py index a7a0bb66b396..83954af2c11f 100644 --- a/scipy/stats/_distn_infrastructure.py +++ b/scipy/stats/_distn_infrastructure.py @@ -3392,7 +3392,7 @@ def _cdf_single(self, k, *args): return np.sum(self._pmf(m, *args), axis=0) def _cdf(self, x, *args): - k = floor(x) + k = floor(x).astype(np.float64) return self._cdfvec(k, *args) # generic _logcdf, _sf, _logsf, _ppf, _isf, _rvs defined in rv_generic diff --git a/scipy/stats/tests/test_discrete_basic.py b/scipy/stats/tests/test_discrete_basic.py index 1ebc9371c075..3db2f0666dbd 100644 --- a/scipy/stats/tests/test_discrete_basic.py +++ b/scipy/stats/tests/test_discrete_basic.py @@ -549,3 +549,15 @@ def test_rv_sample(): rng = np.random.default_rng(98430143469) rvs0 = dist.ppf(rng.random(size=100)) assert_allclose(rvs, rvs0) + +def test__pmf_float_input(): + # gh-21272 + # test that `rvs()` can be computed when `_pmf` requires float input + + class rv_exponential(stats.rv_discrete): + def _pmf(self, i): + return (2/3)*3**(1 - i) + + rv = rv_exponential(a=0.0, b=float('inf')) + rvs = rv.rvs() # should not crash due to integer input to `_pmf` + assert_allclose(rvs, 0) From fe924edb6564358df24de8b915861d6754c6e94d Mon Sep 17 00:00:00 2001 From: Matt Haberland Date: Sat, 10 Aug 2024 12:15:56 -0700 Subject: [PATCH 2/2] Update scipy/stats/tests/test_discrete_basic.py [skip ci] --- scipy/stats/tests/test_discrete_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scipy/stats/tests/test_discrete_basic.py b/scipy/stats/tests/test_discrete_basic.py index 3db2f0666dbd..2bc53d95228a 100644 --- a/scipy/stats/tests/test_discrete_basic.py +++ b/scipy/stats/tests/test_discrete_basic.py @@ -559,5 +559,5 @@ def _pmf(self, i): return (2/3)*3**(1 - i) rv = rv_exponential(a=0.0, b=float('inf')) - rvs = rv.rvs() # should not crash due to integer input to `_pmf` + rvs = rv.rvs(random_state=42) # should not crash due to integer input to `_pmf` assert_allclose(rvs, 0)