выбор оси массива numpy по int
Я пытаюсь систематически получить доступ к оси массива numpy. Например, предположим, что у меня есть массив
a = np.random.random((10, 10, 10, 10, 10, 10, 10))
# choosing 7:9 from axis 2
b = a[:, :, 7:9, ...]
# choosing 7:9 from axis 3
c = a[:, :, :, 7:9, ...]
Ввод двоеточий становится очень повторяющимся, если у меня есть массив высокой размерности. Теперь мне нужна некоторая функция choose_from_axis такая, что
# choosing 7:9 from axis 2
b = choose_from_axis(a, 2, 7, 9)
# choosing 7:9 from axis 3
c = choose_from_axis(a, 3, 7, 9)
Итак, в принципе, я хочу получить доступ к оси с номером. Единственный способ, который я знаю, как это сделать, - это использовать
rollaxis взад и вперед, но я ищу более прямой способ сделать это. 2 ответов:
Похоже, вы ищете Возьмите :
>>> a = np.random.randint(0,100, (3,4,5)) >>> a[:,1:3,:] array([[[61, 4, 89, 24, 86], [48, 75, 4, 27, 65]], [[57, 55, 55, 6, 95], [19, 16, 4, 61, 42]], [[24, 89, 41, 74, 85], [27, 84, 23, 70, 29]]]) >>> a.take(np.arange(1,3), axis=1) array([[[61, 4, 89, 24, 86], [48, 75, 4, 27, 65]], [[57, 55, 55, 6, 95], [19, 16, 4, 61, 42]], [[24, 89, 41, 74, 85], [27, 84, 23, 70, 29]]])Это также даст вам поддержку индексации кортежей. Пример:
>>> a = np.arange(2*3*4).reshape(2,3,4) >>> a array([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]]) >>> a[:,:,(0,1,3)] array([[[ 0, 1, 3], [ 4, 5, 7], [ 8, 9, 11]], [[12, 13, 15], [16, 17, 19], [20, 21, 23]]]) >>> a.take((0,1,3), axis=2) array([[[ 0, 1, 3], [ 4, 5, 7], [ 8, 9, 11]], [[12, 13, 15], [16, 17, 19], [20, 21, 23]]])
Вы можете построить объект slice, который выполняет эту работу:
def choose_from_axis(a, axis, start, stop): s = [slice(None) for i in range(a.ndim)] s[axis] = slice(start, stop) return a[s]Например, следующие оба дают один и тот же результат:
x[:,1:2,:] choose_from_axis(x, 1, 1, 2) # [[[ 3 4 5]] # [[12 13 14]] # [[21 22 23]]]Как и пример в вопросе:
a = np.random.random((10, 10, 10, 10, 10, 10, 10)) a0 = a[:, :, 7:9, ...] a1 = choose_from_axis(a, 2, 7, 9) print np.all(a0==a1) # True
Comments