实例
数据集
笔记本
笔记本

如何部署 PyTorch 深度学习模型
使用 TorchServe 对 PyTorch 训练的模型进行部署
笔记本内容
如何布署 PyTorch 深度学习模型 #
需要介绍一下 TorchServe,这是一个灵活且易于使用的工具,用于为 PyTorch 模型提供服务。
为什么要布署 #
有小伙伴问了:模型还要怎么布署呢,我直接 python run.py 运行不行吗?

设想一个最简单的刷脸进门的程序,如果每来一个人都要手动执行一次 python run.py,中午饭点的时候手可能就抽筋了。。。。。
所以我们需要一个能够随时监听请求的服务,来代替我们的双手。
怎么布署 #
简单来讲,通过网络接收各种协议(HTTP)发送过来的输入数据,调用提前存放的模型进行推理,再返回结果。
当然发送的数据会遵守一定的规范(REST),返回的数据也遵循一定的格式(json, xml),这些细节感兴趣的小伙伴可以自行了解学习。
TorchServe 架构 #
安装 #
# 安装 Torch Serve 以及必要的一些包
%pip install torchserve torch-model-archiver torch-workflow-archiver captum timm
Looking in indexes: https://mirrors.aliyun.com/pypi/simple Collecting torchserve Using cached https://mirrors.aliyun.com/pypi/packages/b3/97/81e2d0fae8c6697501d7a63d67eb74b43bde33dca66e42ae223d5403a622/torchserve-0.4.2-py2.py3-none-any.whl (18.1 MB) Collecting torch-model-archiver Using cached https://mirrors.aliyun.com/pypi/packages/8f/8a/996b77e076e3aeac6dda067537d6a407af7e399f129b3fc479e5d03a6a3d/torch_model_archiver-0.4.2-py2.py3-none-any.whl (14 kB) Collecting torch-workflow-archiver Using cached https://mirrors.aliyun.com/pypi/packages/1c/c1/3a6867df0ed8d7a89dccf37ac491fae63e0e20104bf21abb4728fa4fee43/torch_workflow_archiver-0.1.2-py2.py3-none-any.whl (12 kB) Collecting captum Using cached https://mirrors.aliyun.com/pypi/packages/cc/da/50dd447964766b92a0d1e3781559401b5c58d3b524b8dbb8fab75dc98070/captum-0.4.0-py3-none-any.whl (1.4 MB) Collecting timm Using cached https://mirrors.aliyun.com/pypi/packages/90/fc/606bc5cf46acac3aa9bd179b3954433c026aaf88ea98d6b19f5d14c336da/timm-0.4.12-py3-none-any.whl (376 kB) Requirement already satisfied: numpy in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from captum) (1.19.5) Requirement already satisfied: matplotlib in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from captum) (3.3.4) Requirement already satisfied: torch>=1.2 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from captum) (1.9.0+cu111) Requirement already satisfied: typing-extensions in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from torch>=1.2->captum) (3.10.0.2) Requirement already satisfied: torchvision in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from timm) (0.10.0+cu111) Collecting future Using cached future-0.18.2-py3-none-any.whl Collecting enum-compat Using cached https://mirrors.aliyun.com/pypi/packages/55/ae/467bc4509246283bb59746e21a1a2f5a8aecbef56b1fa6eaca78cd438c8b/enum_compat-0.0.3-py3-none-any.whl (1.3 kB) Requirement already satisfied: Pillow in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from torchserve) (8.0.0) Collecting psutil Using cached https://mirrors.aliyun.com/pypi/packages/84/da/f7efdcf012b51506938553dbe302aecc22f3f43abd5cffa8320e8e0588d5/psutil-5.8.0-cp37-cp37m-manylinux2010_x86_64.whl (296 kB) Requirement already satisfied: packaging in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from torchserve) (21.0) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.3 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (2.4.7) Requirement already satisfied: cycler>=0.10 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (0.10.0) Requirement already satisfied: kiwisolver>=1.0.1 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (1.2.0) Requirement already satisfied: python-dateutil>=2.1 in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from matplotlib->captum) (2.8.1) Requirement already satisfied: six in /environment/python/versions/miniconda3-4.7.12/lib/python3.7/site-packages (from cycler>=0.10->matplotlib->captum) (1.15.0) Installing collected packages: psutil, future, enum-compat, torchserve, torch-workflow-archiver, torch-model-archiver, timm, captum Successfully installed captum-0.4.0 enum-compat-0.0.3 future-0.18.2 psutil-5.8.0 timm-0.4.12 torch-model-archiver-0.4.2 torch-workflow-archiver-0.1.2 torchserve-0.4.2 Note: you may need to restart the kernel to use updated packages.
# 创建一个用于存储模型的目录
!mkdir model_store
权重文件 #
真实环境中就使用自己训练好的权重,这里方便展示就使用了 imagenet 的预训练模型
import timm
import torch
import torch.nn.utils.prune as prune
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.backbone = timm.create_model('efficientnet_b0', pretrained=True)
def forward(self, x):
x = self.backbone(x)
return x
model = Net()
torch.save(model.state_dict(), 'model.pth.tar')
model.py #
model.py 需要包含单个模型的类, 并且能成功加载(torch.load_state_dict)上面的 model.pth
!featurize dataset download ee8b0992-df2e-4ad6-a240-1f486b8eef8b
100%|██████████████████████████████████████████| 845/845 [00:00<00:00, 200kiB/s] 🍬 下载完成,正在解压... 🏁 数据集已经成功添加
!cat /home/featurize/data/torchserve/model.py
import torch import timm class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.backbone = timm.create_model('efficientnet_b0', pretrained=False) def forward(self, x): x = self.backbone(x) return x
(可选)preprocess.py #
可选的一些预处理,比如:flip、 resize 等等。
!cat /home/featurize/data/torchserve/preprocess.py
from torchvision.transforms.transforms import Grayscale from ts.torch_handler.image_classifier import ImageClassifier from torchvision import transforms class CustomHandler(ImageClassifier): image_processing = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), ])
模型打包 #
!wget https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json
--2021-11-02 12:03:01-- https://raw.githubusercontent.com/pytorch/serve/master/examples/image_classifier/index_to_name.json Connecting to 127.0.0.1:7890... connected. Proxy request sent, awaiting response... 200 OK Length: 35363 (35K) [text/plain] Saving to: 'index_to_name.json' index_to_name.json 100%[===================>] 34.53K 206KB/s in 0.2s 2021-11-02 12:03:03 (206 KB/s) - 'index_to_name.json' saved [35363/35363]
!torch-model-archiver \
--model-name efficientnetb0 \
--handler image_classifier \
--version 1.0 \
--model-file /home/featurize/data/torchserve/model.py \
--serialized-file model.pth.tar \
--export-path model_store \
--extra-files index_to_name.json
启动服务 #
import os
os.system("torchserve --start --ncs --model-store model_store --models efficientnetb0.mar")
0
下载测试图片 #
!curl -O https://raw.githubusercontent.com/pytorch/serve/master/docs/images/kitten_small.jpg
import cv2
import matplotlib.pyplot as plt
plt.imshow(cv2.cvtColor(cv2.imread('kitten_small.jpg'), cv2.COLOR_BGR2RGB));
% Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 7341 100 7341 0 0 3607 0 0:00:02 0:00:02 --:--:-- 3607
发送推理请求 #
%%time
!curl http://127.0.0.1:8080/predictions/efficientnetb0 -T /home/featurize/kitten_small.jpg
{ "tabby": 0.45582714676856995, "lynx": 0.2556627094745636, "Egyptian_cat": 0.1583441197872162, "tiger_cat": 0.04835391417145729, "tiger": 0.003294622991234064 }CPU times: user 48.2 ms, sys: 52.8 ms, total: 101 ms Wall time: 1.5 s
停止服务 #
os.system("torchserve --stop")
0
评论(0条)