데이터분석

[tensorflow 2.0] tf.slice

jaehwi0823 2019. 11. 1. 22:50

tf.slice(
    input_
,
   
begin,
    size
,
    name
=None
)

 

tf.slice는 python list slice의 함수 형태로 이해하면 된다.

  • input_: 원본 input tensor
  • begin: 시작 위치
  • size: 잘라낼 size (shape)

 

# 예제 tensor
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
                 [[3, 3, 3], [4, 4, 4]],
                 [[5, 5, 5], [6, 6, 6]]])


# 원본 t에서 시작위치 [1, 0, 0]은 [3, 3, 3] 리스트의 맨 앞 [3] 이다.
# 시작 위치에서 [1, 1, 3] shape으로 내용물을 꺼내오면,
# 총 1*1*3 개의 원소가 차례대로 선택되고 
# 그 결과는 [[[3, 3, 3]]] 이 된다.
tf.slice(t, [1, 0, 0], [1, 1, 3])


# 같은 위치에서 1*2*3 개의 원소를 [1, 2, 3] shape으로 꺼내오면,
# [[[3, 3, 3],
#   [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [1, 2, 3])


# 같은 원리로 아래의 결과를 생각해보자
tf.slice(t, [1, 0, 0], [2, 1, 3])

# [[[3, 3, 3]],
#  [[5, 5, 5]]]