Bug in jnp.take with negative indexing - Python jax

Hello, I am seeing the following issue with jax.numpy.take and -1 index:

A = jnp.array([[1, 2], [3, 4], [5, 6]]) A[0] --> [1, 2] A[-1] --> [5, 6] A.take([0, -1], axis=0) --> [[1, 2], [1, 2]] jnp.take(A, [0, -1], axis=0) --> [[1, 2], [1, 2]]

Thanks!

Asked Oct 11 '21 14:10
avatar katebaumli
katebaumli

1 Answer:

Thanks for the report! It looks like the issue is this is lowered to lax.gather, which doesn't support negative indices.

One workaround currently would be to use mode='wrap', which does wrap negative indices correctly.

1
Answered Feb 17 '21 at 18:56
avatar  of jakevdp
jakevdp