从Tensorflow中的张量中删除一组张量
问题内容:
我正在寻找一种简单的方法来从Tensorflow中的当前张量中删除一组张量,并且我有一个困难的合理解决方案。
例如,假设我具有以下当前张量:
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
我想从该张量中删除两个项目(2.0和5.0)。
创建张量后将其转换为[1.0、3.0、4.0、6.0]的最佳方法是什么?
提前谢谢了。
问题答案:
您可以调用tf.unstack
以获取子张量列表。然后,您可以修改列表并调用tf.stack
以从列表构造张量。例如,以下代码从以下位置删除[2.0,5.0]列:
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
a_vecs = tf.unstack(a, axis=1)
del a_vecs[1]
a_new = tf.stack(a_vecs, 1)