๐งช testax#
testax provides runtime assertions for JAX through the testing interface familiar to NumPy users.
>>> import jax
>>> from jax import numpy as jnp
>>> import testax
>>>
>>> def safe_log(x):
... testax.assert_array_less(0, x)
... return jnp.log(x)
>>>
>>> safe_log(jnp.arange(2))
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered
Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
x: Array(0, dtype=int32, weak_type=True)
y: Array([0, 1], dtype=int32)
testax assertions are jit
-able, although errors need to be functionalized to conform to JAXโs requirement that functions are pure and do not have side effects (see the checkify
guide for details). In short, a checkify
-d function returns a tuple (error, value)
. The first element is an error that may have occurred, and the second is the return value of the original function.
>>> jitted = jax.jit(safe_log)
>>> checkified = testax.checkify(jitted)
>>> error, y = checkified(jnp.arange(2))
>>> error.throw()
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-ordered
Mismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
x: Array(0, dtype=int32, weak_type=True)
y: Array([0, 1], dtype=int32)
>>> y
Array([-inf, 0.], dtype=float32)
Installation#
testax is pip-installable and can be installed by running
pip install testax
Interface#
testax mirrors the testing interface familiar to NumPy users, such as assert_allclose
.
Interface#
- testax.assert_allclose(actual: Array, desired: Array, rtol: float = 1e-07, atol: float = 0, equal_nan: bool = True, err_msg: str = '', verbose: bool = True, *, debug: bool = False) None #
Raises an AssertionError if two objects are not equal up to desired tolerance.
Given two
Array
s, check that their shapes and all elements are equal (but see the Notes for the special handling of a scalar). An exception is raised if the shapes mismatch or any values conflict. In contrast to the standard usage in numpy,nan
s are compared like numbers, no assertion is raised if both objects havenan
s in the same positions.The test is equivalent to
allclose(actual, desired, rtol, atol)
(note thatallclose()
has different default values). It compares the difference betweenactual
anddesired
toatol + rtol * abs(desired)
.- Parameters:
actual โ Array obtained.
desired โ Array desired.
rtol โ Relative tolerance.
atol โ Absolute tolerance.
equal_nan โ If True,
nan
s will compare equal.err_msg โ The error message to be printed in case of failure.
verbose โ If True, the conflicting values are appended to the error message.
- Raises:
AssertionError โ If actual and desired are not equal up to specified precision.
See also
assert_array_almost_equal_nulp()
,assert_array_max_ulp()
Notes
When one of
actual
anddesired
is a scalar and the other is anArray
, the function checks that each element of theArray
is equal to the scalar.Examples
>>> x = jnp.asarray([1e-5, 1e-3, 1e-1]) >>> y = jnp.arccos(jnp.cos(x)) >>> testax.assert_allclose(x, y, atol=1e-4)
- testax.assert_array_almost_equal(x: Array, y: Array, decimal: int = 6, err_msg: str = '', verbose: bool = True, *, debug: bool = False) None #
Raises an
AssertionError
if twoArray
s are not equal up to the desired precision.Note
It is recommended to use one of
assert_allclose()
,assert_array_almost_equal_nulp()
orassert_array_max_ulp()
instead of this function for more consistent floating point comparisons.The test verifies identical shapes and that the elements of
actual
anddesired
satisfyabs(desired - actual) < 1.5 * 10 ** - decimal
.An exception is raised at shape mismatch or conflicting values. In contrast to the standard usage in numpy,
nan
s are compared like numbers, no assertion is raised if both objects havenan
s in the same positions.- Parameters:
x โ The actual object to check.
y โ The desired, expected object.
decimal โ Desired precision, default is 6.
err_msg โ The error message to be printed in case of failure.
verbose โ If True, the conflicting values are appended to the error message.
- Raises:
AssertionError โ If
actual
anddesired
are not equal up to specifiedprecision
.
See also
assert_allclose()
,assert_array_almost_equal_nulp()
,assert_array_max_ulp()
,assert_equal()
Examples
The first assert does not raise an exception
>>> testax.assert_array_almost_equal([1.0, 2.333, jnp.nan], ... [1.0, 2.333, jnp.nan])
>>> testax.assert_array_almost_equal([1.0, 2.33333, jnp.nan], ... [1.0, 2.33339, jnp.nan], decimal=5) Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not almost equal to 5 decimals Mismatched elements: 1 / 3 (33.3%) Max absolute difference: 6.0081e-05 Max relative difference: 2.5749e-05 x: Array([1. , 2.33333, nan], dtype=float32) y: Array([1. , 2.33339, nan], dtype=float32)
>>> testax.assert_array_almost_equal([1.0,2.33333,jnp.nan], ... [1.0,2.33333, 5], decimal=5) Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not almost equal to 5 decimals x and y nan location mismatch: x: Array([1. , 2.33333, nan], dtype=float32) y: Array([1. , 2.33333, 5. ], dtype=float32)
- testax.assert_array_equal(x: Array, y: Array, err_msg: str = '', verbose: bool = True, *, strict: bool = False, debug: bool = False) None #
Raises an
AssertionError
if twoArray
s are not equal.Given two
Array
s, check that the shape is equal and all elements of these objects are equal (but see the Notes for the special handling of a scalar). An exception is raised at shape mismatch or conflicting values. In contrast to the standard usage in numpy,nan
s are compared like numbers, no assertion is raised if both objects havenan
s in the same positions.The usual caution for verifying equality with floating point numbers is advised.
- Parameters:
x โ The actual object to check.
y โ The desired, expected object.
err_msg โ The error message to be printed in case of failure.
verbose โ If True, the conflicting values are appended to the error message.
strict โ If True, raise an AssertionError when either the shape or the data type of the
Array
s does not match. The special handling for scalars mentioned in the Notes section is disabled.
- Raises:
AssertionError โ If actual and desired objects are not equal.
See also
assert_allclose()
,assert_array_almost_equal_nulp()
,assert_array_max_ulp()
,assert_equal()
Notes
When one of
x
andy
is a scalar and the other is anArray
, the function checks that each element of theArray
is equal to the scalar. This behaviour can be disabled with thestrict
parameter.Examples
The first assert does not raise an exception:
>>> testax.assert_array_equal([1.0, 2.33333, jnp.nan], ... [jnp.exp(0), 2.33333, jnp.nan])
Assert fails with numerical imprecision with floats:
>>> testax.assert_array_equal([1.0, 1e-5, jnp.nan], ... [1, jnp.arccos(jnp.cos(1e-5)), jnp.nan]) Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not equal Mismatched elements: 1 / 3 (33.3%) Max absolute difference: 1e-05 Max relative difference: 0 x: Array([1.e+00, 1.e-05, nan], dtype=float32) y: Array([ 1., 0., nan], dtype=float32)
Use
assert_allclose()
or one of the nulp (number of floating point values) functions for these cases instead:>>> testax.assert_allclose([1.0, 1e-5, jnp.nan], ... [1, jnp.arccos(jnp.cos(1e-5)), jnp.nan], atol=1e-5)
As mentioned in the Notes section,
assert_array_equal()
has special handling for scalars. Here the test checks that each value inx
is 3:>>> x = jnp.full((2, 5), fill_value=3) >>> testax.assert_array_equal(x, 3)
Use
strict
to raise an AssertionError when comparing a scalar with an array:>>> testax.assert_array_equal(x, 3, strict=True) Traceback (most recent call last): ... testax.TestaxError: Arrays are not equal (shapes (2, 5), () mismatch) x: Array([[3, 3, 3, 3, 3], [3, 3, 3, 3, 3]], dtype=int32, weak_type=True) y: Array(3, dtype=int32, weak_type=True)
The
strict
parameter also ensures that the array data types match:>>> x = jnp.array([2, 2, 2]) >>> y = jnp.array([2., 2., 2.], dtype=jnp.float32) >>> testax.assert_array_equal(x, y, strict=True) Traceback (most recent call last): ... testax.TestaxError: Arrays are not equal (dtypes int32, float32 mismatch) x: Array([2, 2, 2], dtype=int32) y: Array([2., 2., 2.], dtype=float32)
- testax.assert_array_less(x: Array, y: Array, err_msg: str = '', verbose: bool = True, *, debug: bool = False) None #
Raises an
AssertionError
if twoArray
s are not ordered by less than.Given two
Array
s, check that the shape is equal and all elements of the first array are strictly smaller than those of the second. An exception is raised at shape mismatch or incorrectly ordered values. Shape mismatch does not raise if an object has zero dimension. In contrast to the standard usage in numpy,nan
s are compared, no assertion is raised if both objects havenan
s in the same positions.- Parameters:
x โ The smaller object to check.
y โ The larger object to compare.
err_msg โ The error message to be printed in case of failure.
verbose โ If True, the conflicting values are appended to the error message.
- Raises:
AssertionError โ If
x
is not strictly smaller thany
, element-wise.
Examples
>>> testax.assert_array_less([1.0, 1.0, jnp.nan], [1.1, 2.0, jnp.nan]) >>> testax.assert_array_less([1.0, 1.0, jnp.nan], [1, 2.0, jnp.nan]) Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not less-ordered Mismatched elements: 1 / 3 (33.3%) Max absolute difference: 1 Max relative difference: 0.5 x: Array([ 1., 1., nan], dtype=float32) y: Array([ 1., 2., nan], dtype=float32)
>>> testax.assert_array_less([1.0, 4.0], 3) Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not less-ordered Mismatched elements: 1 / 2 (50%) Max absolute difference: 2 Max relative difference: 0.666667 x: Array([1., 4.], dtype=float32) y: Array(3, dtype=int32, weak_type=True)
>>> testax.assert_array_less([1.0, 2.0, 3.0], [4]) Traceback (most recent call last): ... testax.TestaxError: Arrays are not less-ordered (shapes (3,), (1,) mismatch) x: Array([1., 2., 3.], dtype=float32) y: Array([4], dtype=int32)
- testax.checkify(func: Callable[[...], T], errors: frozenset[Type[JaxException]] | None = None) Callable[[...], Tuple[Error, T]] #
Functionalize
testax
assertions andjax.experimental.checkify.check()
calls infunc
, and optionally adds run-time error checks. This function has the same behavior asjax.experimental.checkify.checkify()
except it ensurestestax
errors are handled properly. See thejax.experimental.checkify.checkify()
documentation for details.- Parameters:
func โ Callable which can contain
testax
assertions and user checks (seejax.experimental.checkify.check()
for details).errors โ A set of exception types which defines the set of enabled checks. By default
testax
assertions and explicitjax.experimental.checkify.check()
s are enabled.
- Returns:
A function which accepts the same arguments as
func
and returns as output a tuple where the first element is anjax.experimental.checkify.Error
value, representing the first failedtestax
assertion orjax.experimental.checkify.check()
, and the second element is the original output offunc
.
Examples
>>> import jax >>> import jax.numpy as jnp >>> import testax >>> >>> @jax.jit ... def f(x): ... testax.assert_array_less(0, x) ... return jnp.log(x) >>> >>> err, out = testax.checkify(f)(jnp.arange(2)) >>> err.throw() Traceback (most recent call last): ... jax._src.checkify.JaxRuntimeError: Arrays are not less-ordered Mismatched elements: 1 / 2 (50%) Max absolute difference: 1 Max relative difference: 1 x: Array(0, dtype=int32, weak_type=True) y: Array([0, 1], dtype=int32)