2018年2月23日金曜日

tf.scatter_ndでゼロ埋め式テンソルを一発で作る

 (10, 4) の形のテンソルを作る際に、10個のうち4個だけ該当するインデックスに対して値を埋めて、あとはゼロにしたいみたいなことがある。
numpy を使ってまっすぐやると


tensor = np.zeros((10, 4))
indices = [2, 5, 6, 9]
tensor[indices] = [1, 1, 1, 1]

# array([[ 0.,  0.,  0.,  0.],
#        [ 0.,  0.,  0.,  0.],
#        [ 1.,  1.,  1.,  1.],
#        [ 0.,  0.,  0.,  0.],
#        [ 0.,  0.,  0.,  0.],
#        [ 1.,  1.,  1.,  1.],
#        [ 1.,  1.,  1.,  1.],
#        [ 0.,  0.,  0.,  0.],
#        [ 0.,  0.,  0.,  0.],
#        [ 1.,  1.,  1.,  1.]])


tf.scatter_ndを使うと

indices = [[2], [5], [6], [9]]
updates = [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]
shape = [10, 4]
tensor = tf.scatter_nd(indices, updates, shape)

with tf.Session() as sess:
    print(sess.run(tensor))

# [[0 0 0 0]
# [0 0 0 0]
# [1 1 1 1]
# [0 0 0 0]
# [0 0 0 0]
# [1 1 1 1]
# [1 1 1 1]
# [0 0 0 0]
# [0 0 0 0]
# [1 1 1 1]]


一発でとか言ったんですが大概でした