카테고리 없음

[tensorflow 2.0] input 값에 따라 masking 하기

jaehwi0823 2019. 11. 14. 21:33

input 값에 0이 있으면, 해당 데이터를 False로 masking하는 코드입니다. 출처

 

class CustomEmbedding(tf.keras.layers.Layer):
  
  def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):
    super(CustomEmbedding, self).__init__(**kwargs)
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.mask_zero = mask_zero
    
  def build(self, input_shape):
    self.embeddings = self.add_weight(
      shape=(self.input_dim, self.output_dim),
      initializer='random_normal',
      dtype='float32')
    
  def call(self, inputs):
    return tf.nn.embedding_lookup(self.embeddings, inputs)
  
  def compute_mask(self, inputs, mask=None):
    if not self.mask_zero:
      return None
    return tf.not_equal(inputs, 0)
  
  
layer = CustomEmbedding(10, 32, mask_zero=True)

x = np.random.random((3, 10)) * 9
x = x.astype('int32')
print(x)

y = layer(x)
mask = layer.compute_mask(x)
print(mask)