ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Mask rcnn 빠르게 사용하기
    pytorch & tensorflow 2021. 4. 18. 00:12

    torch vision에 있는 pretrained된 mask_rcnn을 활용해서 결과를 바로바로 확인해보고 싶다.

     

    단순하게 input폴더에 이미지를 넣으면 output폴더에 segmentation결과가 바로 나올 수 있게...

     

    또한 결과를 확인 하면서 특정 객체만 segmentation하고 싶다.

     

    다른 예제들은 모든 객체에 대해 결과를 보여준다.

     

    mask rcnn 결과

    위와 같이 모든 클래스에 대해 결과가 나온다. 하지만 특정 물체만 찾고 싶을때는 조금 어렵다.

     

    사람만 찾아서 segmentation함

    두가지의 대해서 사용을 하고 싶다! 간단하게 pytorch를 활용하여 만들어 보자!

     

    프로그램 진행은 간단하다.

    1. dataloader을 정의한다.
    2. model을 정의 한다.
    3. dataloader를 통해 불러온 이미지를 모델에 넣기
    4. 모델을 통해 나온 결과를 처리 해줌.

    위의 4가지 과정을 간단하게 구현해 보았다.

     

    1.dataloader

     

    import glob
    import random
    import os
    import natsort
    from torch.utils.data import Dataset
    from PIL import Image
    import torchvision.transforms as transforms
    
    class ImageDataset(Dataset):
        def __init__(self, root, transforms_ = None):
            self.transform = transforms.Compose(transforms_)
            self.files_A = sorted(glob.glob(os.path.join(root ) + '*.*'))        
            self.files_A = natsort.natsorted(self.files_A)        
        def __getitem__(self, index):
            
            item_A = self.transform(Image.open(self.files_A[index]))        
            return {'A': item_A}
            
        def __len__(self):
            return len(self.files_A)

    이미지가 들어있는 폴더를 root로 받아서 정렬한뒤 이름을 A로 정의해서 전달해주는 아주 간단한 dataloader이다.

     

    2. model 정의

    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    
    def init(data_path):
        #Dataset Loader
        transforms_ = [ transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
        dataloader = DataLoader(ImageDataset(data_path, transforms_=transforms_), 
                                batch_size=opt.batchSize, shuffle=False, num_workers=opt.n_cpu)
        return dataloader

    모델을 정의 하면서 dataloader에서 같이 정의해 주었다. 

     

    3. dataloader을 통해 model에 집어넣기.

     

    def mask_rcnn(file_list,dataloader,save_path):
        for i,batch in enumerate(dataloader):        
        
            img = batch['A']        
            print(file_list[i])
            
            #use cuda
            #result = model(img.to('cuda'))
            #use cpu
            result = model(img)
               
            
            image = tensor2im(img)
            scores = list(result[0]['scores'].detach().numpy())
            thresholded_preds_inidices = [scores.index(i) for i in scores if i > 0.965]
            thresholded_preds_count = len(thresholded_preds_inidices)        
            mask = result[0]['masks']
            mask = mask[:thresholded_preds_count]
            labels = result[0]['labels']        
            boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))]  for i in result[0]['boxes']]
            boxes = boxes[:thresholded_preds_count]
            
            mask = mask.data.float().numpy()        
            #사람을 지우고 싶을때
            human_remove(image,mask,labels,file_list[i],save_path)
            # 모든 객체에 segmetation과 박스를 출력하고 싶을때
            apply_mask(image,mask,labels,boxes,file_list[i],save_path)

    mask rcnn은 여러 정보들을 output으로 출력한다.

    box의 좌표, segmentation용  mask

    inference 레이블 결과들이 나온다. 그리고 각각의 점수들을 출력한다.

    mask와 boxes는 다른 레이블의 결과들도 나올 수 있어 0.965이상의 score만 처리를 한다.

     

    그리고 output들을 4번 후 처리의 함수에 넣어 준다.

     

    4. 모델을 통해 나온 결과 처리(사람만 지우기)

    def human_remove(image,mask,labels,file_name,save_path):
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)    
        image = cv2.cvtColor(image,cv2.COLOR_RGB2RGBA)
        
        image[:,:,3] = 255.0
        channel = image.shape[2]
        area = 0    
        _,_,w,h = mask.shape
        for n in range(mask.shape[0]):
                if labels[n] == 1:                
                    mask[n] = np.where(mask[n] >0.5, 255,0)
                else :
                    continue        
        for n in range(mask.shape[0]):
                if labels[n] == 1:
                    for c in range(channel):                
                        image[:,:,c] = np.where(mask[n] == 255, 0, image[:,:,c])
                else :
                    continue                
        #image save
        cv2.imwrite(save_path +'seg_' + file_name ,image)

    투명도를 사용하기 위해서는 RGB영역이 아닌 RGBA를 사용해야 한다.

    cocodataset을 기반으로 mask rcnn을 활용했을때, 사람은 label중 1번이다 label이 1인 것만 mask를 처리한다.

    mask가 0.5 이상인 부분만 사람인 영역이 된다.

    mask를 만들었으면 mask를 기반으로 image를 RGBA값들을 모두 0으로 만들어 준다.

     

    4. 모델을 통해 나온 결과 처리(모든 객체에 적용)

    def apply_mask(image,mask,labels,boxes,file_name,save_path):
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)    
        # image = cv2.cvtColor(image,cv2.COLOR_RGB2RGBA)
        
        alpha = 1 
        beta = 0.6 # transparency for the segmentation map
        gamma = 0 # scalar added to each sum
        COLORS = np.random.uniform(0, 255, size=(len(class_names), 3))    
        channel = image.shape[2]
        area = 0    
        _,_,w,h = mask.shape    
        segmentation_map = np.zeros((w,h,3),np.uint8)
        
        for n in range(mask.shape[0]): 
            if labels[n] == 0:
                continue
            else:
                color = COLORS[random.randrange(0,len(COLORS))]
                segmentation_map[:,:,0] = np.where(mask[n] > 0.5, COLORS[labels[n]][0], 0)
                segmentation_map[:,:,1] = np.where(mask[n] > 0.5, COLORS[labels[n]][1], 0)
                segmentation_map[:,:,2] = np.where(mask[n] > 0.5, COLORS[labels[n]][2], 0)            
                image = cv2.addWeighted(image,alpha,segmentation_map,beta,gamma,dtype = cv2.CV_8U)
            # draw the bounding boxes around the objects        
            cv2.rectangle(image, boxes[n][0], boxes[n][1],color = color ,thickness = 2)
            
            print(class_names[labels[n]])
            # put the label text above the objects
            cv2.putText(image , class_names[labels[n]], (boxes[n][0][0], boxes[n][0][1]-10), 
                        cv2.FONT_HERSHEY_SIMPLEX, 1, color, 
                        thickness=2, lineType=cv2.LINE_AA)
        #image save
        cv2.imwrite(save_path +'seg_2' + file_name ,image)

    위와 비슷하지만 모든 객체에 대해 진행한다. label==0인부분은 BG background라서 무시한다

    color값들은 각각의 클래스에 맞게 random하게 적용한다.

     

    cv2.addWeight를 이용하여 색과 원래 이미지를 적절하게 합친다.

     

    박스의 위치들을 기반으로 박스를 그리고 그 위에 글씨를 입력한다.

     

    4가지의 과정을 통해 간단하게 mask_rcnn을 활용하고 inference 결과들을 입맛에 맞게 처리할 수 있게 됐다.

     

    코드 전문 : github.com/LeeJuwhan/Image-segmentation

     

    LeeJuwhan/Image-segmentation

    Contribute to LeeJuwhan/Image-segmentation development by creating an account on GitHub.

    github.com

     

    'pytorch & tensorflow' 카테고리의 다른 글

    VGG-16, VGG-19 Tensorflow 구현  (0) 2021.09.29
    Pytorch mobile  (0) 2021.06.07
    Cycle gan webcam  (0) 2021.06.07
    전이학습 Transfer learning  (0) 2021.06.07
    파이토치 Image segmentation  (0) 2021.04.17
Designed by Tistory.