Bug in jnp.take with negative indexing - Python jax
Hello, I am seeing the following issue with jax.numpy.take and -1 index:
A = jnp.array([[1, 2], [3, 4], [5, 6]])
A[0]
--> [1, 2]
A[-1]
--> [5, 6]
A.take([0, -1], axis=0)
--> [[1, 2], [1, 2]]
jnp.take(A, [0, -1], axis=0)
--> [[1, 2], [1, 2]]
Thanks!
Asked Oct 11 '21 14:10
katebaumli
1 Answer:
Thanks for the report! It looks like the issue is this is lowered to lax.gather
, which doesn't support negative indices.
One workaround currently would be to use mode='wrap'
, which does wrap negative indices correctly.
1
Answered Feb 17 '21 at 18:56
jakevdp
Read next
- vanilla Flagging content "Take Action" button link broken - PHP
- [v1]: info missing cacheControl - TypeScript graphql-modules
- Mongoose has no `DocumentToObjectOptions` - nestjs-query
- baritone Custom Scripts? Java
- Crash - lateinit property initialState has not been initialized - Kotlin LoadingButtonAndroid
- bug(CdkObserveContent): Incorrect typing for CdkObserveContent's [debounce] property - components
- Добавить возможность использовать ЯП Rust для решения. - Elixir codebattle
- Blurriness in FMVs when selecting adaptive downsampling (PAL) - Cplusplus duckstation