ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • SRGAN Tensor flow 코드 구현 및 테스트
    pytorch & tensorflow 2021. 10. 7. 17:55

    SRGAN을 텐서플로우로 모델 구현을 해보겠습니다.

     

    논문에 대한 설명은 밑의 주소에 간단하게 정리 했습니다.

    https://hwanny-yy.tistory.com/18

     

    SR-GAN 정리 및 코드

    Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 라는 제목으로 CVPR2017에 올라간 논문입니다. 2017년에는 좋은 논문들이 많이 나온것 같네요! SRGAN - Super Resolution..

    hwanny-yy.tistory.com

    Generator


    Generator 구조

    SRGAN의 generator부분입니다.

     

    중요부분으로는 residual block과 subpixel block입니다.

     

    Residual block에는 특별하게 PReLU를 사용한것이 특징입니다.

     

    Sub-pixel block에는 PixelShuffel이 사용되었습니다.

     

    from tensorflow.keras import Input, Model, layers
    
    # 그림의 Residual 블록을 정의합니다.
    def gene_base_block(x):
        out = layers.Conv2D(64, 3, 1, "same")(x)
        out = layers.BatchNormalization()(out)
        out = layers.PReLU(shared_axes=[1,2])(out)
        out = layers.Conv2D(64, 3, 1, "same")(out)
        out = layers.BatchNormalization()(out)
        return layers.Add()([x, out])
    
    # 그림의 뒤쪽 Sub-pixel 블록을 정의합니다.
    def upsample_block(x):
        out = layers.Conv2D(256, 3, 1, "same")(x)
        # 그림의 PixelShuffler 라고 쓰여진 부분을 아래와 같이 구현합니다.
        out = layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(out)
        return layers.PReLU(shared_axes=[1,2])(out)
        
    # 전체 Generator를 정의합니다.
    def get_generator(input_shape=(None, None, 3)):
        inputs = Input(input_shape)
        
        out = layers.Conv2D(64, 9, 1, "same")(inputs)
        out = residual = layers.PReLU(shared_axes=[1,2])(out)
        
        for _ in range(5):
            out = gene_base_block(out)
        
        out = layers.Conv2D(64, 3, 1, "same")(out)
        out = layers.BatchNormalization()(out)
        out = layers.Add()([residual, out])
        
        for _ in range(2):
            out = upsample_block(out)
            
        out = layers.Conv2D(3, 9, 1, "same", activation="tanh")(out)
        return Model(inputs, out)

    Discriminator


    Discriminator block

    # 그림의 블록을 정의합니다.
    def disc_base_block(x, n_filters=128):
        out = layers.Conv2D(n_filters, 3, 1, "same")(x)
        out = layers.BatchNormalization()(out)
        out = layers.LeakyReLU()(out)
        out = layers.Conv2D(n_filters, 3, 2, "same")(out)
        out = layers.BatchNormalization()(out)
        return layers.LeakyReLU()(out)
    
    # 전체 Discriminator 정의합니다.
    def get_discriminator(input_shape=(None, None, 3)):
        inputs = Input(input_shape)
        
        out = layers.Conv2D(64, 3, 1, "same")(inputs)
        out = layers.LeakyReLU()(out)
        out = layers.Conv2D(64, 3, 2, "same")(out)
        out = layers.BatchNormalization()(out)
        out = layers.LeakyReLU()(out)
        
        for n_filters in [128, 256, 512]:
            out = disc_base_block(out, n_filters)
        
        out = layers.Dense(1024)(out)
        out = layers.LeakyReLU()(out)
        out = layers.Dense(1, activation="sigmoid")(out)
        return Model(inputs, out)

     

    Content Loss


    SRGAN에서는 새로 제시한 Content loss로 VGG19의 활성화 함수 통과 값을 사용합니다.

     

    from tensorflow.python.keras import applications
    def get_feature_extractor(input_shape=(None, None, 3)):
        vgg = applications.vgg19.VGG19(
            include_top=False, 
            weights="imagenet", 
            input_shape=input_shape
        )
        # 아래 vgg.layers[20]은 vgg 내의 마지막 convolutional layer 입니다.
        return Model(vgg.input, vgg.layers[20].output)

    Training


    from tensorflow.keras import losses, metrics, optimizers
    
    generator = get_generator()
    discriminator = get_discriminator()
    vgg = get_feature_extractor()
    
    # 사용할 loss function 및 optimizer 를 정의합니다.
    bce = losses.BinaryCrossentropy(from_logits=False)
    mse = losses.MeanSquaredError()
    gene_opt = optimizers.Adam()
    disc_opt = optimizers.Adam()
    
    def get_gene_loss(fake_out):
        return bce(tf.ones_like(fake_out), fake_out)
    
    def get_disc_loss(real_out, fake_out):
        return bce(tf.ones_like(real_out), real_out) + bce(tf.zeros_like(fake_out), fake_out)
    
    
    @tf.function
    def get_content_loss(hr_real, hr_fake): 
        hr_real = applications.vgg19.preprocess_input(hr_real)
        hr_fake = applications.vgg19.preprocess_input(hr_fake)
        
        hr_real_feature = vgg(hr_real) / 12.75
        hr_fake_feature = vgg(hr_fake) / 12.75
        return mse(hr_real_feature, hr_fake_feature)
    
    
    @tf.function
    def step(lr, hr_real):
        with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
            hr_fake = generator(lr, training=True)
            
            real_out = discriminator(hr_real, training=True)
            fake_out = discriminator(hr_fake, training=True)
            
            perceptual_loss = get_content_loss(hr_real, hr_fake) + 1e-3 * get_gene_loss(fake_out)
            discriminator_loss = get_disc_loss(real_out, fake_out)
            
        gene_gradient = gene_tape.gradient(perceptual_loss, generator.trainable_variables)
        disc_gradient = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)
        
        gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
        disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
        return perceptual_loss, discriminator_loss
    
    
    gene_losses = metrics.Mean()
    disc_losses = metrics.Mean()
    
    for epoch in range(1, 2):
        for i, (lr, hr) in enumerate(train):
            g_loss, d_loss = step(lr, hr)
            
            gene_losses.update_state(g_loss)
            disc_losses.update_state(d_loss)
            
            # 10회 반복마다 loss를 출력합니다.
            if (i+1) % 10 == 0:
                print(f"EPOCH[{epoch}] - STEP[{i+1}] \nGenerator_loss:{gene_losses.result():.4f} \nDiscriminator_loss:{disc_losses.result():.4f}", end="\n\n")
            
            if (i+1) == 200:
                break
                
        gene_losses.reset_states()
        disc_losses.reset_states()

     

    Test


    모델을 불러옵니다.

     

    import tensorflow as tf
    import os
    
    model_file = './srgan_G.h5'
    srgan = tf.keras.models.load_model(model_file)

    SRGAN을 사용할 함수를 설정합니다.

    predict결과 float32--> uint8로 바꿔줍니다.

    import numpy as np
    
    def apply_srgan(image):
        image = tf.cast(image[np.newaxis, ...], tf.float32)
        sr = srgan.predict(image)
        sr = tf.clip_by_value(sr, 0, 255)
        sr = tf.round(sr)
        sr = tf.cast(sr, tf.uint8)
        return np.array(sr)[0]

    이미지를 입력해 봅시다.

     

    원본 큰 이미지를 0.25배 한것을 입력으로 사용했습니다.

    import cv2
    import matplotlib.pyplot as plt
    
    img_ori = cv2.imread('./zoa.jpg',1)
    img_ori = cv2.cvtColor(img_ori,cv2.COLOR_BGR2RGB)
    img_resize = cv2.resize(img_ori,dsize=(0,0),fx=0.25,fy=0.25,interpolation=cv2.INTER_CUBIC)
    bicubic_hr = cv2.resize(img_resize,dsize=(0,0),fx=4,fy=4,interpolation=cv2.INTER_CUBIC)

    왼쪽부터 원본, SRGAN, bicubic방법입니다.

    확대해서 한번 보겠습니다.

     

    bicubic은 뭉개지는(?) smooth한 결과가 나오고 SRGAN은 머리카락 부분은 다소 거칠게 나오는것 같습니다.

    'pytorch & tensorflow' 카테고리의 다른 글

    resnet-36, resnet-50 구현 tensorflow  (0) 2021.09.29
    VGG-16, VGG-19 Tensorflow 구현  (0) 2021.09.29
    Pytorch mobile  (0) 2021.06.07
    Cycle gan webcam  (0) 2021.06.07
    전이학습 Transfer learning  (0) 2021.06.07
Designed by Tistory.