您当前的位置:首页 > 电脑百科 > 程序开发 > 编程百科

你需要知道的11个Torchvision计算机视觉数据集

时间:2023-04-07 13:30:04  来源:  作者:王瑞平

计算机视觉是一个显著增长的领域,有许多实际应用,从自动驾驶汽车到面部识别系统。该领域的主要挑战之一是获得高质量的数据集来训练机器学习模型。

Torchvision作为Pytorch的图形库,一直服务于PyTorch深度学习框架,主要用于构建计算机视觉模型。

为了解决这一挑战,Torchvision提供了访问预先构建的数据集、模型和专门为计算机视觉任务设计的转换。此外,Torchvision还支持CPU和GPU的加速,使其成为开发计算机视觉应用程序的灵活且强大的工具。

一、什么是“Torchvision数据集”?

Torchvision数据集是计算机视觉中常用的用于开发和测试机器学习模型的流行数据集集合。运用Torchvision数据集,开发人员可以在一系列任务上训练和测试他们的机器学习模型,例如,图像分类、对象检测和分割。数据集还经过预处理、标记并组织成易于加载和使用的格式。

据了解,Torchvision包由流行的数据集、模型体系结构和通用的计算机视觉图像转换组成。简单地说就是“常用数据集+常见模型+常见图像增强”方法。

Torchvision中的数据集共有11种:MNIST、CIFAR-10等,下面具体说说。

二、Torchvision中的11种数据集

1.MNIST手写数字数据库

这个Torchvision数据集在机器学习和计算机视觉领域中非常流行和广泛应用。它由7万张手写数字0-9的灰度图像组成。其中,6万张用于训练,1万张用于测试。每张图像的大小为28×28像素,并有相应的标签表示它所代表的数字。

要访问此数据集,您可以直接从Kaggle下载或使用torchvision加载数据集:

import torchvision.datasets as datasets# Load the trAIning dataset
train_dataset = datasets.MNIST(root='data/', train=True, transform=None, download=True)# Load the testing dataset
test_dataset = datasets.MNIST(root='data/', train=False, transform=None, download=True)

 

图片

 

2.CIFAR-10(广泛使用的标准数据集)

CIFAR-10数据集由6万张32×32彩色图像组成,分为10个类别,每个类别有6000张图像,总共有5万张训练图像和1万张测试图像。这些图像又分为5个训练批次和一个测试批次,每个批次有1万张图像。数据集可以从Kaggle下载。

import torchimport torchvisionimport torchvision.transforms as transforms

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

在此提醒一句,您可以根据需要调整数据加载器的批处理大小和工作进程的数量。

3.CIFAR-100(广泛使用的标准数据集)

CIFAR-100数据集在100个类中有60,000张(50,000张训练图像和10,000张测试图像)32×32的彩色图像。每个类有600张图像。这100个类被分成20个超类,用一个细标签表示它的类,另一个粗标签表示它所属的超类。

import torchimport torchvisionimport torchvision.transforms as transforms

import torchvision.datasets as datasetsimport torchvision.transforms as transforms# Define transform to normalize data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
# Load CIFAR-100 train and test datasets
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
# Create data loaders for train and test datasets
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

4.Imag.NET数据集

Torchvision中的ImageNet数据集包含大约120万张训练图像,5万张验证图像和10万张测试图像。数据集中的每张图像都被标记为1000个类别中的一个,如“猫”、“狗”、“汽车”、“飞机”等。

import torchvision.datasets as datasetsimport torchvision.transforms as transforms
# Set the path to the ImageNet dataset on your machine
data_path = "/path/to/imagenet"
# Create the ImageNet dataset object with custom options
imagenet_train = datasets.ImageNet(
root=data_path,
split='train',
transform=transforms.Compose([
transforms.Resize(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
download=False)

imagenet_val = datasets.ImageNet(
root=data_path,
split='val',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
download=False)
# Print the number of images in the training and validation setsprint("Number of images in the training set:", len(imagenet_train))print("Number of images in the validation set:", len(imagenet_val))

 

图片

 

5.MSCoco数据集

Microsoft Common Objects in Context(MS Coco)数据集包含32.8万张日常物体和人类的高质量视觉图像,通常用作实时物体检测中比较算法性能的标准。

6.Fashion-MNIST数据集

时尚MNIST数据集是由Zalando Research创建的,作为原始MNIST数据集的替代品。Fashion MNIST数据集由70000张服装灰度图像(训练集60000张,测试集10000张)组成。

图片大小为28×28像素,代表10种不同类别的服装,包括:t恤/上衣、裤子、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和短靴。它类似于原始的MNIST数据集,但由于服装项目的复杂性和多样性,分类任务更具挑战性。这个Torchvision数据集可以从Kaggle下载。

import torchimport torchvisionimport torchvision.transforms as transforms
# Define transformations
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])# Load the dataset
trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
download=True, transform=transform)

testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
download=True, transform=transform)
# Create data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

7.SVHN数据集

SVHN(街景门牌号)数据集是一个来自谷歌街景图像的图像数据集,它由从街道级图像中截取的门牌号的裁剪图像组成。它包含所有门牌号及其包围框的完整格式和仅包含门牌号的裁剪格式。完整格式通常用于对象检测任务,而裁剪格式通常用于分类任务。

SVHN数据集也包含在Torchvision包中,它包含了73,257张用于训练的图像、26,032张用于测试的图像和531,131张用于额外训练数据的额外图像。

import torchvisionimport torch
# Load the train and test sets
train_set = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=torchvision.transforms.ToTensor())
test_set = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=torchvision.transforms.ToTensor())
# Create data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

8.STL-10数据集

STL-10数据集是一个图像识别数据集,由10个类组成,总共约6000 +张图像。STL-10代表“图像识别标准训练和测试集-10类”,数据集中的10个类是:飞机、鸟、汽车、猫、鹿、狗、马、猴子、船、卡车。您可以直接从Kaggle下载数据集。

 

import torchvision.datasets as datasetsimport torchvision.transforms as transforms
# Define the transformation to Apply to the data
transform = transforms.Compose([
    transforms.ToTensor(),   
# Convert PIL image to PyTorch tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))   # Normalize the data])
# Load the STL-10 dataset
train_dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)
test_dataset = datasets.STL10(root='./data', split='test', download=True, transform=transform)

 

9.CelebA数据集

这个Torchvision数据集是一个流行的大规模面部属性数据集,包含超过20万张名人图像。2015年,香港中文大学的研究人员首次发布了这一数据。CelebA中的图像包含40个面部属性,如,年龄、头发颜色、面部表情和性别。

此外,这些图片是从互联网上检索到的,涵盖了广泛的面部外观,包括不同的种族、年龄和性别。每个图像中面部位置的边界框注释,以及眼睛、鼻子和嘴巴的5个地标点。

import torchvision.datasets as datasetsimport torchvision.transforms as transforms
transform = transforms.Compose([
transforms.CenterCrop(178),
transforms.Resize(128),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
celeba_dataset = datasets.CelebA(root='./data', split='train', transform=transform, download=True)

10.PASCAL VOC数据集

VOC数据集(视觉对象类)于2005年作为PASCAL VOC挑战的一部分首次引入。该挑战旨在推进视觉识别的最新水平。它由20种不同类别的物体组成,包括:动物、交通工具和常见的家用物品。这些图像中的每一个都标注了图像中物体的位置和分类。注释包括边界框和像素级分割掩码。

数据集分为两个主要集:训练集和验证集。

训练集包含大约5000张带有注释的图像,而验证集包含大约5000张没有注释的图像。此外,该数据集还包括一个包含大约10,000张图像的测试集,但该测试集的注释是不可公开的。

 
import torchimport torchvisionfrom torchvision import transforms
# Define transformations to apply to the images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# Load the train and validation datasets
train_dataset = torchvision.datasets.VOCDetection(root='./data', year='2007', image_set='train', transform=transform)
val_dataset = torchvision.datasets.VOCDetection(root='./data', year='2007', image_set='val', transform=transform)# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

11.Places365数据集

Places365数据集是一个大型场景识别数据集,拥有超过180万张图像,涵盖365个场景类别。Places365标准数据集包含约180万张图像,而Places365挑战数据集包含5万张额外的验证图像,这些图像对识别模型更具挑战性。

import torchimport torchvisionfrom torchvision import transforms
# Define transformations to apply to the images
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# Load the train and validation datasets
train_dataset = torchvision.datasets.Places365(root='./data', split='train-standard', transform=transform)
val_dataset = torchvision.datasets.Places365(root='./data', split='val', transform=transform)# Create data loaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=False)

 

图片

 

三、总结

总之,Torchvision数据集通常用于训练和评估机器学习模型,如卷积神经网络(CNNs)。这些模型通常用于计算机视觉应用,任何人都可以免费下载和使用。本文的主要图像是通过HackerNoon的AI稳定扩散模型生成的。

参考链接:​​https://hackernoon.com/11-torchvision-datasets-for-computer-vision-you-need-to-know​



Tags:Torchvision   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,不构成投资建议。投资者据此操作,风险自担。如有任何标注错误或版权侵犯请与我们联系,我们将及时更正、删除。
▌相关推荐
你需要知道的11个Torchvision计算机视觉数据集
计算机视觉是一个显著增长的领域,有许多实际应用,从自动驾驶汽车到面部识别系统。该领域的主要挑战之一是获得高质量的数据集来训练机器学习模型。Torchvision作为Pytorch的图...【详细内容】
2023-04-07  Search: Torchvision  点击:(308)  评论:(0)  加入收藏
▌简易百科推荐
Netflix 是如何管理 2.38 亿会员的
作者 | Surabhi Diwan译者 | 明知山策划 | TinaNetflix 高级软件工程师 Surabhi Diwan 在 2023 年旧金山 QCon 大会上发表了题为管理 Netflix 的 2.38 亿会员 的演讲。她在...【详细内容】
2024-04-08    InfoQ  Tags:Netflix   点击:(0)  评论:(0)  加入收藏
即将过时的 5 种软件开发技能!
作者 | Eran Yahav编译 | 言征出品 | 51CTO技术栈(微信号:blog51cto) 时至今日,AI编码工具已经进化到足够强大了吗?这未必好回答,但从2023 年 Stack Overflow 上的调查数据来看,44%...【详细内容】
2024-04-03    51CTO  Tags:软件开发   点击:(6)  评论:(0)  加入收藏
跳转链接代码怎么写?
在网页开发中,跳转链接是一项常见的功能。然而,对于非技术人员来说,编写跳转链接代码可能会显得有些困难。不用担心!我们可以借助外链平台来简化操作,即使没有编程经验,也能轻松实...【详细内容】
2024-03-27  蓝色天纪    Tags:跳转链接   点击:(13)  评论:(0)  加入收藏
中台亡了,问题到底出在哪里?
曾几何时,中台一度被当做“变革灵药”,嫁接在“前台作战单元”和“后台资源部门”之间,实现企业各业务线的“打通”和全域业务能力集成,提高开发和服务效率。但在中台如火如荼之...【详细内容】
2024-03-27  dbaplus社群    Tags:中台   点击:(9)  评论:(0)  加入收藏
员工写了个比删库更可怕的Bug!
想必大家都听说过删库跑路吧,我之前一直把它当一个段子来看。可万万没想到,就在昨天,我们公司的某位员工,竟然写了一个比删库更可怕的 Bug!给大家分享一下(不是公开处刑),希望朋友们...【详细内容】
2024-03-26  dbaplus社群    Tags:Bug   点击:(5)  评论:(0)  加入收藏
我们一起聊聊什么是正向代理和反向代理
从字面意思上看,代理就是代替处理的意思,一个对象有能力代替另一个对象处理某一件事。代理,这个词在我们的日常生活中也不陌生,比如在购物、旅游等场景中,我们经常会委托别人代替...【详细内容】
2024-03-26  萤火架构  微信公众号  Tags:正向代理   点击:(11)  评论:(0)  加入收藏
看一遍就理解:IO模型详解
前言大家好,我是程序员田螺。今天我们一起来学习IO模型。在本文开始前呢,先问问大家几个问题哈~什么是IO呢?什么是阻塞非阻塞IO?什么是同步异步IO?什么是IO多路复用?select/epoll...【详细内容】
2024-03-26  捡田螺的小男孩  微信公众号  Tags:IO模型   点击:(9)  评论:(0)  加入收藏
为什么都说 HashMap 是线程不安全的?
做Java开发的人,应该都用过 HashMap 这种集合。今天就和大家来聊聊,为什么 HashMap 是线程不安全的。1.HashMap 数据结构简单来说,HashMap 基于哈希表实现。它使用键的哈希码来...【详细内容】
2024-03-22  Java技术指北  微信公众号  Tags:HashMap   点击:(11)  评论:(0)  加入收藏
如何从头开始编写LoRA代码,这有一份教程
选自 lightning.ai作者:Sebastian Raschka机器之心编译编辑:陈萍作者表示:在各种有效的 LLM 微调方法中,LoRA 仍然是他的首选。LoRA(Low-Rank Adaptation)作为一种用于微调 LLM(大...【详细内容】
2024-03-21  机器之心Pro    Tags:LoRA   点击:(12)  评论:(0)  加入收藏
这样搭建日志中心,传统的ELK就扔了吧!
最近客户有个新需求,就是想查看网站的访问情况。由于网站没有做google的统计和百度的统计,所以访问情况,只能通过日志查看,通过脚本的形式给客户导出也不太实际,给客户写个简单的...【详细内容】
2024-03-20  dbaplus社群    Tags:日志   点击:(4)  评论:(0)  加入收藏
相关文章
    无相关信息
站内最新
站内热门
站内头条