python – numpy einsum的替代品

当我计算具有N行和n列的矩阵X的三阶矩时,我通常使用einsum:

M3 = sp.einsum('ij,ik,il->jkl',X,X,X) /N

这通常很好,但现在我正在使用更大的值,即n = 120和N = 100000,并且einsum返回以下错误:

ValueError: iterator is too large

做3个嵌套循环的替代方案是不可行的,所以我想知道是否有任何替代方案.

最佳答案
请注意,计算此值至少需要进行~n3×N = 1730亿次操作(不考虑对称性),因此除非numpy可以访问GPU或其他东西,否则它将会很慢.在具有~3 GHz CPU的现代计算机上,假设没有SIMD /并行加速,整个计算预计需要大约60秒才能完成.

为了测试,让我们从N = 1000开始.我们将使用它来检查正确性和性能:

#!/usr/bin/env python3

import numpy
import time

numpy.random.seed(0)

n = 120
N = 1000
X = numpy.random.random((N, n))

start_time = time.time()

M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)

end_time = time.time()

print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)

这大约需要8秒钟.这是基线.

让我们从3嵌套循环开始作为替代:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds

这大约需要半分钟,没有好处!一个原因是因为这实际上是四个嵌套循环:numpy.sum也可以被认为是一个循环.

我们注意到总和可以变成点积来移除第4个循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        for l in range(n):
            M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds

现在好多了,但仍然很慢.但是我们注意到点积可以改成矩阵乘法来移除一个循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    for k in range(n):
        M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds

咦?现在这比einsum更有效!我们还可以检查答案是否应该是正确的.

我们可以走得更远吗?是!我们可以通过以下方法消除k循环:

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = numpy.repeat(X[:,j], n).reshape((N, n))
    M3[j] = (Y * X).T @ X
# ~0.3 seconds

我们也可以使用广播(即a * [b,c] == [a * b,a * c]为X的每一行)以避免做numpy.repeat(感谢@Divakar):

M3 = numpy.zeros((n, n, n))
for j in range(n):
    Y = X[:,j].reshape((N, 1))
    ## or, equivalently: 
    # Y = X[:, numpy.newaxis, j]
    M3[j] = (Y * X).T @ X
# ~0.16 seconds

如果我们将其缩放到N = 100000,程序预计需要16秒,这在理论极限内,因此消除j可能没有太多帮助(但这可能使代码真的很难理解).我们可以接受这个作为最终解决方案.

注意:如果您使用的是Python 2,则@ b等效于a.dot(b).

转载注明原文:python – numpy einsum的替代品 - 代码日志