分享 | 基于图像分类网络ResNet50_vd实现桃子分类
新机器视觉
共 6938字,需浏览 14分钟
· 2021-12-23
点击下方卡片,关注“新机器视觉”公众号
视觉/图像重磅干货,第一时间送达
随着时代的快速发展,人工智能已经融入我们生活的方方面面。中国的农业也因人工智能而受益进入高速发展阶段。现今,看庄稼长势有卫星遥感技术,水果分拣有智能分拣系统,灌溉施肥有自动化机械……
本实践旨在通过桃子分类来让大家对图像分类问题有一个初步了解,同时理解和掌握如何使用PaddleHub搭建一个经典的卷积神经网络。
方案设计
环境搭建与准备
!pip install paddlehub==2.0.4 -i https://pypi.tuna.tsinghua.edu.cn/simple
数据处理
├─data: 数据目录
├─train_list.txt:训练集数据列表
├─test_list.txt:测试集数据列表
├─validate_list.txt:验证集数据列表
├─label_list.txt:标签列表
└─……
图片1路径 图片1标签
图片2路径 图片2标签
...
分类1名称
分类2名称
...
!unzip -q -o ./data/data67225/peach.zip -d ./work
__init__
,__getitem__
和__len__
三个部分。示例如下:#coding:utf-8
import os
import paddle
import paddlehub as hub
class DemoDataset(paddle.io.Dataset):
def __init__(self, transforms, num_classes=4, mode='train'):
# 数据集存放位置
self.dataset_dir = "./work/peach-classification" #dataset_dir为数据集实际路径,需要填写全路径
self.transforms = transforms
self.num_classes = num_classes
self.mode = mode
if self.mode == 'train':
self.file = 'train_list.txt'
elif self.mode == 'test':
self.file = 'test_list.txt'
else:
self.file = 'validate_list.txt'
self.file = os.path.join(self.dataset_dir , self.file)
self.data = []
with open(self.file, 'r') as f:
for line in f.readlines():
line = line.strip()
if line != '':
self.data.append(line)
def __getitem__(self, idx):
img_path, grt = self.data[idx].split(' ')
img_path = os.path.join(self.dataset_dir, img_path)
im = self.transforms(img_path)
return im, int(grt)
def __len__(self):
return len(self.data)
import paddlehub.vision.transforms as T
transforms = T.Compose(
[T.Resize((256, 256)),
T.CenterCrop(224),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])],
to_rgb=True)
peach_train = DemoDataset(transforms)
peach_validate = DemoDataset(transforms, mode='val')
peach_test = DemoDataset(transforms, mode='test')
模型构建
#安装预训练模型
!hub install resnet50_vd_imagenet_ssld==1.1.0
# 加载模型
import paddlehub as hub
model = hub.Module(name='resnet50_vd_imagenet_ssld', label_list=["R0", "B1", "M2", "S3"])
模型训练
from paddlehub.finetune.trainer import Trainer
import paddle
optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt', use_gpu=True)
trainer.train(peach_train, epochs=10, batch_size=16, eval_dataset=peach_validate, save_interval=1)
learning_rate:全局学习率。默认为1e-3; parameters:待优化模型参数。
运行配置
model:被优化模型; optimizer:优化器选择; use_vdl:是否使用vdl可视化训练过程; checkpoint_dir:保存模型参数的地址; compare_metrics:保存最优模型的衡量指标;
train_dataset:训练时所用的数据集; epochs:训练轮数; batch_size:训练的批大小,如果使用GPU,请根据实际情况调整batch_size; num_workers:works的数量,默认为0; eval_dataset:验证集; log_interval:打印日志的间隔, 单位为执行批训练的次数。 save_interval:保存模型的间隔频次,单位为执行训练的轮数。
模型评估
# 模型评估
trainer.evaluate(peach_test, 16)
模型推理
import paddle
import paddlehub as hub
from PIL import Image
import matplotlib.pyplot as plt
img_path = './work/test.jpg'
img = Image.open(img_path)
plt.imshow(img)
plt.axis('off')
plt.show()
result = model.predict([img_path])
print("桃子的类别被预测为:{}".format(result))
模型部署
{
"modules_info": {
"resnet50_vd_imagenet_ssld": {
"init_args": {
"version": "1.1.0",
"label_list":["R0", "B1", "M2", "S3"],
"load_checkpoint": "img_classification_ckpt/best_model/model.pdparams"
},
"predict_args": {
"batch_size": 1
}
}
},
"port": 8866,
"gpu": "0"
}
$ hub serving start --config config.json
import requests
import json
import cv2
import base64
import numpy as np
def cv2_to_base64(image):
data = cv2.imencode('.jpg', image)[1]
return base64.b64encode(data.tostring()).decode('utf8')
def base64_to_cv2(b64str):
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data
# 发送HTTP请求
org_im = cv2.imread('/PATH/TO/IMAGE')
data = {'images':[cv2_to_base64(org_im)], 'top_k':1}
headers = {"Content-type": "application/json"}
url = "http://127.0.0.1:8866/predict/resnet50_vd_imagenet_ssld"
r = requests.post(url=url, headers=headers, data=json.dumps(data))
data =r.json()["results"]['data']
评论
分享几个前端中好玩且有用的开源工具,总有一个适合你!
点击上方 前端Q,关注公众号回复加群,加入前端Q技术交流群正所谓差生文具多,作为前端的我们,拥有几个合适的工具和网站可以很有效的提高我们的工具效率,还会有一些很有趣的网站可以在我们敲 bug 累了的时候供我们娱乐,接下来我就和大嘎分析一下我在用的一些工具和网站。聚合API该网站提供了大量的
前端Q
0
SpringBoot+Minio实现上传凭证、分片上传、秒传和断点续传
关注我们,设为星标,每天7:40不见不散,架构路上与您共享回复架构师获取资源大家好,我是你们的朋友架构君,一个会写代码吟诗的架构师。Spring Boot整合Minio后,前端的文件上传有两种方式:1、文件上传到后端,由后端保存到Minio这种方式好处是完全由后端集中管理,可以很好的做到、身份验证、
Java架构师社区
0
6大类最新AI工具,共计39个分类梳理!
你好,我是郭震俗话说,工欲善其事必先利其器,用好AI工具一定事半功倍!这也是AI技术革命带给我们最能感知到的地方之一。这篇文章总结了6大类AI工具,分别包括:问答,图像,视频,AI编程,AI提示词和AI大模型,一共梳理挑选了共计39个AI工具,其中很多都是开源!文末还包括完整思维导图,大家记得收藏这
Python与算法社区
10
分享一份抓取某东商品名称、价格和评论数的代码
点击上方“Python共享之家”,进行关注回复“资源”即可获赠Python学习资料今日鸡汤芳草已云暮,故人殊未来。大家好,我是皮皮。一、前言前几天在Python白银交流群【邮递员】问了一个Python网络爬虫的问题,提问截图如下:代码如下:import requestsfrom
IT共享之家
0
图像处理基础知识
点击上方“小白学视觉”,选择加"星标"或“置顶”重磅干货,第一时间送达图像1、模拟图像模拟图像,又称连续图像,是指在二维坐标系中连续变化的图像,即图像的像点是无限稠密的,同时具有灰度值(即图像从暗到亮的变化值)。2、数字图像数字图像,又称数码图像或数位图像,是二维图像用有限数字数值像素的表示。数字图
小白学视觉
429
如何使用 Python比较两张图像并获得准确度?
点击上方“小白学视觉”,选择加"星标"或“置顶”重磅干货,第一时间送达本文,将带你了解如何使用 Python、OpenCV 和人脸识别模块比较两张图像并获得这些图像之间的准确度水平。首先,你需要了解我们是如何比较两个图像的。我们正在使用Face Recognition python 模块来获取两张图
小白学视觉
142
超越原生,散点图实现华夫饼图
之前我们介绍过了如何使用新卡片图实现华夫饼图。参考:超越原生,PowerBI 华夫饼图实现但是利用卡片图实现的华夫饼图有一些缺点,形状之间的大小跟间距不太好把握,而且有时形状大一点的话显示就会不正常,需要做出二次调整。今天给大家介绍一种原生视觉对象生成华夫饼图的更佳方案,既简单又美观。上图是利用散点
PowerBI战友联盟
2
使用OpenCV测量图像中物体之间的距离
点击上方“小白学视觉”,选择加"星标"或“置顶”重磅干货,第一时间送达来源丨opcv学堂编辑丨极市平台极市导读 附详细代码操作。本文来自光头哥哥的博客【Measuring distance between objects in an image with OpenCV】,仅做学习分享。原文
小白学视觉
630