카테고리 없음
[tensorflow 2.0] input 값에 따라 masking 하기
jaehwi0823
2019. 11. 14. 21:33
input 값에 0이 있으면, 해당 데이터를 False로 masking하는 코드입니다. 출처
class CustomEmbedding(tf.keras.layers.Layer):
def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):
super(CustomEmbedding, self).__init__(**kwargs)
self.input_dim = input_dim
self.output_dim = output_dim
self.mask_zero = mask_zero
def build(self, input_shape):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
initializer='random_normal',
dtype='float32')
def call(self, inputs):
return tf.nn.embedding_lookup(self.embeddings, inputs)
def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
return tf.not_equal(inputs, 0)
layer = CustomEmbedding(10, 32, mask_zero=True)
x = np.random.random((3, 10)) * 9
x = x.astype('int32')
print(x)
y = layer(x)
mask = layer.compute_mask(x)
print(mask)