您当前的位置:首页 > 电脑百科 > 程序开发 > 语言 > Python

突破Pytorch核心点,CNN !!!

时间:2024-01-03 14:15:05  来源:微信公众号  作者:DOWHAT小壮

创建卷积神经网络(CNN),很多初学者不太熟悉,今儿咱们来大概说说,给一个完整的案例进行说明。

CNN 用于图像分类、目标检测、图像生成等任务。它的关键思想是通过卷积层和池化层来自动提取图像的特征,并通过全连接层进行分类。

原理

1.卷积层(Convolutional Layer):

卷积层使用卷积操作从输入图像中提取特征。卷积操作涉及一个可学习的卷积核(filter/kernel),该核在输入图像上滑动,并计算滑动窗口下的点积。这有助于提取局部特征,使网络对平移不变性更强。

公式:

突破Pytorch核心点,CNN !!!

其中,x是输入,w是卷积核,b是偏置。

2.池化层(Pooling Layer):

池化层用于减小数据的空间维度,减少计算量,并提取最显著的特征。最大池化是常用的一种方式,在每个窗口中选择最大的值。

公式(最大池化):

突破Pytorch核心点,CNN !!!

3.全连接层(Fully Connected Layer):

全连接层用于将卷积和池化层提取的特征映射到输出类别。它连接到前一层的所有神经元。

实战步骤和详解

1.步骤

  • 导入必要的库和模块。
  • 定义网络结构:使用nn.Module定义一个继承自它的自定义神经网络类,定义卷积层、激活函数、池化层和全连接层。
  • 定义损失函数和优化器。
  • 加载和预处理数据。
  • 训练网络:使用训练数据迭代训练网络参数。
  • 测试网络:使用测试数据评估模型性能。

2.代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义卷积神经网络类
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # 卷积层1
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # 卷积层2
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        # 全连接层
        self.fc1 = nn.Linear(32 * 7 * 7, 10)  # 输入大小根据数据调整

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 7 * 7)
        x = self.fc1(x)
        return x

# 定义损失函数和优化器
net = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam.NET.parameters(), lr=0.001)

# 加载和预处理数据
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trAIn_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# 训练网络
num_epochs = 5
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')

# 测试网络
net.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print('Accuracy on the test set: {}%'.format(100 * accuracy))

这个示例展示了一个简单的CNN模型,使用MNIST数据集进行训练和测试。

接下来,咱们添加可视化步骤,更直观地了解模型的性能和训练过程。

可视化

1.导入matplotlib

import matplotlib.pyplot as plt

2.在训练过程中记录损失和准确率:

在训练循环中,记录每个epoch的损失和准确率。

# 在训练循环中添加以下代码
train_loss_list = []
accuracy_list = []

for epoch in range(num_epochs):
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = net(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')

    epoch_loss = running_loss / len(train_loader)
    accuracy = correct / total

    train_loss_list.Append(epoch_loss)
    accuracy_list.append(accuracy)

3.可视化损失和准确率:

# 在训练循环后,添加以下代码
plt.figure(figsize=(12, 4))

# 可视化损失
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# 可视化准确率
plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()
 

这样,咱们就可以在训练过程结束后看到训练损失和准确率的变化。

导入代码后,大家可以根据需要调整可视化的内容和格式。



Tags:Pytorch   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,不构成投资建议。投资者据此操作,风险自担。如有任何标注错误或版权侵犯请与我们联系,我们将及时更正、删除。
▌相关推荐
突破Pytorch核心点,优化器 !!
今儿咱们聊聊pytorch中的优化器。优化器在深度学习中的选择直接影响模型的训练效果和速度。不同的优化器适用于不同的问题,其性能的差异可能导致模型更快、更稳定地收敛,或者...【详细内容】
2024-01-05  Search: Pytorch  点击:(90)  评论:(0)  加入收藏
突破Pytorch核心点,CNN !!!
创建卷积神经网络(CNN),很多初学者不太熟悉,今儿咱们来大概说说,给一个完整的案例进行说明。CNN 用于图像分类、目标检测、图像生成等任务。它的关键思想是通过卷积层和池化层来...【详细内容】
2024-01-03  Search: Pytorch  点击:(86)  评论:(0)  加入收藏
PyTorch团队重写「分割一切」模型,比原始实现快八倍
编辑:陈萍我们该如何优化 Meta 的「分割一切」模型,PyTorch 团队撰写的这篇博客由浅入深的帮你解答。从年初到现在,生成式 AI 发展迅猛。但很多时候,我们又不得不面临一个难题:如...【详细内容】
2023-11-23  Search: Pytorch  点击:(249)  评论:(0)  加入收藏
基于Pytorch的从零开始的目标检测
引言目标检测是计算机视觉中一个非常流行的任务,在这个任务中,给定一个图像,你预测图像中物体的包围盒(通常是矩形的) ,并且识别物体的类型。在这个图像中可能有多个对象,而且现...【详细内容】
2023-11-10  Search: Pytorch  点击:(201)  评论:(0)  加入收藏
深度学习中实现PyTorch和NumPy之间的数据转换知多少?
在深度学习中,PyTorch和NumPy是两个常用的工具,用于处理和转换数据。PyTorch是一个基于Python的科学计算库,用于构建神经网络和深度学习模型。NumPy是一个用于科学计算的Python...【详细内容】
2023-10-13  Search: Pytorch  点击:(67)  评论:(0)  加入收藏
Star量近8万,大火AutoGPT星标超PyTorch,网友:看清它的局限性
机器之心编辑部英伟达 AI 科学家 Jim Fan 表示,「AutoGPT 只是一项有趣的实验,虽然火爆但并不意味着可以投入生产。」他的观点得到了很多人的附和和现身说法。仿佛一夜之间,AI...【详细内容】
2023-04-18  Search: Pytorch  点击:(171)  评论:(0)  加入收藏
PyTorch将塑造生成式人工智能系统(GPT-4及以上)的未来
PyTorch不仅用于研究,还用于生产目的,每天有数十亿个请求得到服务和训练。...【详细内容】
2023-04-13  Search: Pytorch  点击:(171)  评论:(0)  加入收藏
微信基于 PyTorch 的大规模推荐系统训练实践
本文将介绍微信基于 PyTorch 进行的大规模推荐系统训练。推荐系统和其它一些深度学习领域不同,仍在使用 Tensorflow 作为训练框架,被广大开发者诟病。虽然也有使用 PyTorch 进...【详细内容】
2023-04-04  Search: Pytorch  点击:(236)  评论:(0)  加入收藏
PyTorch张量的四种乘法运算
在PyTorch中有四种类型的乘法运算(位置乘法、点积、矩阵与向量乘法、矩阵乘法),非常容易搞混,我们一起来看看这四种乘法运算的区别。位置乘法先构建两个张量a,b他们都是4行5列。a...【详细内容】
2023-03-21  Search: Pytorch  点击:(249)  评论:(0)  加入收藏
PyTorch 并行训练 DistributedDataParallel 完整代码示例
使用大型数据集训练大型深度神经网络 (DNN) 的问题是深度学习领域的主要挑战。 随着 DNN 和数据集规模的增加,训练这些模型的计算和内存需求也会增加。 这使得在计算资源有限...【详细内容】
2023-02-19  Search: Pytorch  点击:(275)  评论:(0)  加入收藏
▌简易百科推荐
Python 可视化:Plotly 库使用基础
当使用 Plotly 进行数据可视化时,我们可以通过以下示例展示多种绘图方法,每个示例都会有详细的注释和说明。1.创建折线图import plotly.graph_objects as go# 示例1: 创建简单...【详细内容】
2024-04-01  Python技术    Tags:Python   点击:(8)  评论:(0)  加入收藏
Python 办公神器:教你使用 Python 批量制作 PPT
介绍本文将介绍如何使用openpyxl和pptx库来批量制作PPT奖状。本文假设你已经安装了python和这两个库。本文的场景是:一名基层人员,要给一次比赛活动获奖的500名选手制作奖状,并...【详细内容】
2024-03-26  Python技术  微信公众号  Tags:Python   点击:(15)  评论:(0)  加入收藏
Python实现工厂模式、抽象工厂,单例模式
工厂模式是一种常见的设计模式,它可以帮助我们创建对象的过程更加灵活和可扩展。在Python中,我们可以使用函数和类来实现工厂模式。一、Python中实现工厂模式工厂模式是一种常...【详细内容】
2024-03-07  Python都知道  微信公众号  Tags:Python   点击:(31)  评论:(0)  加入收藏
不可不学的Python技巧:字典推导式使用全攻略
Python的字典推导式是一种优雅而强大的工具,用于创建字典(dict)。这种方法不仅代码更加简洁,而且执行效率高。无论你是Python新手还是有经验的开发者,掌握字典推导式都将是你技能...【详细内容】
2024-02-22  子午Python  微信公众号  Tags:Python技巧   点击:(32)  评论:(0)  加入收藏
如何进行Python代码的代码重构和优化?
Python是一种高级编程语言,它具有简洁、易于理解和易于维护的特点。然而,代码重构和优化对于保持代码质量和性能至关重要。什么是代码重构?代码重构是指在不改变代码外部行为的...【详细内容】
2024-02-22  编程技术汇    Tags:Python代码   点击:(32)  评论:(0)  加入收藏
Python开发者必备的八个PyCharm插件
在编写代码的过程中,括号几乎无处不在,以至于有时我们会拼命辨别哪个闭合括号与哪个开头的括号相匹配。这款插件能帮助解决这个众所周知的问题。前言在PyCharm中浏览插件列表...【详细内容】
2024-01-26  Python学研大本营  微信公众号  Tags:PyCharm插件   点击:(84)  评论:(0)  加入收藏
Python的Graphlib库,再也不用手敲图结构了
Python中的graphlib库是一个功能强大且易于使用的工具。graphlib提供了许多功能,可以帮助您创建、操作和分析图形对象。本文将介绍graphlib库的主要用法,并提供一些示例代码和...【详细内容】
2024-01-26  科学随想录  微信公众号  Tags:Graphlib库   点击:(85)  评论:(0)  加入收藏
Python分布式爬虫打造搜索引擎
简单分布式爬虫结构主从模式是指由一台主机作为控制节点负责所有运行网络爬虫的主机进行管理,爬虫只需要从控制节点那里接收任务,并把新生成任务提交给控制节点就可以了,在这个...【详细内容】
2024-01-25  大雷家吃饭    Tags:Python   点击:(58)  评论:(0)  加入收藏
使用Python进行数据分析,需要哪些步骤?
Python是一门动态的、面向对象的脚本语言,同时也是一门简约,通俗易懂的编程语言。Python入门简单,代码可读性强,一段好的Python代码,阅读起来像是在读一篇外语文章。Python这种特...【详细内容】
2024-01-15  程序员不二    Tags:Python   点击:(161)  评论:(0)  加入收藏
Python语言的特点及应用场景, 同其它语言对比优势
Python语言作为一种高级编程语言,具有许多独特的特点和优势,这使得它在众多编程语言中脱颖而出。在本文中,我们将探讨Python语言的特点、应用场景以及与其他语言的对比优势。一...【详细内容】
2024-01-09    今日头条  Tags:Python语言   点击:(250)  评论:(0)  加入收藏
站内最新
站内热门
站内头条