Pytorch框架下的CNN和RNN

1.CNN

建立了3层(3层=2层+1层全连接层)。分别是conv1、conv2和分类问题中的全连接层线性层out

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,              # input height
                out_channels=16,            # n_filters
                kernel_size=5,              # filter size
                stride=1,                   # filter movement/step
                padding=2,                  # if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
            ),                              # output shape (16, 28, 28)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(kernel_size=2),    # choose max value in 2x2 area, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(         # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),     # output shape (32, 14, 14)
            nn.ReLU(),                      # activation
            nn.MaxPool2d(2),                # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)   # fully connected layer, output 10 classes

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)           # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output, x    # return x for visualization

2.RNN

2.1RNN分类问题代码

设计了rnn层【输入(INPUT_SIZE),隐藏层1层(hidden_size)】和分类问题的全连接线性层out

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.LSTM(     # LSTM 效果要比 nn.RNN() 好多了
            input_size=28,      # 图片每行的数据像素点
            hidden_size=64,     # rnn hidden unit
            num_layers=1,       # 有几层 RNN layers
            batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
        )

        self.out = nn.Linear(64, 10)    # 输出层

    def forward(self, x):
        # x shape (batch, time_step, input_size)
        # r_out shape (batch, time_step, output_size)
        # h_n shape (n_layers, batch, hidden_size)   LSTM 有两个 hidden states, h_n 是分线, h_c 是主线
        # h_c shape (n_layers, batch, hidden_size)
        r_out, (h_n, h_c) = self.rnn(x, None)   # None 表示 hidden state 会用全0的 state

        # 选取最后一个时间点的 r_out 输出
        # 这里 r_out[:, -1, :] 的值也是 h_n 的值
        out = self.out(r_out[:, -1, :])
        return out

rnn = RNN()
print(rnn)
"""
RNN (
  (rnn): LSTM(28, 64, batch_first=True)
  (out): Linear (64 -> 10)
)
"""

2.2RNN回归问题代码

具体参考:https://mofanpy.com/tutorials/machine-learning/torch/RNN-regression

class RNN(nn.Module):
    def __init__(self):
        super(RNN, self).__init__()

        self.rnn = nn.RNN(  # 这回一个普通的 RNN 就能胜任
            input_size=1,
            hidden_size=32,     # rnn hidden unit
            num_layers=1,       # 有几层 RNN layers
            batch_first=True,   # input & output 会是以 batch size 为第一维度的特征集 e.g. (batch, time_step, input_size)
        )
        self.out = nn.Linear(32, 1)

    def forward(self, x, h_state):  # 因为 hidden state 是连续的, 所以我们要一直传递这一个 state
        # x (batch, time_step, input_size)
        # h_state (n_layers, batch, hidden_size)
        # r_out (batch, time_step, output_size)
        r_out, h_state = self.rnn(x, h_state)   # h_state 也要作为 RNN 的一个输入

        outs = []    # 保存所有时间点的预测值
        for time_step in range(r_out.size(1)):    # 对每一个时间点计算 output
            outs.append(self.out(r_out[:, time_step, :]))
        return torch.stack(outs, dim=1), h_state


rnn = RNN()
print(rnn)
"""
RNN (
  (rnn): RNN(1, 32, batch_first=True)
  (out): Linear (32 -> 1)
)
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/589443.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端Web开发基础知识

HTML定义 超文本标记语言(英语:HyperText Markup Language,简称:HTML)是一种用于创建网页的标准标记语言。 什么是 HTML? HTML 是用来描述网页的一种语言。 HTML 指的是超文本标记语言: HyperText Markup LanguageH…

ELK Stack 8 接入ElasticFlow

介绍 Netflow v5 / v9 / v10(IPFIX),支持大部分网络厂商及VMware的分布式交换机。 NetFlow是一种数据交换方式。Netflow提供网络流量的会话级视图,记录下每个TCP/IP事务的信息。当汇集起来时,它更加易于管理和易读。…

EasyExcel 处理 Excel

序言 本文介绍在日常的开发中,如何使用 EasyExcel 高效处理 Excel。 一、EasyExcel 是什么 EasyExcel 是阿里巴巴开源的一个 Java Excel 操作类库,它基于 Apache POI 封装了简单易用的 API,使得我们能够方便地读取、写入 Excel 文件。Easy…

常用AI工具分享 + IDEA内使用通义灵码

引言 随着人工智能技术的飞速发展,AI工具已经渗透到我们日常生活和工作的各个领域,带来了前所未有的便利。现在我将分享一下常用的AI工具,以及介绍如何在IDEA中使用通义灵码。 常用AI工具 1. 通义灵码 (TONGYI Lingma) - 由阿里云开发的智能…

Neo4j v5 中 Cypher 的变化

How Cypher changed in Neo4j v5 Neo4j v5 中 Cypher 的变化 几周前,Neo4j 5 发布了。如果你像我一样,在 Neo4j 4 的后期版本中忽略了所有的弃用警告,你可能需要更新你的 Cypher 查询以适应最新版本的 Neo4j。幸运的是,新的 Cyp…

【翻译】REST API

自动伸缩 API 创建或更新自动伸缩策略 API 此特性设计用于 Elasticsearch Service、Elastic Cloud Enterprise 和 Kubernetes 上的 Elastic Cloud 的间接使用。不支持直接用户使用。 创建或更新一个自动伸缩策略。 请求 PUT /_autoscaling/policy/<name> {"rol…

什么是UDP反射放大攻击,有什么安全措施可以防护UDP攻击

随着互联网的飞速发展和业务复杂性的提升&#xff0c;网络安全问题日益凸显&#xff0c;其中分布式拒绝服务&#xff08;DDoS&#xff09;攻击成为危害最为严重的一类网络威胁之一。 近些年&#xff0c;网络攻击越来越频繁&#xff0c;常见的网络攻击类型包括&#xff1a;蠕虫…

AI图书推荐:用ChatGPT快速创建在线课程

您是否是您领域的专家&#xff0c;拥有丰富的知识和技能可以分享&#xff1f;您是否曾想过创建一个在线课程&#xff0c;但被这个过程吓倒了&#xff1f;那么&#xff0c;是时候把这些担忧放在一边&#xff0c;迈出这一步了&#xff01;有了这本指南和ChatGPT的帮助&#xff0c…

ssh远程访问windows系统下的jupyterlab

网上配置这一堆那一堆&#xff0c;特别乱&#xff0c;找了好久整理后发在这里 由于既想打游戏又想做深度学习&#xff0c;不舍得显卡性能白白消耗&#xff0c;这里尝试使用笔记本连接主机 OpenSSH 最初是为 Linux 系统开发的&#xff0c;现在也支持包括 Windows 和 macOS 在内…

[1673]jsp在线考试管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 在线考试管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5.0&…

R语言学习—4—数据矩阵及R表示

1、创建向量、矩阵 在R中&#xff0c;c()函数用于创建向量或组合数据对象。它在某些情况下可能会被省略&#xff0c;因为R有一些隐式的向量创建规则。例如&#xff0c;当你使用:操作符创建一个数字序列时&#xff0c;R会自动创建一个向量&#xff0c;所以你不需要显式地调用c()…

《QT实用小工具·五十二》文本或窗口炫酷有趣的滚动条——果冻条

1、概述 源码放在文章末尾 该项目实现了文本或窗口纤细的滚动条——果冻条 一个可以像弓弦一样拉出来&#xff0c;并且来回弹动的普通滚动条。 思路为此&#xff0c;但发现实际效果更像条状果冻&#xff0c;并且略有谐音&#xff0c; 故&#xff0c;称之为——“果冻条”&am…

条件依赖性的方法示例

5个条件判断一件事情是否发生&#xff0c;每个条件可能性只有2种&#xff08;发生或者不发生&#xff09;&#xff0c;计算每个条件对这件事情发生的影响力&#xff0c;条件之间有很强的依赖关系。 例一 如果条件之间有很强的依赖关系&#xff0c;那么简单地计算每个条件独立的…

初探 Google 云原生的CICD - CloudBuild

大纲 Google Cloud Build 简介 Google Cloud Build&#xff08;谷歌云构建&#xff09;是谷歌云平台&#xff08;Google Cloud Platform&#xff0c;GCP&#xff09;提供的一项服务&#xff0c;可帮助开发人员以一致和自动化的方式构建、测试和部署应用程序或构件。它为构建和…

B树:原理、操作及应用

B树&#xff1a;原理、操作及应用 一、引言二、B树概述1. 定义与性质2. B树与磁盘I/O 三、B树的基本操作1. 搜索&#xff08;B-TREE-SEARCH&#xff09;2. 插入&#xff08;B-TREE-INSERT&#xff09;3. 删除&#xff08;B-TREE-DELETE&#xff09; 四、B树的C代码实现示例五、…

基于 Wireshark 分析 IP 协议

一、IP 协议 IP&#xff08;Internet Protocol&#xff09;协议是一种网络层协议&#xff0c;它用于在计算机网络中实现数据包的传输和路由。 IP协议的主要功能有&#xff1a; 1. 数据报格式&#xff1a;IP协议将待传输的数据分割成一个个数据包&#xff0c;每个数据包包含有…

mac电脑关于ios端的appium真机自动化测试环境搭建

一、app store 下载xcode,需要登录apple id 再开始下载 二、安装homebrew 1、终端输入命令&#xff1a; curl -fsSL <https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh>如果不能直接安装&#xff0c;而是出现了很多内容&#xff0c;那么这个时候不要着急&…

06.Git远程仓库

Git远程仓库 #仓库种类&#xff0c;举例说明 github gitlab gitee #以这个仓库为例子操作登录码云 https://gitee.com/projects/new 创建仓库 选择ssh方式 需要配置ssh公钥 在系统上获取公钥输入命令&#xff1a;ssh-keygen 查看文件&#xff0c;复制公钥信息内…

【画图】读取无人机IMU数据并打印成log用matlab分析

一、修改IMU频率 原来的imu没有加速度信息&#xff0c;查看加速度信息的指令为&#xff1a; rostopic echo /mavros/imu/data 修改imu频率&#xff0c;分别修改的是 原始IMU数据话题 /mavros/imu/data_raw。飞控计算过后的IMU数据 /mavros/imu/data rosrun mavros mavcmd l…

Uniapp好看登录注册页面

个人介绍 hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的…
最新文章