-
SRGAN Tensor flow 코드 구현 및 테스트pytorch & tensorflow 2021. 10. 7. 17:55
SRGAN을 텐서플로우로 모델 구현을 해보겠습니다.
논문에 대한 설명은 밑의 주소에 간단하게 정리 했습니다.
https://hwanny-yy.tistory.com/18
Generator
SRGAN의 generator부분입니다.
중요부분으로는 residual block과 subpixel block입니다.
Residual block에는 특별하게 PReLU를 사용한것이 특징입니다.
Sub-pixel block에는 PixelShuffel이 사용되었습니다.
from tensorflow.keras import Input, Model, layers # 그림의 Residual 블록을 정의합니다. def gene_base_block(x): out = layers.Conv2D(64, 3, 1, "same")(x) out = layers.BatchNormalization()(out) out = layers.PReLU(shared_axes=[1,2])(out) out = layers.Conv2D(64, 3, 1, "same")(out) out = layers.BatchNormalization()(out) return layers.Add()([x, out]) # 그림의 뒤쪽 Sub-pixel 블록을 정의합니다. def upsample_block(x): out = layers.Conv2D(256, 3, 1, "same")(x) # 그림의 PixelShuffler 라고 쓰여진 부분을 아래와 같이 구현합니다. out = layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(out) return layers.PReLU(shared_axes=[1,2])(out) # 전체 Generator를 정의합니다. def get_generator(input_shape=(None, None, 3)): inputs = Input(input_shape) out = layers.Conv2D(64, 9, 1, "same")(inputs) out = residual = layers.PReLU(shared_axes=[1,2])(out) for _ in range(5): out = gene_base_block(out) out = layers.Conv2D(64, 3, 1, "same")(out) out = layers.BatchNormalization()(out) out = layers.Add()([residual, out]) for _ in range(2): out = upsample_block(out) out = layers.Conv2D(3, 9, 1, "same", activation="tanh")(out) return Model(inputs, out)
Discriminator
# 그림의 블록을 정의합니다. def disc_base_block(x, n_filters=128): out = layers.Conv2D(n_filters, 3, 1, "same")(x) out = layers.BatchNormalization()(out) out = layers.LeakyReLU()(out) out = layers.Conv2D(n_filters, 3, 2, "same")(out) out = layers.BatchNormalization()(out) return layers.LeakyReLU()(out) # 전체 Discriminator 정의합니다. def get_discriminator(input_shape=(None, None, 3)): inputs = Input(input_shape) out = layers.Conv2D(64, 3, 1, "same")(inputs) out = layers.LeakyReLU()(out) out = layers.Conv2D(64, 3, 2, "same")(out) out = layers.BatchNormalization()(out) out = layers.LeakyReLU()(out) for n_filters in [128, 256, 512]: out = disc_base_block(out, n_filters) out = layers.Dense(1024)(out) out = layers.LeakyReLU()(out) out = layers.Dense(1, activation="sigmoid")(out) return Model(inputs, out)
Content Loss
SRGAN에서는 새로 제시한 Content loss로 VGG19의 활성화 함수 통과 값을 사용합니다.
from tensorflow.python.keras import applications def get_feature_extractor(input_shape=(None, None, 3)): vgg = applications.vgg19.VGG19( include_top=False, weights="imagenet", input_shape=input_shape ) # 아래 vgg.layers[20]은 vgg 내의 마지막 convolutional layer 입니다. return Model(vgg.input, vgg.layers[20].output)
Training
from tensorflow.keras import losses, metrics, optimizers generator = get_generator() discriminator = get_discriminator() vgg = get_feature_extractor() # 사용할 loss function 및 optimizer 를 정의합니다. bce = losses.BinaryCrossentropy(from_logits=False) mse = losses.MeanSquaredError() gene_opt = optimizers.Adam() disc_opt = optimizers.Adam() def get_gene_loss(fake_out): return bce(tf.ones_like(fake_out), fake_out) def get_disc_loss(real_out, fake_out): return bce(tf.ones_like(real_out), real_out) + bce(tf.zeros_like(fake_out), fake_out) @tf.function def get_content_loss(hr_real, hr_fake): hr_real = applications.vgg19.preprocess_input(hr_real) hr_fake = applications.vgg19.preprocess_input(hr_fake) hr_real_feature = vgg(hr_real) / 12.75 hr_fake_feature = vgg(hr_fake) / 12.75 return mse(hr_real_feature, hr_fake_feature) @tf.function def step(lr, hr_real): with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape: hr_fake = generator(lr, training=True) real_out = discriminator(hr_real, training=True) fake_out = discriminator(hr_fake, training=True) perceptual_loss = get_content_loss(hr_real, hr_fake) + 1e-3 * get_gene_loss(fake_out) discriminator_loss = get_disc_loss(real_out, fake_out) gene_gradient = gene_tape.gradient(perceptual_loss, generator.trainable_variables) disc_gradient = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables) gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables)) disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables)) return perceptual_loss, discriminator_loss gene_losses = metrics.Mean() disc_losses = metrics.Mean() for epoch in range(1, 2): for i, (lr, hr) in enumerate(train): g_loss, d_loss = step(lr, hr) gene_losses.update_state(g_loss) disc_losses.update_state(d_loss) # 10회 반복마다 loss를 출력합니다. if (i+1) % 10 == 0: print(f"EPOCH[{epoch}] - STEP[{i+1}] \nGenerator_loss:{gene_losses.result():.4f} \nDiscriminator_loss:{disc_losses.result():.4f}", end="\n\n") if (i+1) == 200: break gene_losses.reset_states() disc_losses.reset_states()
Test
모델을 불러옵니다.
import tensorflow as tf import os model_file = './srgan_G.h5' srgan = tf.keras.models.load_model(model_file)
SRGAN을 사용할 함수를 설정합니다.
predict결과 float32--> uint8로 바꿔줍니다.
import numpy as np def apply_srgan(image): image = tf.cast(image[np.newaxis, ...], tf.float32) sr = srgan.predict(image) sr = tf.clip_by_value(sr, 0, 255) sr = tf.round(sr) sr = tf.cast(sr, tf.uint8) return np.array(sr)[0]
이미지를 입력해 봅시다.
원본 큰 이미지를 0.25배 한것을 입력으로 사용했습니다.
import cv2 import matplotlib.pyplot as plt img_ori = cv2.imread('./zoa.jpg',1) img_ori = cv2.cvtColor(img_ori,cv2.COLOR_BGR2RGB) img_resize = cv2.resize(img_ori,dsize=(0,0),fx=0.25,fy=0.25,interpolation=cv2.INTER_CUBIC) bicubic_hr = cv2.resize(img_resize,dsize=(0,0),fx=4,fy=4,interpolation=cv2.INTER_CUBIC)
확대해서 한번 보겠습니다.
bicubic은 뭉개지는(?) smooth한 결과가 나오고 SRGAN은 머리카락 부분은 다소 거칠게 나오는것 같습니다.
'pytorch & tensorflow' 카테고리의 다른 글
resnet-36, resnet-50 구현 tensorflow (0) 2021.09.29 VGG-16, VGG-19 Tensorflow 구현 (0) 2021.09.29 Pytorch mobile (0) 2021.06.07 Cycle gan webcam (0) 2021.06.07 전이학습 Transfer learning (0) 2021.06.07