데이터분석

[tensorflow 2.0] tf.data.Dataset enumerate

jaehwi0823 2019. 12. 1. 17:55

tf.data.Dataset API를 보면 enumerate() 함수가 있다. 배치마다의 값 확인을 위해서 아래 함수를 쓰면 되겠구나! 했는데..

 

# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { (7, 8), (9, 10) }

# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
a.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
b.enumerate() == { (0, (7, 8)), (1, (9, 10)) }

 

가이드 상에서는 분명 from_tensor_slices()의 return이 dataset임에도, 아래와 같이 사용하면 에러가 난다. 에러 메세지는 myds가 BatchedDataset이라서 enumerate() 함수가 없다는 것이다.. (공식 가이드 ??)

myds = tf.data.Dataset.from_tensor_slices(데이터).batch(batch_size)
for itr, (train, label, origin) in myds.enumerate():
    # 내용

 

그래서 위의 내용을 구현하려면 python base 기능을 이용해야 한다.. 

myds = tf.data.Dataset.from_tensor_slices(데이터).batch(batch_size)
for itr, (train, label, origin) in enumerate(myds):
    # 내용