데이터분석

[Tensorflow 2.0] custom gradient 함수로 reverse gradient 구현하기

jaehwi0823 2019. 11. 7. 20:15

Adversarial Network를 학습하기 위해서는 gradient를 반전해야 할 때가 있다. 그런 경우 Tensorflow 2.x에서는 custom gradient 함수를 정의하여 구현할 수 있다.

 

일단 공식 가이드의 custom gradient 예시를 보자.

 

# custom gradient sample
@tf.custom_gradient
def log1pexp(x):
  e = tf.exp(x)
  def grad(dy):
    return dy * (1 - 1 / (1 + e))
  return tf.math.log(1 + e), grad

 

먼저 @tf.custom_gradient 데코레이터를 사용한다. 그리고 정의하는 함수 내부에서 input 처리 및 gradient를 return하는 함수를 정의해서 다시 반환해주기만 하면 된다!

 

위 방식을 쉽게 활용하여 reverse gradient를 구현하는 방법은 stackoverflow에서 가져왔다.

# 먼저 custom gradient 정의
@tf.custom_gradient
def grad_reverse(x):
    y = tf.identity(x)
    def custom_grad(dy):
        return -dy
    return y, custom_grad
 
# Layer를 상속하는 class 정의
class GradReverse(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
    def call(self, x):
        return grad_reverse(x)

# 일반 Layer처럼 활용!
model = Sequential()
conv = tf.keras.layers.Conv2D(...)(inp)
cust = CustomLayer()(conv)
flat = tf.keras.layers.Flatten()(cust)
fc = tf.keras.layers.Dense(num_classes)(flat)