From ad2a73c18dcff95d844c382c94ab7f73b5571cf3 Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Tue, 4 May 2021 17:43:26 -0500 Subject: [PATCH] MAINT: Adjust NumPy float hashing to Python's slightly changed hash This is necessary, since we use the Python double hash and the semi-private function to calculate it in Python has a new signature to return the identity-hash when the value is NaN. closes gh-18833, gh-18907 --- numpy/core/src/common/npy_pycompat.h | 16 ++++++++++ numpy/core/src/multiarray/scalartypes.c.src | 13 ++++---- numpy/core/tests/test_scalarmath.py | 34 +++++++++++++++++++++ 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/numpy/core/src/common/npy_pycompat.h b/numpy/core/src/common/npy_pycompat.h index aa0b5c1224d..9e94a971090 100644 --- a/numpy/core/src/common/npy_pycompat.h +++ b/numpy/core/src/common/npy_pycompat.h @@ -3,4 +3,20 @@ #include "numpy/npy_3kcompat.h" + +/* + * In Python 3.10a7 (or b1), python started using the identity for the hash + * when a value is NaN. See https://bugs.python.org/issue43475 + */ +#if PY_VERSION_HEX > 0x030a00a6 +#define Npy_HashDouble _Py_HashDouble +#else +static NPY_INLINE Py_hash_t +Npy_HashDouble(PyObject *NPY_UNUSED(identity), double val) +{ + return _Py_HashDouble(val); +} +#endif + + #endif /* _NPY_COMPAT_H_ */ diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index a001500b0a9..9930f7791d6 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -3172,7 +3172,7 @@ static npy_hash_t static npy_hash_t @lname@_arrtype_hash(PyObject *obj) { - return _Py_HashDouble((double) PyArrayScalar_VAL(obj, @name@)); + return Npy_HashDouble(obj, (double)PyArrayScalar_VAL(obj, @name@)); } /* borrowed from complex_hash */ @@ -3180,14 +3180,14 @@ static npy_hash_t c@lname@_arrtype_hash(PyObject *obj) { npy_hash_t hashreal, hashimag, combined; - hashreal = _Py_HashDouble((double) - PyArrayScalar_VAL(obj, C@name@).real); + hashreal = Npy_HashDouble( + obj, (double)PyArrayScalar_VAL(obj, C@name@).real); if (hashreal == -1) { return -1; } - hashimag = _Py_HashDouble((double) - PyArrayScalar_VAL(obj, C@name@).imag); + hashimag = Npy_HashDouble( + obj, (double)PyArrayScalar_VAL(obj, C@name@).imag); if (hashimag == -1) { return -1; } @@ -3202,7 +3202,8 @@ c@lname@_arrtype_hash(PyObject *obj) static npy_hash_t half_arrtype_hash(PyObject *obj) { - return _Py_HashDouble(npy_half_to_double(PyArrayScalar_VAL(obj, Half))); + return Npy_HashDouble( + obj, npy_half_to_double(PyArrayScalar_VAL(obj, Half))); } static npy_hash_t diff --git a/numpy/core/tests/test_scalarmath.py b/numpy/core/tests/test_scalarmath.py index d91b4a39146..09a734284a7 100644 --- a/numpy/core/tests/test_scalarmath.py +++ b/numpy/core/tests/test_scalarmath.py @@ -712,6 +712,40 @@ def test_shift_all_bits(self, type_code, op): assert_equal(res_arr, res_scl) +class TestHash: + @pytest.mark.parametrize("type_code", np.typecodes['AllInteger']) + def test_integer_hashes(self, type_code): + scalar = np.dtype(type_code).type + for i in range(128): + assert hash(i) == hash(scalar(i)) + + @pytest.mark.parametrize("type_code", np.typecodes['AllFloat']) + def test_float_and_complex_hashes(self, type_code): + scalar = np.dtype(type_code).type + for val in [np.pi, np.inf, 3, 6.]: + numpy_val = scalar(val) + # Cast back to Python, in case the NumPy scalar has less precision + if numpy_val.dtype.kind == 'c': + val = complex(numpy_val) + else: + val = float(numpy_val) + assert val == numpy_val + print(repr(numpy_val), repr(val)) + assert hash(val) == hash(numpy_val) + + if hash(float(np.nan)) != hash(float(np.nan)): + # If Python distinguises different NaNs we do so too (gh-18833) + assert hash(scalar(np.nan)) != hash(scalar(np.nan)) + + @pytest.mark.parametrize("type_code", np.typecodes['Complex']) + def test_complex_hashes(self, type_code): + # Test some complex valued hashes specifically: + scalar = np.dtype(type_code).type + for val in [np.pi+1j, np.inf-3j, 3j, 6.+1j]: + numpy_val = scalar(val) + assert hash(complex(numpy_val)) == hash(numpy_val) + + @contextlib.contextmanager def recursionlimit(n): o = sys.getrecursionlimit()