实例
数据集
笔记本

笔记本

如何部署 PyTorch 深度学习模型

使用 TorchServe 对 PyTorch 训练的模型进行部署
Dave上传于 4 years ago
标签
PyTorch
浏览3346
笔记本内容

如何布署 PyTorch 深度学习模型 #

需要介绍一下 TorchServe,这是一个灵活且易于使用的工具,用于为 PyTorch 模型提供服务。

为什么要布署 #

有小伙伴问了:模型还要怎么布署呢,我直接 python run.py 运行不行吗?

drawing

设想一个最简单的刷脸进门的程序,如果每来一个人都要手动执行一次 python run.py,中午饭点的时候手可能就抽筋了。。。。。

所以我们需要一个能够随时监听请求的服务,来代替我们的双手。

怎么布署 #

简单来讲,通过网络接收各种协议(HTTP)发送过来的输入数据,调用提前存放的模型进行推理,再返回结果。

当然发送的数据会遵守一定的规范(REST),返回的数据也遵循一定的格式(json, xml),这些细节感兴趣的小伙伴可以自行了解学习。

TorchServe 架构 #

image.png

安装 #

# 安装 Torch Serve 以及必要的一些包
%pip install torchserve torch-model-archiver torch-workflow-archiver captum timm    
Looking in indexes: https://mirrors.aliyun.com/pypi/simple
Collecting torchserve
  Using cached https://mirrors.aliyun.com/pypi/packages/b3/97/81e2d0fae8c6697501d7a63d67eb74b43bde33dca66e42ae223d5403a622/torchserve-0.4.2-py2.py3-none-any.whl (18.1 MB)
Collecting torch-model-archiver
  Using cached https://mirrors.aliyun.com/pypi/packages/8f/8a/996b77e076e3aeac6dda067537d6a407af7e399f129b3fc479e5d03a6a3d/torch_model_archiver-0.4.2-py2.py3-none-any.whl (14 kB)
Collecting torch-workflow-archiver
  Using cached https://mirrors.aliyun.com/pypi/packages/1c/c1/3a6867df0ed8d7a89dccf37ac491fae63e0e20104bf21abb4728fa4fee43/torch_workflow_archiver-0.1.2-py2.py3-none-any.whl (12 kB)
Collecting captum
  Using cached https://mirrors.aliyun.com/pypi/packages/cc/da/50dd447964766b92a0d1e3781559401b5c58d3b524b8dbb8fab75dc98070/captum-0.4.0-py3-none-any.whl (1.4 MB)
Collecting timm
  Using cached https://mirrors.aliyun.com/pypi/packages/90/fc/606bc5cf46acac3aa9bd179b3954433c026aaf88ea98d6b19f5d14c336da/timm-0.4.12-py3-none-any.whl (376 kB)
Requirement already satisfied: numpy in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from captum) (1.19.5)
Requirement already satisfied: matplotlib in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from captum) (3.3.4)
Requirement already satisfied: torch>=1.2 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from captum) (1.9.0+cu111)
Requirement already satisfied: typing-extensions in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from torch>=1.2->captum) (3.10.0.2)
Requirement already satisfied: torchvision in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from timm) (0.10.0+cu111)
Collecting future
  Using cached future-0.18.2-py3-none-any.whl
Collecting enum-compat
  Using cached https://mirrors.aliyun.com/pypi/packages/55/ae/467bc4509246283bb59746e21a1a2f5a8aecbef56b1fa6eaca78cd438c8b/enum_compat-0.0.3-py3-none-any.whl (1.3 kB)
Requirement already satisfied: Pillow in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from torchserve) (8.0.0)
Collecting psutil
  Using cached https://mirrors.aliyun.com/pypi/packages/84/da/f7efdcf012b51506938553dbe302aecc22f3f43abd5cffa8320e8e0588d5/psutil-5.8.0-cp37-cp37m-manylinux2010_x86_64.whl (296 kB)
Requirement already satisfied: packaging in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from torchserve) (21.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (2.4.7)
Requirement already satisfied: cycler>=0.10 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (1.2.0)
Requirement already satisfied: python-dateutil>=2.1 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (2.8.1)
Requirement already satisfied: six in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->captum) (1.15.0)
Installing collected packages: psutil, future, enum-compat, torchserve, torch-workflow-archiver, torch-model-archiver, timm, captum
Successfully installed captum-0.4.0 enum-compat-0.0.3 future-0.18.2 psutil-5.8.0 timm-0.4.12 torch-model-archiver-0.4.2 torch-workflow-archiver-0.1.2 torchserve-0.4.2
Note: you may need to restart the kernel to use updated packages.
# 创建一个用于存储模型的目录
!mkdir model_store

权重文件 #

真实环境中就使用自己训练好的权重,这里方便展示就使用了 imagenet 的预训练模型

import timm
import torch
import torch.nn.utils.prune as prune

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.backbone = timm.create_model('efficientnet_b0', pretrained=True)

    def forward(self, x):
        x = self.backbone(x)
        return x

model = Net()
torch.save(model.state_dict(), 'model.pth.tar')

model.py #

model.py 需要包含单个模型的类, 并且能成功加载(torch.load_state_dict)上面的 model.pth

!featurize dataset download ee8b0992-df2e-4ad6-a240-1f486b8eef8b
100%|██████████████████████████████████████████| 845/845 [00:00<00:00, 200kiB/s]
🍬  下载完成,正在解压...
🏁  数据集已经成功添加
!cat /home/featurize/data/torchserve/model.py
import torch
import timm

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.backbone = timm.create_model('efficientnet_b0', pretrained=False)

    def forward(self, x):
        x = self.backbone(x)
        return x

(可选)preprocess.py #

可选的一些预处理,比如:flip、 resize 等等。

!cat /home/featurize/data/torchserve/preprocess.py
from torchvision.transforms.transforms import Grayscale
from ts.torch_handler.image_classifier import ImageClassifier
from torchvision import transforms


class CustomHandler(ImageClassifier):
    image_processing = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
    ])

模型打包 #

!wget https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json
--2021-11-02 12:03:01--  https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json
Connecting to 127.0.0.1:7890... connected.
Proxy request sent, awaiting response... 200 OK
Length: 35363 (35K) [text/plain]
Saving to: 'index_to_name.json'

index_to_name.json  100%[===================>]  34.53K   206KB/s    in 0.2s    

2021-11-02 12:03:03 (206 KB/s) - 'index_to_name.json' saved [35363/35363]

!torch-model-archiver  \
--model-name efficientnetb0 \
--handler image_classifier  \
--version 1.0  \
--model-file /home/featurize/data/torchserve/model.py  \
--serialized-file model.pth.tar  \
--export-path model_store  \
--extra-files index_to_name.json

启动服务 #

import os
os.system("torchserve --start --ncs --model-store model_store --models efficientnetb0.mar")
0

下载测试图片 #

!curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg

import cv2
import matplotlib.pyplot as plt

plt.imshow(cv2.cvtColor(cv2.imread('kitten_small.jpg'), cv2.COLOR_BGR2RGB));
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  7341  100  7341    0     0   3607      0  0:00:02  0:00:02 --:--:--  3607

发送推理请求 #

%%time
!curl http://127.0.0.1:8080/predictions/efficientnetb0 -T /home/featurize/kitten_small.jpg
{
  "tabby": 0.45582714676856995,
  "lynx": 0.2556627094745636,
  "Egyptian_cat": 0.1583441197872162,
  "tiger_cat": 0.04835391417145729,
  "tiger": 0.003294622991234064
}CPU times: user 48.2 ms, sys: 52.8 ms, total: 101 ms
Wall time: 1.5 s

停止服务 #

os.system("torchserve --stop")
0

image

评论(0条)