-
[tensorflow 2.0] input 값에 따라 masking 하기카테고리 없음 2019. 11. 14. 21:33
input 값에 0이 있으면, 해당 데이터를 False로 masking하는 코드입니다. 출처
class CustomEmbedding(tf.keras.layers.Layer): def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs): super(CustomEmbedding, self).__init__(**kwargs) self.input_dim = input_dim self.output_dim = output_dim self.mask_zero = mask_zero def build(self, input_shape): self.embeddings = self.add_weight( shape=(self.input_dim, self.output_dim), initializer='random_normal', dtype='float32') def call(self, inputs): return tf.nn.embedding_lookup(self.embeddings, inputs) def compute_mask(self, inputs, mask=None): if not self.mask_zero: return None return tf.not_equal(inputs, 0) layer = CustomEmbedding(10, 32, mask_zero=True) x = np.random.random((3, 10)) * 9 x = x.astype('int32') print(x) y = layer(x) mask = layer.compute_mask(x) print(mask)