데이터분석
[tensorflow 2.0] tf.slice
jaehwi0823
2019. 11. 1. 22:50
tf.slice(
input_,
begin,
size,
name=None
)
tf.slice는 python list slice의 함수 형태로 이해하면 된다.
- input_: 원본 input tensor
- begin: 시작 위치
- size: 잘라낼 size (shape)
# 예제 tensor
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
# 원본 t에서 시작위치 [1, 0, 0]은 [3, 3, 3] 리스트의 맨 앞 [3] 이다.
# 시작 위치에서 [1, 1, 3] shape으로 내용물을 꺼내오면,
# 총 1*1*3 개의 원소가 차례대로 선택되고
# 그 결과는 [[[3, 3, 3]]] 이 된다.
tf.slice(t, [1, 0, 0], [1, 1, 3])
# 같은 위치에서 1*2*3 개의 원소를 [1, 2, 3] shape으로 꺼내오면,
# [[[3, 3, 3],
# [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [1, 2, 3])
# 같은 원리로 아래의 결과를 생각해보자
tf.slice(t, [1, 0, 0], [2, 1, 3])
# [[[3, 3, 3]],
# [[5, 5, 5]]]