找到一种更好的计算矩阵的方法

mar*_*all 8 python algorithm math performance numpy

我想计算只有1和0条目的2d数组的数量,这些条目具有一对具有相等矢量和的不相交的不相交行对.对于4乘4矩阵,下面的代码通过迭代所有这些并依次测试每个代码来实现这一点.

import numpy as np
from itertools import combinations
n = 4
nxn = np.arange(n*n).reshape(n, -1)
count = 0
for i in xrange(2**(n*n)):
   A = (i >> nxn) %2
   p = 1
   for firstpair in combinations(range(n), 2):
       for secondpair in combinations(range(n), 2):
           if firstpair < secondpair and not set(firstpair) & set(secondpair):
              if (np.array_equal(A[firstpair[0]] + A[firstpair[1]], A[secondpair[0]] + A[secondpair[1]] )):
                  if (p):
                      count +=1
                      p = 0
print count
Run Code Online (Sandbox Code Playgroud)

输出是3136.

这个问题是它使用了2 ^(4 ^ 2)次迭代,并且我希望将它运行为n到8次.是否有更聪明的方法来计算这些而不迭代所有矩阵?例如,一遍又一遍地创建相同矩阵的排列似乎毫无意义.

Dav*_*tat 8

使用CPython 3.3在我的机器上大约一分钟计算:

4 3136
5 3053312
6 7247819776
7 53875134036992
8 1372451668676509696
Run Code Online (Sandbox Code Playgroud)

代码,基于记忆包含 - 排除:

#!/usr/bin/env python3
import collections
import itertools

def pairs_of_pairs(n):
    for (i, j, k, m) in itertools.combinations(range(n), 4):
        (yield ((i, j), (k, m)))
        (yield ((i, k), (j, m)))
        (yield ((i, m), (j, k)))

def columns(n):
    return itertools.product(range(2), repeat=n)

def satisfied(pair_of_pairs, column):
    ((i, j), (k, m)) = pair_of_pairs
    return ((column[i] + column[j]) == (column[k] + column[m]))

def pop_count(valid_columns):
    return bin(valid_columns).count('1')

def main(n):
    pairs_of_pairs_n = list(pairs_of_pairs(n))
    columns_n = list(columns(n))
    universe = ((1 << len(columns_n)) - 1)
    counter = collections.defaultdict(int)
    counter[universe] = (- 1)
    for pair_of_pairs in pairs_of_pairs_n:
        mask = 0
        for (i, column) in enumerate(columns_n):
            mask |= (int(satisfied(pair_of_pairs, column)) << i)
        for (valid_columns, count) in list(counter.items()):
            counter[(valid_columns & mask)] -= count
    counter[universe] += 1
    return sum(((count * (pop_count(valid_columns) ** n)) for (valid_columns, count) in counter.items()))
if (__name__ == '__main__'):
    for n in range(4, 9):
        print(n, main(n))
Run Code Online (Sandbox Code Playgroud)