데이터분석
[Tensorflow 2.0] custom gradient 함수로 reverse gradient 구현하기
jaehwi0823
2019. 11. 7. 20:15
Adversarial Network를 학습하기 위해서는 gradient를 반전해야 할 때가 있다. 그런 경우 Tensorflow 2.x에서는 custom gradient 함수를 정의하여 구현할 수 있다.
일단 공식 가이드의 custom gradient 예시를 보자.
# custom gradient sample
@tf.custom_gradient
def log1pexp(x):
e = tf.exp(x)
def grad(dy):
return dy * (1 - 1 / (1 + e))
return tf.math.log(1 + e), grad
먼저 @tf.custom_gradient 데코레이터를 사용한다. 그리고 정의하는 함수 내부에서 input 처리 및 gradient를 return하는 함수를 정의해서 다시 반환해주기만 하면 된다!
위 방식을 쉽게 활용하여 reverse gradient를 구현하는 방법은 stackoverflow에서 가져왔다.
# 먼저 custom gradient 정의
@tf.custom_gradient
def grad_reverse(x):
y = tf.identity(x)
def custom_grad(dy):
return -dy
return y, custom_grad
# Layer를 상속하는 class 정의
class GradReverse(tf.keras.layers.Layer):
def __init__(self):
super().__init__()
def call(self, x):
return grad_reverse(x)
# 일반 Layer처럼 활용!
model = Sequential()
conv = tf.keras.layers.Conv2D(...)(inp)
cust = CustomLayer()(conv)
flat = tf.keras.layers.Flatten()(cust)
fc = tf.keras.layers.Dense(num_classes)(flat)