-
[Tensorflow 2.0] custom gradient 함수로 reverse gradient 구현하기데이터분석 2019. 11. 7. 20:15
Adversarial Network를 학습하기 위해서는 gradient를 반전해야 할 때가 있다. 그런 경우 Tensorflow 2.x에서는 custom gradient 함수를 정의하여 구현할 수 있다.
일단 공식 가이드의 custom gradient 예시를 보자.
# custom gradient sample @tf.custom_gradient def log1pexp(x): e = tf.exp(x) def grad(dy): return dy * (1 - 1 / (1 + e)) return tf.math.log(1 + e), grad
먼저 @tf.custom_gradient 데코레이터를 사용한다. 그리고 정의하는 함수 내부에서 input 처리 및 gradient를 return하는 함수를 정의해서 다시 반환해주기만 하면 된다!
위 방식을 쉽게 활용하여 reverse gradient를 구현하는 방법은 stackoverflow에서 가져왔다.
# 먼저 custom gradient 정의 @tf.custom_gradient def grad_reverse(x): y = tf.identity(x) def custom_grad(dy): return -dy return y, custom_grad # Layer를 상속하는 class 정의 class GradReverse(tf.keras.layers.Layer): def __init__(self): super().__init__() def call(self, x): return grad_reverse(x) # 일반 Layer처럼 활용! model = Sequential() conv = tf.keras.layers.Conv2D(...)(inp) cust = CustomLayer()(conv) flat = tf.keras.layers.Flatten()(cust) fc = tf.keras.layers.Dense(num_classes)(flat)
'데이터분석' 카테고리의 다른 글
[tensorflow 2.0] tf.data.Dataset enumerate (0) 2019.12.01 [tensorflow 2.0] tf.tile (0) 2019.11.14 [tensorflow 2.0] tf.slice (0) 2019.11.01 텐서플로우(tensorflow) 2.x 에서 iris dataset 분류하기 (0) 2019.10.17 배열 3차원 이상 곱셈 (380) 2019.10.16