【NLP 24、实践 ⑤ 计算Bert模型中的参数数量】

news/2025/2/26 3:52:44

以前不甘心,总想争个对错,现在不会了

人心各有所愿,没有道理可讲

                                                        —— 25.1.18

计算Bert模型结构中的参数数量 

BertModel.from_pretrained():用于从预训练模型目录或 Hugging Face 模型库加载 BERT 模型的权重及配置。

参数名称类型是否必填说明
pretrained_model_name_or_path字符串模型名称(如 bert-base-uncased)或本地路径。
configBertConfig对象自定义配置类,用于覆盖默认配置。
state_dict字典预训练权重字典,用于部分加载模型。
cache_dir字符串缓存目录,用于存储下载的模型文件。
from_tf布尔值是否从 TensorFlow 模型加载权重,默认为 False
ignore_mismatched_sizes布尔值是否忽略权重大小不匹配的错误,默认为 False
local_files_only布尔值是否仅从本地文件加载模型,默认为 False

return_dict参数:

  • 当 return_dict 设置为 True 时,forward() 方法返回一个 BaseModelOutput 对象,该对象包含了模型的各种输出,如最后一层的隐藏状态、[CLS] 标记的输出等。
  • 当 return_dict 设置为 False 时,forward() 方法返回一个元组,包含与 BaseModelOutput 对象相同的元素,但不包含对象结构。

numel():计算张量(Tensor)中的元素总数

参数名称类型是否必填说明
tensortorch.Tensor输入的PyTorch张量。

parameters():返回模型中所有可训练参数的迭代器。

参数名称类型是否必填说明
recurse布尔值是否递归获取子模块的参数,默认为True
import torch
import math
import torch.nn as nn
import numpy as np
from transformers import BertModel

model = BertModel.from_pretrained("F:\人工智能NLP\\NLP资料\week6 语言模型//bert-base-chinese", return_dict=False)
n = 2                       # 输入最大句子个数
vocab = 21128               # 词表数目
max_sequence_length = 512   # 最大句子长度
embedding_size = 768        # embedding维度
hide_size = 3072            # 隐藏层维数
num_layers = 1              # 隐藏层层数

# embedding过程中的参数,其中 vocab * embedding_size是词表embedding参数, max_sequence_length * embedding_size是位置参数, n * embedding_size是句子参数
# embedding_size + embedding_sizes是layer_norm层参数
embedding_parameters = vocab * embedding_size + max_sequence_length * embedding_size + n * embedding_size + embedding_size + embedding_size

# self_attention过程的参数, 其中embedding_size * embedding_size是权重参数,embedding_size是bias, *3是K Q V三个
self_attention_parameters = (embedding_size * embedding_size + embedding_size) * 3

# self_attention_out参数 其中 embedding_size * embedding_size + embedding_size + embedding_size是self输出的线性层参数,embedding_size + embedding_size是layer_norm层参数
self_attention_out_parameters = embedding_size * embedding_size + embedding_size + embedding_size + embedding_size

# Feed Forward参数 其中embedding_size * hide_size + hide_size第一个线性层,embedding_size * hide_size + embedding_size第二个线性层,
# embedding_size + embedding_size是layer_norm层
feed_forward_parameters = embedding_size * hide_size + hide_size + embedding_size * hide_size + embedding_size + embedding_size + embedding_size

# pool_fc层参数
pool_fc_parameters = embedding_size * embedding_size + embedding_size

# 模型总参数 = embedding层参数 + self_attention参数 + self_attention_out参数 + Feed_Forward参数 + pool_fc层参数
all_paramerters = embedding_parameters + (self_attention_parameters + self_attention_out_parameters + \
    feed_forward_parameters) * num_layers + pool_fc_parameters
print("模型实际参数个数为%d" % sum(p.numel() for p in model.parameters()))
print("diy计算参数个数为%d" % all_paramerters)


http://www.niftyadmin.cn/n/5867144.html

相关文章

echarts图表初始化搭建

vue搭建echarts折线图 Examples - Apache ECharts <template><div><div ref"chart" class"chart-container"></div></div> </template><script> import * as echarts from echarts;export default {name: Li…

蓝桥杯 Java B 组之最短路径算法(Dijkstra、Floyd-Warshall)

Day 2&#xff1a;最短路径算法&#xff08;Dijkstra、Floyd-Warshall&#xff09; &#x1f4d6; 一、最短路径算法简介 最短路径问题是图论中的经典问题&#xff0c;主要用于求解 单源最短路径 或 多源最短路径。在实际应用中&#xff0c;最短路径广泛应用于 导航系统、网络…

【Java项目】基于Spring Boot的火车订票管理系统

【Java项目】基于Spring Boot的火车订票管理系统 技术简介&#xff1a;采用Spring Boot框架、Java技术、MySQL数据库等实现。 系统简介&#xff1a;火车订票管理系统是一个面向管理员和用户的在线订票平台&#xff0c;主要分为前台和后台两大模块。前台功能模块包括&#xff08…

WPS中Word表格做好了,忘记写标题了怎么办?

大家好&#xff0c;我是小鱼。 在使用wps制作Word表格时经常会遇到这种情况&#xff0c;就是辛辛苦苦把word表格制作好了&#xff0c;却突然发现忘了为表格添加标题了。怎么都没法为表格重写添加标题&#xff0c;真是一阵操作猛如虎&#xff0c;结果觉得表格真是白做了。其实&…

设计模式教程:备忘录模式(Memento Pattern)

备忘录模式&#xff08;Memento Pattern&#xff09;详解 一、模式概述 备忘录模式&#xff08;Memento Pattern&#xff09;是一种行为型设计模式&#xff0c;允许在不暴露对象实现细节的情况下&#xff0c;保存对象的内部状态&#xff0c;并在需要时恢复该状态。备忘录模式…

Qt开发⑦Qt的窗口_上_菜单栏+工具栏+状态栏

目录 1. 菜单栏 1.1 创建菜单栏 1.2 在菜单栏中添加菜单 1.3 创建菜单项 1.4 在菜单项之间添加分割线 1.5 添加快捷键 1.6 添加子菜单 1.7 添加图标 1.8 综合示例 2. 工具栏 2.1 创建工具栏 2.2 设置停靠位置 2.3 设置浮动属性 2.4 设置移动属性 2.5 综合示例 …

机器学习数学基础:34.点二列

点二列相关教程 一、点二列相关的定义 点二列相关是一种统计方法&#xff0c;用于衡量两个变量之间的相关程度。在这种相关分析中&#xff0c;一个变量是正态连续性变量&#xff0c;取值可以是连续的数值&#xff0c;比如身高、体重、考试分数等&#xff1b;另一个是真正的二…

Android之图片保存相册及分享图片

文章目录 前言一、效果图二、实现步骤1.引入依赖库2.二维码生成3.布局转图片保存或者分享 总结 前言 其实现在很多分享都是我们自定义的&#xff0c;更多的是在界面加了很多东西&#xff0c;然后把整个界面转成图片保存相册和分享&#xff0c;而且现在分享都不需要第三方&…