Custom
-
[Tensorflow 2.0] custom gradient 함수로 reverse gradient 구현하기데이터분석 2019. 11. 7. 20:15
Adversarial Network를 학습하기 위해서는 gradient를 반전해야 할 때가 있다. 그런 경우 Tensorflow 2.x에서는 custom gradient 함수를 정의하여 구현할 수 있다. 일단 공식 가이드의 custom gradient 예시를 보자. # custom gradient sample @tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.math.log(1 + e), grad 먼저 @tf.custom_gradient 데코레이터를 사용한다. 그리고 정의하는 함수 내부에서 input 처리 및 gradient를 return하는 함수를 정의해서 다시 반환..