极市导读
本文将带你感受einsum的“万能”,作者通过提供从基础到高级的einsum使用范例,展示了它是怎么做到既简洁又优雅地实现多种张量操作,并轻易解决维度匹配问题。 >>加入极市CV技术交流群,走在计算机视觉的最前沿
引入
如果问pytorch中最强大的一个数学函数是什么?
一,einsum规则原理
C = torch.einsum("ik,kj->ij",A,B)
1,用元素计算公式来表达张量运算。 2,只出现在元素计算公式箭头左边的指标叫做哑指标。 3,省略元素计算公式中对哑指标的求和符号。
import torch
A = torch.tensor([[1,2],[3,4.0]])
B = torch.tensor([[5,6],[7,8.0]])
C1 = A@B
print(C1)
C2 = torch.einsum("ik,kj->ij",[A,B])
print(C2)
tensor([[19., 22.],
[43., 50.]])
tensor([[19., 22.],
[43., 50.]])
二,einsum基础范例
例1,张量转置
#例1,张量转置
A = torch.randn(3,4,5)
#B = torch.permute(A,[0,2,1])
B = torch.einsum("ijk->ikj",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([3, 4, 5])
after: torch.Size([3, 5, 4])
例2,取对角元
#例2,取对角元
A = torch.randn(5,5)
#B = torch.diagonal(A)
B = torch.einsum("ii->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([5, 5])
after: torch.Size([5])
例3,求和降维
#例3,求和降维
A = torch.randn(4,5)
#B = torch.sum(A,1)
B = torch.einsum("ij->i",A)
print("before:",A.shape)
print("after:",B.shape)
before: torch.Size([4, 5])
after: torch.Size([4])
例4,哈达玛积
#例4,哈达玛积
A = torch.randn(5,5)
B = torch.randn(5,5)
#C=A*B
C = torch.einsum("ij,ij->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 5]) torch.Size([5, 5])
after: torch.Size([5, 5])
例5,向量内积
#例5,向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([10])
after: torch.Size([])
例6,向量外积
#例6,向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([10]) torch.Size([5])
after: torch.Size([10, 5])
例7,矩阵乘法
#例7,矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([5, 4]) torch.Size([4, 6])
after: torch.Size([5, 6])
例8,张量缩并
#例8,张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
before: torch.Size([3, 4, 5]) torch.Size([4, 3, 6])
after: torch.Size([5, 6])
三,einsum高级范例
例9,bilinear注意力机制
#例9,bilinear注意力机制
#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(10) #key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features
#a = q@W@k.t()+b
a = torch.bilinear(q,k,W,b)
print("a.shape:",a.shape)
#=====考虑batch维度====
Q = torch.randn(8,10) #batch_size,query_features
K = torch.randn(8,10) #batch_size,key_features
W = torch.randn(5,10,10) #out_features,query_features,key_features
b = torch.randn(5) #out_features
#A = torch.bilinear(Q,K,W,b)
A = torch.einsum('bq,oqk,bk->bo',Q,W,K) + b
print("A.shape:",A.shape)
a.shape: torch.Size([5])
A.shape: torch.Size([8, 5])
例10,scaled-dot-product注意力机制
#例10,scaled-dot-product注意力机制
#====不考虑batch维度====
q = torch.randn(10) #query_features
k = torch.randn(6,10) #key_size, key_features
d_k = k.shape[-1]
a = torch.softmax(q@k.t()/d_k,-1)
print("a.shape=",a.shape )
#====考虑batch维度====
Q = torch.randn(8,10) #batch_size,query_features
K = torch.randn(8,6,10) #batch_size,key_size,key_features
d_k = K.shape[-1]
A = torch.softmax(torch.einsum("in,ijn->ij",Q,K)/d_k,-1)
print("A.shape=",A.shape )
a.shape= torch.Size([6])
A.shape= torch.Size([8, 6])
公众号后台回复“ECCV2022”获取论文分类资源下载~
“
点击阅读原文进入CV社区
收获更多技术干货