爱因斯坦求和(einsum)

2022-04-06
11分钟阅读时长

在数学里,特别是将线性代数套用到物理时,爱因斯坦求和约定(Einstein summation convention)是一种标记的约定,又称为爱因斯坦标记法(Einstein notation),在处理关于坐标的方程式时非常有用。这约定是由阿尔伯特·爱因斯坦于1916年提出的。后来,爱因斯坦与友人半开玩笑地说:“这是数学史上的一大发现,若不信的话,可以试着返回那不使用这方法的古板日子”。

一、einsum 记法

在 PyTorch / TensorFlow 中的那些点积、外积、转置、矩阵-向量乘法、矩阵-矩阵乘法都是可以试用 einsum 记法来表达上面的这些运算的,包括复杂张量运算在内的优雅方式,基本上,可以把 einsum 看成一种领域特定语言。一旦理解了并能利用 einsum,除了不用记忆和频繁查找特定库函数说明文档外,还可以更迅速地编写更加紧凑、高效的代码。而不使用 einsum 的时候,容易出现引入不必要的张量变形或转置运算,以及可以省略的中间张量的现象。此外,einsum 这样的领域特定语言有时可以编译到高性能代码,事实上,PyTorch 新引入的能够自动生成 GPU 代码并为特定输入尺寸自动调整代码的张量理解(Tensor Comprehensions)就基于类似 einsum 的领域特定语言。此外,可以是要用 Optimized Einsumtf einsum opt 这样的项目优化 einsum 表达式的构造顺序。

假如,我们想要将两个矩阵 ${\color{red}A} \in \mathbb R^{I \times K}$ 和 ${\color{blue}B} \in \mathbb R^{K \times J}$ 相乘,接着计算每列的和,最终得到向量 ${\color{green}c} \in \mathbb R^J$。使用爱因斯坦求和约定,这可以表达为: $$ \Large {\color{green}C_j}={\color{gray}\sum_i\sum_k\color{red}A_{ik}\color{blue}{B_{kj}}}=\color{red}A_{ik}\color{blue}B_{kj} $$ 这一表达式指明了 $\color{green}c$ 中每个元素 $\color{green}c_i$ 是如何计算的,列向量 $\color{red}A_{i:}$ 乘以行向量 $\color{blue}B_{:j}$,然后求和。

注意,在爱因斯坦求和约定中,我们省略了求和符号 $\Sigma$,因为我们隐式地累加重复的下标(这里是 $k$)和输出中未指明的下标(这里是 $i$)。当然 einsum 也能表达更基本的运算。比如,计算两个向量 ${\color{red}a},{\color{blue}b} \in \mathbb R^{J}$ 的点积可以表达为: $$ \Large {\color{green}c} = {\color{gray}\sum_i \color{red}a_i \color{blue}b_i}={\color{red}a_i \color{blue}b_i}. $$ 在深度学习中,我们经常碰到的一个问题是,变换高阶张量到向量。例如,我们可能有一个向量,其中包含一个 batch 中的 $N$ 个训练样本,每个样本是一个长度为 $T$ 的 $K$ 维词向量序列,然后想把词向量投影到一个不同的维度 $Q$ 。如果将这个张量记作 ${\color{red}\mathcal T}\in \mathbb R^{N \times T \times K}$,将这个投影矩阵记作 ${\color{blue}W} \in \mathbb R^{K\times Q}$,那么所需要计算可以用 einsum 表达为: $$ \Large {\color{green}C_{ntq}} = {\color{gray}\sum_k \color{red}T_{ntk} \color{blue}W_{kq}}={\color{red}T_{ntk} \color{blue}W_{kq}}. $$ 最后一个例子,如果有一个四阶张量 ${\color{red}\mathcal T} \in \mathbb R_{N \times T \times K \times M}$,想要使用之前的投影矩阵将第三维投影至 $Q$ 维,并累加第二维,然后转置结果中的第一维和最后一维,最终得到张量 ${\color{green}C} \in \mathbb R^{M \times Q \times N}$。einsum 可以非常简洁地表达这一切: $$ \Large {\color{green}C_{mqn}} = {\color{gray}\sum_t\sum_k \color{red}T_{ntkm} \color{blue}W_{kq}} = {\color{red}T_{ntkm} \color{blue}W_{kq}}. $$ 注意,我们通过交换下标 $n$ 和 $m$(【公式】而不是 $C_{nqm}$),转置了张量构造结果。

二、Numpy、PyTorch、TensorFlow 中的 einsum

einsum 在 numpy 中实现为 np.einsum,在 PyTorch 中实现为 torch.einsum,在 TensorFlow 中实现为 tf.einsum,均试用一致的函数签名 einsum(euqation, eperands),其中 equation 是表示爱因斯坦求和约定的字符串,而 operands 则是张量序列(在 numpy 和 TensorFlow 中是可变长参数列表,而在 PyTorch 中也可以是列表)。例如,我们的第一个例子,${\color{green}C_j}={\color{gray}\sum_i\sum_k\color{red}A_{ik}\color{blue}{B_{kj}}}$ 写成 equation 字符串就是 “${\color{red}ik},{\color{blue}kj}$->${\color{green}j}$"。注意这里的 (i,j,k) 的命名是任意的,但需要保持一致。

简单起见,选择最常用的字母串方式。矩阵乘法是便于演示理解的,其涉及了行和列的相乘以及乘积结果的求和。对于两个二维数组 A 和 B,矩阵乘法可以实现为(以 PyTorch 为例):

torch.einsum('ij,jk->ik', A, B)

ij,jk->ik 想象成在箭头 -> 处一分为二。其中,左边 ij,jk 是输入数组的轴:ij 标记 A,jk 标记 B,右边的 ik 就是输出数组的轴。

换句话说,是将两个输入二维数组放到一个新的输出二维数据中。

如,数组 A 和 B,

A = np.array([[1, 1, 1],
              [2, 2, 2],
              [5, 5, 5]])

B = np.array([[0, 1, 0],
              [1, 1, 0],
              [1, 1, 1]])

torch.einsum('ij,jk->ik', A, B) 实现如图:

matrix-mul-reduce

为了便于理解输出数组的计算过程,记住如下三个规则:

  • 输入数组之间的重复字母,表示沿着这些轴(axis)的值将相乘,乘积构成输出数组的值。

    如,字母 $j$ 重复了两次,一次是 $\color{red}A$,一次是 $\color{blue}B$。意味着是将 $\color{red}A$ 的每一行和 $\color{blue}B$ 的每一列相乘。其仅对两个数组在 $j$ 标记 axis 长度一致或其中一个长度为 1 时才有效。

  • 输出中忽略的字母,表示沿该轴的值将被求和。

    如,字母 $j$ 不在输出数组的标记中。忽略它沿轴求和,并显式地将最终数据的维度减 1。

    如果输出标记是 ijk,则最终会得到一个 $i \times j \times k$ 的乘积数组。(且,如果没有输出标记,而只是写箭头 ->,则需要对整个数组求和)

  • 可以以任意次序返回未求和的轴。

    如果忽略箭头 ->,将会获取出现过一次的字母标记,并且按照字母顺序进行排列,实际上 ij,jk->ik 等价于 ij,jk

    如果期望控制输出的形式,可以自定义选择输出字母标记的顺序,如,ij,jk->ki 表示矩阵乘法的转置(注意 ki 在输出标记的次序有切换)。

至此,应该更容易理解矩阵乘法的计算流程。如图示,如果不对 j 轴求和,而是采用 torch.einsum('ij,jk->ijk', A, B)j 轴包含在输出中。右侧,j 轴已经被求和。

mat-mul-full-and-reduce

注:np.einsum('ij,jk->ik', A, B) 函数并未构建 3D 数组再求和,其只是将总和累计到一个 2D 数组中。

PyTorch 和 TensorFlow 像 numpy 支持 einsum 的好处之一是 einsum 可以用于神经网络架构的任意计算图,并且可以反向传播。典型的 einsum 调用格式如下: $$ \Large {\color{green}\textbf{result}} = \color{gray}\text{einsum}(\texttt{"}\color{red}\square\square,\color{purple}\square\square\square,\color{blue}\square\square\textbf{->}\color{green}\square\square\color{gray}\texttt{"},\color{red}\text{arg1},\color{purple}\text{arg2},\color{blue}\text{arg3}\color{gray}) $$ 上式中 $\square$ 是占位符,表示张量维度。上面例子中,$\color{red}\text{arg1}$ 和 $\color{blue}\text{arg3}$ 是矩阵,$\color{purple}\text{arg2}$ 是三阶张量,这一 einsum 运算的结果是($\color{green}\textbf{result}$)是矩阵。注意 einsum 处理的是可变数量的参数输入。在上面例子中,einsum 指定了三个参数上的操作,但它同样可以用在牵涉一个参数、两个参数、三个以上参数的操作上。学习 einsum 的最佳途径是通过一些例子,下面展示一下具体的例子,在许多深度学习模型中常用的库函数,用 einsum 该如何表达(以 PyTorch 为例)。

  1. 矩阵转置(Matrix transpose) $$ \Large {\color{green}B_{ji}} = {\color{red}A_{ij}} $$

    import torch
    a = torch.arange(6).reshape(2, 3)
    torch.einsum('ij->ji', [a])  # 与 torch.einsum('ij->ji', a) 等价
    

    输出:

     tensor([[0, 3],
             [1, 4],
             [2, 5]])
    
  2. 求和(Sum) $$ \Large {\color{green}b}={\color{gray}\sum_i\sum_j\color{red}A_{ij}}=\color{red}A_{ij} $$

    a = torch.arange(6).reshape(2, 3)
    torch.einsum('ij->', a)
    

    输出:

    tensor(15)
    
  3. 列求和(Column sum) $$ \Large {\color{green}b_j}={\color{gray}\sum_i\sum_j\color{red}A_{ij}}=\color{red}A_{ij} $$

    a = torch.arange(6).reshape(2, 3)
    torch.einsum('ij->j', a)
    

    输出:

    tensor([3, 5, 7])
    
  4. 行求和(Row sum) $$ \Large {\color{green}b_i}={\color{gray}\sum_i\sum_j\color{red}A_{ij}}=\color{red}A_{ij} $$

    a = torch.arange(6).reshape(2, 3)
    torch.einsum('ij->i', a)
    

    输出:

    tensor([3, 12])
    
  5. 矩阵-向量相乘(Matrix-vector maltiplication) $$ \Large {\color{green}C_i}={\color{gray}\sum_k \color{red}A_{ik} \color{blue}b_k} = \color{red}A_{ik} \color{blue}b_k $$

    a = torch.arange(6).reshape(2, 3)
    b = torch.arange(3)
    torch.einsum('ik,k->i', a, b)
    

    输出:

    tensor([5, 14])
    
  6. 矩阵-矩阵相乘(Matrix-matrix multiplication) $$ \Large {\color{green}C_{ij}}={\color{gray}\sum_k \color{red}A_{ik} \color{blue}B_{kj}} = \color{red}A_{ik} \color{blue}B_{kj} $$

    **解释:**对于矩阵 ${\color{red}A} \in \mathbb R^{i \times k}$ 和矩阵 ${\color{blue}B} \in \mathbb R^{k \times j}$的矩阵乘法结果的第 $i$ 行 j 列的结果是 ${\color{green}C_{ij}}={\color{red}A_{i0}\color{blue}B_{0j}}+{\color{red}A_{i1}\color{blue}B_{1j}}+\cdots+\color{red}A_{ik}\color{blue}B_{kj}$。

    a = torch.arange(6).reshape(2, 3)
    b = torch.arange(15).reshape(3, 5)
    torch.einsum('ik,kj->ij', [a, b])
    

    输出:

    tensor([[ 25,  28,  31,  34,  37],
            [ 70,  82,  94, 106, 118]])
    
  7. 点积(Dot product)

    • 向量 $$ \Large {\color{green}c}={\color{gray}\sum_i \color{red}a_i \color{blue} b_i} = {\color{red}a_i \color{blue} b_i} $$

      a = torch.arange(3)
      b = torch.arange(3, 6) # [3, 4, 5]
      torch.einsum('i,i->', a, b)
      

      输出:

      tensor(14)
      
    • 矩阵 $$ \Large {\color{green}c}={\color{gray}\sum_i \sum_j \color{red}A_{ij} \color{blue} B_{ij}} = {\color{red}A_{ij} \color{blue} B_{ij}} $$

      a = torch.arange(6).reshape(2, 3)
      b = torch.arange(6, 12).reshape(2, 3)
      torch.einsum('ij,ij->', [a, b])
      

      输出:

      tensor(145)
      
  8. 哈达玛积(Hadamard product) $$ \Large {\color{green} C_{ij} } = \color{red}A_{ij} \color{blue}B_{ij} $$

    a = torch.arange(6).reshape(2, 3)
    b = torch.arange(6, 12).reshape(2, 3)
    torch.einsum('ij,ij->ij', [a, b])
    

    输出:

    tensor([[ 0,  7, 16],
            [27, 40, 55]])
    
  9. 外积(Outer product) $$ \Large {\color{green} C_{ij} } = \color{red}a_{i} \color{blue}b_{j} $$

    a = torch.arange(3)
    b = torch.arange(3, 7)
    torch.einsum('i,j->ij', [a, b])
    

    输出:

    tensor([[ 0,  0,  0,  0],
           [ 3,  4,  5,  6],
           [ 6,  8, 10, 12]])
    
  10. 批量矩阵相乘(Batch matrix multiplication) $$ \Large {\color{green}C_{bij}}={\color{gray}\sum_k \color{red} A_{bik} \color{blue}B_{bkj}}={\color{red} A_{bik} \color{blue}B_{bkj}} $$

    a = torch.randn(3, 2, 5)
    b = torch.randn(3, 5, 3)
    torch.einsum('bik,bkj->bij', a, b)
    

    输出:

    tensor([[[-3.2780, -0.2889,  1.6346],
             [ 1.1263, -3.0012,  1.6123]],
    
            [[-0.3938,  0.9549,  3.3076],
             [-1.1168,  1.2746, -1.2815]],
    
            [[-0.3073,  5.2579, -0.8160],
             [-2.5301, -0.6242,  0.5145]]])
    
  11. 张量缩约(Tensor contraction)

    批量矩阵相乘是张量缩约的一个特例。假如,我们有两个张量,一个 $n$ 阶张量 $A \in \mathbb R^{I_1 \times \cdots \times I_n}$,一个 $m$ 阶张量 $B \in \mathbb R^{J_1 \times \cdots \times J_m}$。设 $n=4$,$m=5$,并假定 $I_2=J_3$ 且 $I_3=J_5$。我们可以将这两个张量在这两个维度上相乘(A 张量的第 2,3 维度,B 张量的 3,5 维度),最终得到一个新张量 $C \in \mathbb R^{I_1 \times I_4 \times J_1 \times J_2 \times J_4}$,如下所示: $$ \Large {\color{green}C_{pstuv} }={\color{gray}\sum_q\sum_r \color{red}A_{pqrs} \color{blue}B_{tuqvr}}={\color{red}A_{pqrs} \color{blue}B_{tuqvr}} $$

    a = torch.randn(2,3,5,7)
    b = torch.randn(11,13,3,17,5)
    torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape
    

    输出:

    torch.Size([2, 7, 11, 13, 17])
    
  12. 双线性变换

    einsum 可用于超过两个张量的计算。这里举一个实际的双线性变换 例子。 $$ \Large {\color{green}D_{ij}}={\color{gray}\sum_k\sum_l\color{red}A_{ik}\color{purple}B_{jkl}\color{blue}C_{il}}={\color{red}A_{ik}\color{purple}B_{jkl}\color{blue}C_{il}} $$

    a = torch.randn(2,3)
    b = torch.randn(5,3,7)
    c = torch.randn(2,7)
    torch.einsum('ik,jkl,il->ij', [a, b, c])
    

    输出:

    tensor([[ 4.3686,  2.4327, -4.7635, -0.9871,  4.2039],
            [-2.1500,  0.2848, -1.4769,  0.7634,  1.4486]])
    

三、常见 einsum 的 Numpy 等价形式

  • 假设,A 和 B 是两个一维数组(向量),其具有兼容的 shape ,即配对在一起的轴(要么长度相等,要么其中之一长度为1)

    标记符号 Numpy 等价形式 描述
    ('i', A) A 向量 A 本身
    ('i->', A) sum(A) 向量 A 所有元素求和
    ('i,i->i', A, B) A * B 向量 A 和 向量 B 逐元素相乘
    ('i,i', A, B) inner(A, B) 向量 A 和 向量 B 的内积
    ('i,j->ij', A, B) outer(A, B) 向量 A 和 向量 B 的外积
  • 假设,A 和 B 是两个二维数组(矩阵),其具有兼容的 shape。

    标记符号 Numpy 等价形式 描述
    ('ij', A) A 矩阵 A 本身
    ('ji', A) A.T 矩阵 A 转置
    ('ii->i', A) diag(A) 矩阵 A 对角元素
    ('ii', A) trace(A) 矩阵 A 的 迹(主对角线之和)
    ('ij->', A) sum(A) 矩阵 A 所有元素求和
    ('ij->j', A) sum(A, axis=0) 矩阵 A 列求和,保留列(axis=1),对轴(axis=0)求和
    ('ij->i', A) sum(A, axis=1) 矩阵 A 行求和,保留行(axis=0),对轴(axis=1)求和
    ('ij,ij->ij', A, B) A * B A 和 B 哈达玛积(逐元素相乘)
    ('ij,ji->ij', A, B) A * B.T A 和 B 的转置逐元素相乘
    ('ij,jk', A, B) dot(A, B) A 和 B 的矩阵乘法
    ('ij,kj->ik', A, B) inner(A, B) 矩阵 A 和 矩阵 B 的内积
    ('ij,kj->ikj', A, B) A[:, None] * B
    ('ij,kl->ijkl', A, B) A[:, :, None, None] * B

四、实际案例

  1. 假设两个数组 A 和 B,需要进行如下操作:

    • 乘法:首先以特定的方式将 A 和 B 相乘,以得到乘积结果数组;
    • 求和:然后,沿着特定的轴(axis)求和,得到新的数组;
    • 转置:再以特定顺序,对数组转置。

    使用 einsum 可以更快速和更少内存实现 multipy、sum 和 transpose 函数功能,如,

    import torch
    A = torch.tensor([0, 1, 2])
    B = torch.tensor([[0,  1,  2,  3],
                      [4,  5,  6,  7],
                      [8,  9, 10, 11]])
    

    采用一般实现是:

    • 首先,需要对 A 进行 reshape,以便于与 B 进行操作,A 需要是列向量;
    • 然后,对 B 第一行乘以 0,第二行乘以 1,第三行乘以 2,得到新的数组;
    • 接着,对三列相加求和。

    即:

    (A.reshape(-1, 1) * B).sum(axis=1)  # output: tensor([ 0, 22, 76])
    

    而,使用 einsum 方式如下:

    torch.einsum('i,ij->i', A, B)
    

    为什么更有好,其原因是,不需要对 A 进行 reshape 操作;最重要的是,不会产生临时数据,如:A.reshape(-1, 1) * B。即使上面这个小例子,einsum 也具有三倍更快的速度。

  2. TreeQN

    论文 TreeQN(arXiv:1710.11417)中的等式 6 可以使用 einsum:给定网络层 $l$ 上的低维状态表示 $z_l$,和激活函数 a 上转换函数 $\mathbf W^a$,我们想要计算残差连接的下一层状态表示。 $$ \Large \mathbf z^a_{l+1}=\mathbf z_l+\tanh(\mathbf W^a\mathbf z_l) $$ 在实践中,我们想要高效地计算大小为 $B$ 的 batch 中 $K$ 维状态表示 $\mathbf Z \in \mathbb R^{B \times K}$,并同时计算所有转换函数(即,所有激活 $A$)。我们可以将这些转换函数安排为一个张量 $\mathcal W \in \mathbb R^{A \times K \times K}$,并使用 einsum 高效地计算下一层状态表示。

    import torch.nn.functional as F
    
    def random_tensors(shape, num=1,requires_grad=False):
      tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
      return tensors[0] if num == 1 else tensors
    
    # 参数
    # -- 【激活数 ✕ 隐藏层维度】
    b = random_tensors([5, 3], requires_grad=True)
    # -- 【激活数 ✕ 隐藏层维度 ✕ 隐藏层维度】
    W = random_tensors([5, 3, 3], requires_grad=True)
    
    def transition(zl):
        # -- [batch大小 x 激活数 x 隐藏层维度]
        return zl.unsqueeze(1) + F.tanh(torch.einsum("bk,aki->bai", [zl, W]) + b)
    
    # 随机取样仿造输入
    # -- [batch大小 x 隐藏层维度]
    zl = random_tensors([2, 3])
    
    transition(zl)
    
  3. 注意力

    再来看看注意力机制(arXiv:1509.06664)的真实例子。 $$ \Large \begin{align} \mathbf M_t &= \tanh(\mathbf W^yY+(\mathbf W^h\mathbf h_t+\mathbf W^r\mathbf r_{t-1})\otimes \mathbf e_L) &M &\in \mathbb R^{k \times L} \\ \alpha_t &= \text{softmax}(\mathbf w^T\mathbf M_t) &\alpha_t &\in \mathbb R^L \\ \mathbf r_t &= \mathbf Y\alpha^T_t+\tanh(\mathbf W^t\mathbf r_{t-1}) &\mathbf r_t &\in \mathbb R^k \end{align} $$ 用传统的写法实现这些要废不少力气,特别是考虑 batch 实现。使用 einsum 就简单多了。

    # 参数
    # -- [隐藏层维度]
    bM, br, w = random_tensors([7], num=3, requires_grad=True)
    # -- [隐藏层维度 x 隐藏层维度]
    WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)
    
    # 注意力机制的单次应用
    def attention(Y, ht, rt1):
        # -- [batch大小 x 隐藏层维度]
        tmp = torch.einsum("ik,kl->il", [ht, Wh]) + torch.einsum("ik,kl->il", [rt1, Wr])
        Mt = F.tanh(torch.einsum("ijk,kl->ijl", [Y, WY]) + tmp.unsqueeze(1).expand_as(Y) + bM)
        # -- [batch大小 x 序列长度]
        at = F.softmax(torch.einsum("ijk,k->ij", [Mt, w]))
        # -- [batch大小 x 隐藏层维度]
        rt = torch.einsum("ijk,ij->ik", [Y, at]) + F.tanh(torch.einsum("ij,jk->ik", [rt1, Wt]) + br)
        # -- [batch大小 x 隐藏层维度], [batch大小 x 序列维度]
        return rt, at
    
    # 取样仿造输入
    # -- [batch大小 x 序列长度 x 隐藏层维度]
    Y = random_tensors([3, 5, 7])
    # -- [batch大小 x 隐藏层维度]
    ht, rt1 = random_tensors([3, 7], num=2)
    
    rt, at = attention(Y, ht, rt1)
    

五、总结

einsum 是一个函数走天下,是处理各种张量操作的瑞士军刀。话虽如此,“einsum 满足你一切需要” 显然夸大其词了。从上面的真实用例,可以看到,我们仍然需要在 einsum 之外应用非线性和构造额外的维度(unsqueeze)。类似地,分割、连接、索引张量仍需要应用其他库函数。

使用 einsum 的麻烦之处是你需要手动实例化参数,操心它们的初始化,并在模型中注册这些参数,不过仍然强烈建议在实现模型的时候,考虑一下是否适合使用 einsum 。


参考

Avatar

YISH

这个人很懒,什么都没有