2018年2月17日土曜日

多次元配列で最大値とそのインデックスをとる


li = np.random.rand(1,20).reshape(4,5)

# array([[ 0.08902982, 0.07044699, 0.80305672, 0.16506277, 0.48436258],
#        [ 0.34410367, 0.52830111, 0.29357735, 0.52468119, 0.55911187],
#        [ 0.98363013, 0.63727013, 0.3395793 , 0.94851833, 0.42974549],
#        [ 0.48017984, 0.50335503, 0.03280679, 0.82860064, 0.77796785]])


mx = li.argmax(axis=1)
l = []
for i in range(len(li)):
    l.append([[position, value] for position, value in enumerate(li[i])][mx[i]])

# l ->
# [[2, 0.80305672332582601],
# [4, 0.55911186642729027],
# [0, 0.98363012662030236],
# [3, 0.82860063949987028]]

処理のしやすさから l = np.array(l) としてここから使っていく。
物体検知のカテゴリ推論で使いました

https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html