I want to implement the connect 4 game in python as a hobby project, and I have no idea, why the search for matches on the diagonals is so slow.
When analyzing my code with psstats
, I found that this is the bottleneck.
I want to build a computer enemy which analyzes thousands of future steps in the game, therefore performance is an issue.
Does anyone have an idea, how to improve performance the following code? I chose numpy to do this, because I thought that would speed things up. The problem is, that I could not find a way to avoid a for loop.
import numpy as np
# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array
def findseq(sm,seq=2,redyellow=1):
matches=0
# search in the diagonals
# diags stores all the diagonals and off diagonals as rows of a matrix
diags=np.zeros((1,6),dtype=np.int8)
for k in range(-5,7):
t=np.zeros(6,dtype=np.int8)
a=np.diag(sm,k=k).copy()
t[:len(a)] = a
s=np.zeros(6,dtype=np.int8)
a=np.diag(np.fliplr(sm),k=k).copy()
s[:len(a)] = a
diags=np.concatenate(( diags,t[None,:],s[None,:]),axis=0)
diags=np.delete(diags,0,0)
# print(diags)
# now, search for sequences
Na=np.size(diags,axis=1)
n=np.arange(Na-seq 1)[:,None] np.arange(seq)
seqmat=np.all(diags[:,n]==redyellow,axis=2)
matches =seqmat.sum()
return matches
def randomdebug():
# sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
sm=np.random.randint(0,3,size=(6,7))
return sm
# in my main program, I need to do this thousands of times
matches=[]
for i in range(1000):
sm=randomdebug()
matches.append(findseq(sm,seq=3,redyellow=1))
matches.append(findseq(sm,seq=3,redyellow=2))
# print(sm)
# print(findseq(sm,seq=3))
Here are the psstats
ncalls tottime percall cumtime percall filename:lineno(function)
2000 1.965 0.001 4.887 0.002 Frage zu diag.py:4(findseq)
151002/103002 0.722 0.000 1.979 0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
48000 0.264 0.000 0.264 0.000 {method 'diagonal' of 'numpy.ndarray' objects}
48072 0.251 0.000 0.251 0.000 {method 'copy' of 'numpy.ndarray' objects}
48000 0.209 0.000 0.985 0.000 twodim_base.py:240(diag)
48000 0.179 0.000 1.334 0.000 <__array_function__ internals>:177(diag)
50000 0.165 0.000 0.165 0.000 {built-in method numpy.zeros}
I am new to python, so please imagine a tag "hopeless noob" ;-)
CodePudding user response:
As stated in comment by Andrey, the code is calling a lot of np functions that require additional memory allocations. I believe that is the bottleneck.
I would suggest precomputing indices of all diagonals, since they won't change much in your case (matrix shape remains the same, the sequence may change I guess). Then you can use them to address diagonals fast:
import numpy as np
known_diagonals = dict()
def diagonal_indices(h: int, w: int, length: int = 3) -> np.array:
'''
Returns array (shape diagonal_count x length) of diagonal indices
of a flatten array
'''
# one of many ways to store precomputed function output
# cleaner way would probably be to do this outside this function
diagonal_indices_key = (h, w, length)
if diagonal_indices_key in known_diagonals:
return known_diagonals[diagonal_indices_key]
diagonals_count = (h 1 - length) * (w 1 - length) * 2
# default value is meant to ease process with cumsum:
# adding h 1 selects an index 1 down and 1 right, h - 1 index 1 down 1 left
# firts half dedicated to right down diagonals
diagonals = np.full((diagonals_count, length), w 1, dtype=np.int32)
# second half dedicated to left down diagonals
diagonals[diagonals_count//2::] = w - 1
# this could have been calculated mathematicaly
flat_indices = np.arange(w * h).reshape((h, w))
# print(flat_indices)
# selects rectangle offseted by l - 1 from right and down edges
diagonal_starts_rd = flat_indices[:h 1 - length, :w 1 - length]
# selects rectangle offseted by l - 1 from left and down edges
diagonal_starts_ld = flat_indices[:h 1 - length, -(w 1 - length):]
# sets starts
diagonals[:diagonals_count//2, 0] = diagonal_starts_rd.flatten()
diagonals[diagonals_count//2::, 0] = diagonal_starts_ld.flatten()
# sum triplets left to right
# diagonals contains triplets (or vector of other length) of (start, h -1, h -1). cumsum makes diagonal indices
diagonals = diagonals.cumsum(axis=1)
# save ouput
known_diagonals[diagonal_indices_key] = diagonals
return diagonals
# Finds all the diagonal and off-diagonal-sequences in a 7x6 numpy array
def findseq(sm: np.array, seq: int = 2, redyellow: int = 1) -> int:
matches = 0
diagonals = diagonal_indices(*sm.shape, seq)
seqmat = np.all(sm.flatten()[diagonals] == redyellow, axis=1)
matches = seqmat.sum()
return matches
def randomdebug():
# sm=np.array([[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,0,0,0,0,0],[0,0,2,1,1,0,0]])
sm=np.random.randint(0,3,size=(6,7))
return sm
# in my main program, I need to do this thousands of times
matches=[]
for i in range(1000):
sm=randomdebug()
matches.append(findseq(sm,seq=3,redyellow=1))
matches.append(findseq(sm,seq=3,redyellow=2))
# print(sm)
# print(findseq(sm,seq=3))