首页 文章详情

使用OpenCV与sklearn实现基于词袋模型的图像分类预测与搜索

小白学视觉 | 331 2022-01-20 06:53 0 0 0
UniSMS (合一短信)

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

基于OpenCV实现SIFT特征提取与BOW(Bag of Word)生成向量数据,然后使用sklearn的线性SVM分类器训练模型,实现图像分类预测。实现基于词袋模型的图像分类预测与搜索,大致要分为如下四步:

1.特征提取与描述生成

这里选择SIFT特征,SIFT特征具有放缩、旋转、光照不变性,同时兼有对几何畸变,图像几何变形的一定程度的鲁棒性,使用Python OpenCV扩展模块中的SIFT特征提取接口,就可以提取图像的SIFT特征点与描述子。

2.词袋生成

词袋生成,是基于描述子数据的基础上,生成一系列的向量数据,最常见就是首先通过K-Means实现对描述子数据的聚类分析,一般会分成100个聚类、得到每个聚类的中心数据,就生成了100 词袋,根据每个描述子到这些聚类中心的距离,决定了它属于哪个聚类,这样就生成了它的直方图表示数据。

3.SVM分类训练与模型生成

使用SVM进行数据的分类训练,得到输出模型,这里通过sklearn的线性SVM训练实现了分类模型训练与导出。

4.模型使用与预测

加载预训练好的模型,使用模型在测试集上进行数据预测,测试表明,对于一些简单的图像分类与相似图像预测都可以获得比较好的效果。

完整步骤图示如下:

其中SIFT特征提取算法主要有如下几步:

1.构建高斯金子塔图像,寻找极值点 2.极值点亚像素级别定位 3.图像梯度与角度直方图建立 4.特征描述子建立

K-Means聚类方法 - 参见公众号以前的文章即可

OpenCV中KMeans算法介绍与应用

代码实现,特征提取与训练模型导出

  1. import cv2

  2. import imutils

  3. import numpy as np

  4. import os

  5. from sklearn.svm import LinearSVC

  6. from sklearn.externals import joblib

  7. from scipy.cluster.vq import *

  8. from sklearn.preprocessing import StandardScaler

  9. # Get the training classes names and store them in a list

  10. train_path = "dataset/train/"

  11. training_names = os.listdir(train_path)

  12. # Get all the path to the images and save them in a list

  13. # image_paths and the corresponding label in image_paths

  14. image_paths = []

  15. image_classes = []

  16. class_id = 0

  17. for training_name in training_names:

  18.    dir = os.path.join(train_path, training_name)

  19.    class_path = imutils.imlist(dir)

  20.    image_paths += class_path

  21.    image_classes += [class_id] * len(class_path)

  22.    class_id += 1

  23. # 创建SIFT特征提取器

  24. sift = cv2.xfeatures2d.SIFT_create()

  25. # 特征提取与描述子生成

  26. des_list = []

  27. for image_path in image_paths:

  28.    im = cv2.imread(image_path)

  29.    im = cv2.resize(im, (300, 300))

  30.    kpts = sift.detect(im)

  31.    kpts, des = sift.compute(im, kpts)

  32.    des_list.append((image_path, des))

  33.    print("image file path : ", image_path)

  34. # 描述子向量

  35. descriptors = des_list[0][1]

  36. for image_path, descriptor in des_list[1:]:

  37.    descriptors = np.vstack((descriptors, descriptor))

  38. # 100 聚类 K-Means

  39. k = 100

  40. voc, variance = kmeans(descriptors, k, 1)

  41. # 生成特征直方图

  42. im_features = np.zeros((len(image_paths), k), "float32")

  43. for i in range(len(image_paths)):

  44.    words, distance = vq(des_list[i][1], voc)

  45.    for w in words:

  46.        im_features[i][w] += 1

  47. # 实现动词词频与出现频率统计

  48. nbr_occurences = np.sum((im_features > 0) * 1, axis=0)

  49. idf = np.array(np.log((1.0 * len(image_paths) + 1) / (1.0 * nbr_occurences + 1)), 'float32')

  50. # 尺度化

  51. stdSlr = StandardScaler().fit(im_features)

  52. im_features = stdSlr.transform(im_features)

  53. # Train the Linear SVM

  54. clf = LinearSVC()

  55. clf.fit(im_features, np.array(image_classes))

  56. # Save the SVM

  57. print("training and save model...")

  58. joblib.dump((clf, training_names, stdSlr, k, voc), "bof.pkl", compress=3)

在训练图像上的运行输出:

  1. "C:\Program Files\Python\Python36\python.exe" D:/python/image_classification/feature_detection.py

  2. image file path :  dataset/train/aeroplane\1.jpg

  3. image file path :  dataset/train/aeroplane\10.jpg

  4. image file path :  dataset/train/aeroplane\11.jpg

  5. image file path :  dataset/train/aeroplane\12.jpg

  6. image file path :  dataset/train/aeroplane\13.jpg

  7. image file path :  dataset/train/aeroplane\14.jpg

  8. image file path :  dataset/train/aeroplane\15.jpg

  9. image file path :  dataset/train/aeroplane\16.jpg

  10. image file path :  dataset/train/aeroplane\17.jpg

  11. image file path :  dataset/train/aeroplane\2.jpg

  12. image file path :  dataset/train/aeroplane\3.jpg

  13. image file path :  dataset/train/aeroplane\4.jpg

  14. image file path :  dataset/train/aeroplane\5.jpg

  15. image file path :  dataset/train/aeroplane\6.jpg

  16. image file path :  dataset/train/aeroplane\7.jpg

  17. image file path :  dataset/train/aeroplane\8.jpg

  18. image file path :  dataset/train/aeroplane\9.jpg

  19. image file path :  dataset/train/bicycle\1.jpg

  20. image file path :  dataset/train/bicycle\10.jpg

  21. image file path :  dataset/train/bicycle\11.jpg

  22. image file path :  dataset/train/bicycle\12.jpg

  23. image file path :  dataset/train/bicycle\13.jpg

  24. image file path :  dataset/train/bicycle\14.JPG

  25. image file path :  dataset/train/bicycle\15.png

  26. image file path :  dataset/train/bicycle\16.jpg

  27. image file path :  dataset/train/bicycle\17.jpg

  28. image file path :  dataset/train/bicycle\2.jpg

  29. image file path :  dataset/train/bicycle\3.jpg

  30. image file path :  dataset/train/bicycle\4.png

  31. image file path :  dataset/train/bicycle\5.jpg

  32. image file path :  dataset/train/bicycle\6.jpg

  33. image file path :  dataset/train/bicycle\7.jpg

  34. image file path :  dataset/train/bicycle\8.JPG

  35. image file path :  dataset/train/bicycle\9.jpg

  36. image file path :  dataset/train/car\1.jpg

  37. image file path :  dataset/train/car\10.jpg

  38. image file path :  dataset/train/car\11.jpg

  39. image file path :  dataset/train/car\12.jpg

  40. image file path :  dataset/train/car\13.jpg

  41. image file path :  dataset/train/car\14.jpg

  42. image file path :  dataset/train/car\15.jpg

  43. image file path :  dataset/train/car\16.jpg

  44. image file path :  dataset/train/car\17.jpg

  45. image file path :  dataset/train/car\2.jpeg

  46. image file path :  dataset/train/car\3.jpg

  47. image file path :  dataset/train/car\4.jpg

  48. image file path :  dataset/train/car\5.jpg

  49. image file path :  dataset/train/car\6.jpg

  50. image file path :  dataset/train/car\7.jpg

  51. image file path :  dataset/train/car\8.jpg

  52. image file path :  dataset/train/car\9.jpg

  53. training and save model...


程序测试


  1. import os

  2. import imutils

  3. import cv2 as cv

  4. import numpy as np

  5. from sklearn.externals import joblib

  6. from scipy.cluster.vq import *

  7. # Load the classifier, class names, scaler, number of clusters and vocabulary

  8. clf, classes_names, stdSlr, k, voc = joblib.load("bof.pkl")

  9. # Create feature extraction and keypoint detector objects

  10. sift = cv.xfeatures2d.SIFT_create()

  11. def predict_image(image_path):

  12.    # List where all the descriptors are stored

  13.    des_list = []

  14.    im = cv.imread(image_path, cv.IMREAD_GRAYSCALE)

  15.    im = cv.resize(im, (300, 300))

  16.    kpts = sift.detect(im)

  17.    kpts, des = sift.compute(im, kpts)

  18.    des_list.append((image_path, des))

  19.    descriptors = des_list[0][1]

  20.    for image_path, descriptor in des_list[0:]:

  21.        descriptors = np.vstack((descriptors, descriptor))

  22.    test_features = np.zeros((1, k), "float32")

  23.    words, distance = vq(des_list[0][1], voc)

  24.    for w in words:

  25.        test_features[0][w] += 1

  26.    # Perform Tf-Idf vectorization

  27.    nbr_occurences = np.sum((test_features > 0) * 1, axis=0)

  28.    idf = np.array(np.log((1.0 + 1) / (1.0 * nbr_occurences + 1)), 'float32')

  29.    # Scale the features

  30.    test_features = stdSlr.transform(test_features)

  31.    # Perform the predictions

  32.    predictions = [classes_names[i] for i in clf.predict(test_features)]

  33.    return predictions

  34. if __name__ == "__main__":

  35.    test_path = "dataset/test/"

  36.    testing_names = os.listdir(test_path)

  37.    image_paths = []

  38.    for training_name in testing_names:

  39.        dir = os.path.join(test_path, training_name)

  40.        class_path = imutils.imlist(dir)

  41.        image_paths += class_path

  42.    for image_path in image_paths:

  43.        predictions = predict_image(image_path)

  44.        print("image: %s, classes : %s"%(image_path, predictions))

测试集预测运行结果:

  1. "C:\Program Files\Python\Python36\python.exe" D:/python/image_classification/demo.py

  2. image: dataset/test/aeroplane\test_1.jpg, classes : ['aeroplane']

  3. image: dataset/test/aeroplane\test_2.jpg, classes : ['aeroplane']

  4. image: dataset/test/aeroplane\test_3.jpg, classes : ['aeroplane']

  5. image: dataset/test/bicycle\test_1.jpg, classes : ['bicycle']

  6. image: dataset/test/bicycle\test_2.JPG, classes : ['bicycle']

  7. image: dataset/test/bicycle\test_3.jpg, classes : ['bicycle']

  8. image: dataset/test/car\test_1.jpg, classes : ['car']

  9. image: dataset/test/car\test_2.jpg, classes : ['car']

  10. image: dataset/test/car\test_3.jpg, classes : ['car']


总结

只需要几十张图像训练集,就可以对后续的图像做出一个简单的分类预测,对于一些要求不高的web项目来说,植入的成本与代价很小,值得一试!同时为了减小计算量,我对图像的最大尺度resize到300x300大小。

下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


good-icon 0
favorite-icon 0
收藏
回复数量: 0
    暂无评论~~
    Ctrl+Enter