ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Pytorch mobile
    pytorch & tensorflow 2021. 6. 7. 12:37

    파이토치에서 모바일로 실행할 수 있게 데모앱을 지원해 주고 있다.

     

    https://github.com/pytorch/android-demo-app

     

    pytorch/android-demo-app

    PyTorch android examples of usage in applications. Contribute to pytorch/android-demo-app development by creating an account on GitHub.

    github.com

    클론해서 사용할 수 있으며 데모 앱실행은 자연스럽게 된다.

     

    app단위의 gradle의 depenencies에 아래 코드를 추가하면 android pytorch를 사용할 수 있다.

    implementation 'org.pytorch:pytorch_android:1.7.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.7.0'

     

    데모앱에서 사용하는 것 말고 내가 학습한 모델을 적용하기 위해서는 양자화 된 mobile용 모델을 사용해야 한다.

     

    DEPLOYMENT WORKFLOW

    파이토치 모델을 android용 모델로 바꾸는 일을 WORKFLOW이다.

     

    yolov5의 경우는 yolov5 git repo에서 model 디렉토리에 export.py를 변경해준다.

     

    50번 라인의

    model.model[-1].export = False --> model.model[-1].export = True로 바꿔주고

     

    57번과 58번 라인 사이에 아래 코드를 넣어주고 실행하면

    from torch.utils.mobile_optimizer import optimize_for_mobile
    ts = optimize_for_mobile(ts)

    진행하면 안드로이드용 yolo모델이 저장된다.

     

    저장된 모델은 안드로이드 프로젝트 폴더의 asset에 넣어주고 build를 진행하면 된다.

     

    내가 원하는 모델로 바꾸고 싶은 경우에는 yolov5의 export.py가 아닌 아래의 코드를 사용하면 된다.

     

    import torch
    import torchvision
    from torch.utils.mobile_optimizer import optimize_for_mobile
    #torchvision.models에 있는 모델로 바꿔서 사용할 수 있다.
    model = torchvision.models.mobilenet_v3_small(pretrained=True)
    model.eval()
    example = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, example)
    optimized_traced_model = optimize_for_mobile(traced_script_module)
    optimized_traced_model.save("model.pt")

     

    load_state_dict를 통해서도 내가 학습한 모델을 적용 시킬 수도 있다.

    import torch
    import torchvision
    from torch.utils.mobile_optimizer import optimize_for_mobile
    
    model = torchvision.models.mobilenet_v3_small(pretrained=True)
    model.eval()
    #내가 학습한 모델을 사용할 경우(전이학습을 했거나 새로운 데이터로 학습을 했다면)
    model.load_state_dict(torch.load("./model.pth", strict=False))
    example = torch.rand(1, 3, 224, 224)
    traced_script_module = torch.jit.trace(model, example)
    optimized_traced_model = optimize_for_mobile(traced_script_module)
    optimized_traced_model.save("model.pt")

     

    'pytorch & tensorflow' 카테고리의 다른 글

    resnet-36, resnet-50 구현 tensorflow  (0) 2021.09.29
    VGG-16, VGG-19 Tensorflow 구현  (0) 2021.09.29
    Cycle gan webcam  (0) 2021.06.07
    전이학습 Transfer learning  (0) 2021.06.07
    Mask rcnn 빠르게 사용하기  (0) 2021.04.18
Designed by Tistory.