Pytorch函数总结-常用函数


一、Pytorch函数总结

函数 功能 参数 定义 备注 示例 示例
torch.mul 对位相乘 两个矩阵 或者一个矩阵一个标量 torch.mul(input, other, *, out=None) → Tensor “如果input和other都是矩阵向量,size必须一样 结果是对应位置的元素相乘 如果other是标量(只有一个元素),结果是把这个元素都乘上去 “
torch.mm 二维矩阵乘法 其他维度不行 两个矩阵 torch.mm(input, mat2, *, out=None) → Tensor 行列式相乘
torch.matmul 适用性最多的,能处理batch、可以广播不是对位相乘 是矩阵乘法 任意两个tensor torch.matmul(input, other, *, out=None) → Tensor “向量和向量对位相乘矩阵和向量 行列式乘法”
torch.ones 生成全为1的tensor shape 是一个整数序列,可以是一个数字也可以是list或者tuple ones(size, , out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
torch.zeros 生成全为1的tensor shape同上 比如2行3列可以写成(2,3) [2,3] zeros(size, , out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
torch.randn 按照正则分布产生数据 Size参数同上 randn(size, , out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
torch.normal 按照正则分布产生 但是每一个元素的mean和std不一样 是不同曲线生成的 means是一个tensor list表示第一个元素的mean std也是一个tensor list normal(mean, std, *, generator=None, out=None) -> Tensor 感觉参数比较诡异 不建议使用
torch.tensor 返回tensor变量 data可以是list,tuple,numpy的各种类型 data (array_like) – Initial datafor the tensor. Can be a list, tuple, NumPy ndarray, scalar, and other types. tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor 这个函数生成tensor的时候 会新建存储
torch.Tensor 与torch.FloatTensor相同,只返回cpu张量且是float类型 还没搞清楚
view 返回的变量只是原变量的一个软连接,两个变量共享存储No data movement occurs when creating a view, view tensor just changes the way it interprets the same data. shapeShape 可以是torch.Size 也可以是int Tensor.view(*shape) → Tensor “1. view 的参数是torch.Size 或很多int 参数,可以带括号也可以不用view([2,2])view((2,2))view(2,2)以上效果是一样的2. torch.Size是一个不可变序列 immutable sequence. 如果没有参数,则返回空tuple如果是可迭代的,比如list 则迭代后返回tuple如果是个tuple ,则直接返回”

二、Pytorch数据类型总结

Data type dtype Tensor types
32-bit floating point torch.float32 or torch.float torch.*.FloatTensor
64-bit floating point torch.float64 or torch.double torch.*.DoubleTensor
16-bit floating point torch.float16 or torch.half torch.*.HalfTensor
8-bit integer (unsigned) torch.uint8 torch.*.ByteTensor
8-bit integer (signed) torch.int8 torch.*.CharTensor
16-bit integer (signed) torch.int16 or torch.short torch.*.ShortTensor
32-bit integer (signed) torch.int32 or torch.int torch.*.IntTensor
64-bit integer (signed) torch.int64 or torch.long torch.*.LongTensor

文章作者: jasme
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 jasme !
  目录