๐Ÿงช testax#

https://img.shields.io/pypi/v/testax https://github.com/tillahoffmann/testax/actions/workflows/build.yml/badge.svg https://readthedocs.org/projects/testax/badge/?version=latest

testax provides runtime assertions for JAX through the testing interface familiar to NumPy users.

>>> import jax
>>> from jax import numpy as jnp
>>> import testax
>>>
>>> def safe_log(x):
...     testax.assert_array_less(0, x)
...     return jnp.log(x)
>>>
>>> safe_log(jnp.arange(2))
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered

Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
 x: Array(0, dtype=int32, weak_type=True)
 y: Array([0, 1], dtype=int32)

testax assertions are jit-able, although errors need to be functionalized to conform to JAXโ€™s requirement that functions are pure and do not have side effects (see the checkify guide for details). In short, a checkify-d function returns a tuple (error, value). The first element is an error that may have occurred, and the second is the return value of the original function.

>>> jitted = jax.jit(safe_log)
>>> checkified = testax.checkify(jitted)
>>> error, y = checkified(jnp.arange(2))
>>> error.throw()
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered

Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
 x: Array(0, dtype=int32, weak_type=True)
 y: Array([0, 1], dtype=int32)
>>> y
Array([-inf,   0.], dtype=float32)

Installation#

testax is pip-installable and can be installed by running

pip install testax

Interface#

testax mirrors the testing interface familiar to NumPy users, such as assert_allclose.

Interface#

testax.assert_allclose(actual: Array, desired: Array, rtol: float = 1e-07, atol: float = 0, equal_nan: bool = True, err_msg: str = '', verbose: bool = True, *, debug: bool = False) None#

Raises an AssertionError if two objects are not equal up to desired tolerance.

Given two Arrays, check that their shapes and all elements are equal (but see the Notes for the special handling of a scalar). An exception is raised if the shapes mismatch or any values conflict. In contrast to the standard usage in numpy, nans are compared like numbers, no assertion is raised if both objects have nans in the same positions.

The test is equivalent to allclose(actual, desired, rtol, atol) (note that allclose() has different default values). It compares the difference between actual and desired to atol + rtol * abs(desired).

Parameters:
  • actual โ€“ Array obtained.

  • desired โ€“ Array desired.

  • rtol โ€“ Relative tolerance.

  • atol โ€“ Absolute tolerance.

  • equal_nan โ€“ If True, nans will compare equal.

  • err_msg โ€“ The error message to be printed in case of failure.

  • verbose โ€“ If True, the conflicting values are appended to the error message.

Raises:

AssertionError โ€“ If actual and desired are not equal up to specified precision.

See also

assert_array_almost_equal_nulp(), assert_array_max_ulp()

Notes

When one of actual and desired is a scalar and the other is an Array, the function checks that each element of the Array is equal to the scalar.

Examples

>>> x = jnp.asarray([1e-5, 1e-3, 1e-1])
>>> y = jnp.arccos(jnp.cos(x))
>>> testax.assert_allclose(x, y, atol=1e-4)
testax.assert_array_almost_equal(x: Array, y: Array, decimal: int = 6, err_msg: str = '', verbose: bool = True, *, debug: bool = False) None#

Raises an AssertionError if two Arrays are not equal up to the desired precision.

Note

It is recommended to use one of assert_allclose(), assert_array_almost_equal_nulp() or assert_array_max_ulp() instead of this function for more consistent floating point comparisons.

The test verifies identical shapes and that the elements of actual and desired satisfy abs(desired - actual) < 1.5 * 10 ** - decimal.

An exception is raised at shape mismatch or conflicting values. In contrast to the standard usage in numpy, nans are compared like numbers, no assertion is raised if both objects have nans in the same positions.

Parameters:
  • x โ€“ The actual object to check.

  • y โ€“ The desired, expected object.

  • decimal โ€“ Desired precision, default is 6.

  • err_msg โ€“ The error message to be printed in case of failure.

  • verbose โ€“ If True, the conflicting values are appended to the error message.

Raises:

AssertionError โ€“ If actual and desired are not equal up to specified precision.

See also

assert_allclose(), assert_array_almost_equal_nulp(), assert_array_max_ulp(), assert_equal()

Examples

The first assert does not raise an exception

>>> testax.assert_array_almost_equal([1.0, 2.333, jnp.nan],
...                                  [1.0, 2.333, jnp.nan])
>>> testax.assert_array_almost_equal([1.0, 2.33333, jnp.nan],
...                                  [1.0, 2.33339, jnp.nan], decimal=5)
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not almost equal to 5 decimals

Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 6.0081e-05
Max relative difference: 2.5749e-05
 x: Array([1.     , 2.33333,     nan], dtype=float32)
 y: Array([1.     , 2.33339,     nan], dtype=float32)
>>> testax.assert_array_almost_equal([1.0,2.33333,jnp.nan],
...                                      [1.0,2.33333, 5], decimal=5)
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not almost equal to 5 decimals

x and y nan location mismatch:
 x: Array([1.     , 2.33333,     nan], dtype=float32)
 y: Array([1.     , 2.33333, 5.     ], dtype=float32)
testax.assert_array_equal(x: Array, y: Array, err_msg: str = '', verbose: bool = True, *, strict: bool = False, debug: bool = False) None#

Raises an AssertionError if two Arrays are not equal.

Given two Arrays, check that the shape is equal and all elements of these objects are equal (but see the Notes for the special handling of a scalar). An exception is raised at shape mismatch or conflicting values. In contrast to the standard usage in numpy, nans are compared like numbers, no assertion is raised if both objects have nans in the same positions.

The usual caution for verifying equality with floating point numbers is advised.

Parameters:
  • x โ€“ The actual object to check.

  • y โ€“ The desired, expected object.

  • err_msg โ€“ The error message to be printed in case of failure.

  • verbose โ€“ If True, the conflicting values are appended to the error message.

  • strict โ€“ If True, raise an AssertionError when either the shape or the data type of the Arrays does not match. The special handling for scalars mentioned in the Notes section is disabled.

Raises:

AssertionError โ€“ If actual and desired objects are not equal.

See also

assert_allclose(), assert_array_almost_equal_nulp(), assert_array_max_ulp(), assert_equal()

Notes

When one of x and y is a scalar and the other is an Array, the function checks that each element of the Array is equal to the scalar. This behaviour can be disabled with the strict parameter.

Examples

The first assert does not raise an exception:

>>> testax.assert_array_equal([1.0, 2.33333, jnp.nan],
...                           [jnp.exp(0), 2.33333, jnp.nan])

Assert fails with numerical imprecision with floats:

>>> testax.assert_array_equal([1.0, 1e-5, jnp.nan],
...                           [1, jnp.arccos(jnp.cos(1e-5)), jnp.nan])
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not equal

Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 1e-05
Max relative difference: 0
 x: Array([1.e+00, 1.e-05,    nan], dtype=float32)
 y: Array([ 1.,  0., nan], dtype=float32)

Use assert_allclose() or one of the nulp (number of floating point values) functions for these cases instead:

>>> testax.assert_allclose([1.0, 1e-5, jnp.nan],
...                        [1, jnp.arccos(jnp.cos(1e-5)), jnp.nan], atol=1e-5)

As mentioned in the Notes section, assert_array_equal() has special handling for scalars. Here the test checks that each value in x is 3:

>>> x = jnp.full((2, 5), fill_value=3)
>>> testax.assert_array_equal(x, 3)

Use strict to raise an AssertionError when comparing a scalar with an array:

>>> testax.assert_array_equal(x, 3, strict=True)
Traceback (most recent call last):
    ...
testax.TestaxError:
Arrays are not equal

(shapes (2, 5), () mismatch)
 x: Array([[3, 3, 3, 3, 3],
           [3, 3, 3, 3, 3]], dtype=int32, weak_type=True)
 y: Array(3, dtype=int32, weak_type=True)

The strict parameter also ensures that the array data types match:

>>> x = jnp.array([2, 2, 2])
>>> y = jnp.array([2., 2., 2.], dtype=jnp.float32)
>>> testax.assert_array_equal(x, y, strict=True)
Traceback (most recent call last):
    ...
testax.TestaxError:
Arrays are not equal

(dtypes int32, float32 mismatch)
 x: Array([2, 2, 2], dtype=int32)
 y: Array([2., 2., 2.], dtype=float32)
testax.assert_array_less(x: Array, y: Array, err_msg: str = '', verbose: bool = True, *, debug: bool = False) None#

Raises an AssertionError if two Arrays are not ordered by less than.

Given two Arrays, check that the shape is equal and all elements of the first array are strictly smaller than those of the second. An exception is raised at shape mismatch or incorrectly ordered values. Shape mismatch does not raise if an object has zero dimension. In contrast to the standard usage in numpy, nans are compared, no assertion is raised if both objects have nans in the same positions.

Parameters:
  • x โ€“ The smaller object to check.

  • y โ€“ The larger object to compare.

  • err_msg โ€“ The error message to be printed in case of failure.

  • verbose โ€“ If True, the conflicting values are appended to the error message.

Raises:

AssertionError โ€“ If x is not strictly smaller than y, element-wise.

Examples

>>> testax.assert_array_less([1.0, 1.0, jnp.nan], [1.1, 2.0, jnp.nan])
>>> testax.assert_array_less([1.0, 1.0, jnp.nan], [1, 2.0, jnp.nan])
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered

Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 1
Max relative difference: 0.5
 x: Array([ 1.,  1., nan], dtype=float32)
 y: Array([ 1.,  2., nan], dtype=float32)
>>> testax.assert_array_less([1.0, 4.0], 3)
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered

Mismatched elements: 1 / 2 (50%)
Max absolute difference: 2
Max relative difference: 0.666667
 x: Array([1., 4.], dtype=float32)
 y: Array(3, dtype=int32, weak_type=True)
>>> testax.assert_array_less([1.0, 2.0, 3.0], [4])
Traceback (most recent call last):
    ...
testax.TestaxError:
Arrays are not less-ordered

(shapes (3,), (1,) mismatch)
 x: Array([1., 2., 3.], dtype=float32)
 y: Array([4], dtype=int32)
testax.checkify(func: Callable[[...], T], errors: frozenset[Type[JaxException]] | None = None) Callable[[...], Tuple[Error, T]]#

Functionalize testax assertions and jax.experimental.checkify.check() calls in func, and optionally adds run-time error checks. This function has the same behavior as jax.experimental.checkify.checkify() except it ensures testax errors are handled properly. See the jax.experimental.checkify.checkify() documentation for details.

Parameters:
Returns:

A function which accepts the same arguments as func and returns as output a tuple where the first element is an jax.experimental.checkify.Error value, representing the first failed testax assertion or jax.experimental.checkify.check(), and the second element is the original output of func.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> import testax
>>>
>>> @jax.jit
... def f(x):
...     testax.assert_array_less(0, x)
...     return jnp.log(x)
>>>
>>> err, out = testax.checkify(f)(jnp.arange(2))
>>> err.throw()
Traceback (most recent call last):
    ...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered

Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
 x: Array(0, dtype=int32, weak_type=True)
 y: Array([0, 1], dtype=int32)