ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [tensorflow 2.0] input 값에 따라 masking 하기
    카테고리 없음 2019. 11. 14. 21:33

    input 값에 0이 있으면, 해당 데이터를 False로 masking하는 코드입니다. 출처

     

    class CustomEmbedding(tf.keras.layers.Layer):
      
      def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):
        super(CustomEmbedding, self).__init__(**kwargs)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.mask_zero = mask_zero
        
      def build(self, input_shape):
        self.embeddings = self.add_weight(
          shape=(self.input_dim, self.output_dim),
          initializer='random_normal',
          dtype='float32')
        
      def call(self, inputs):
        return tf.nn.embedding_lookup(self.embeddings, inputs)
      
      def compute_mask(self, inputs, mask=None):
        if not self.mask_zero:
          return None
        return tf.not_equal(inputs, 0)
      
      
    layer = CustomEmbedding(10, 32, mask_zero=True)
    
    x = np.random.random((3, 10)) * 9
    x = x.astype('int32')
    print(x)
    
    y = layer(x)
    mask = layer.compute_mask(x)
    print(mask)

    댓글

Designed by Tistory.