Home > Software design >  Ensure that every column in a matrix has at least `e` non-zero elements
Ensure that every column in a matrix has at least `e` non-zero elements

Time:08-27

I would like to ensure that each column in a matrix has at least e non-zero elements, and for each column that does not randomoly replace zero-valued elements with the value y until the column contains e non-zero elements. Consider the following matrix where some columns have 0, 1 or 2 elements. After the operation, each column should have at least e elements of value y.

before

tensor([[0, 7, 0, 0],
        [0, 0, 0, 0],
        [0, 1, 0, 4]], dtype=torch.int32)

after, e = 2

tensor([[y, 7, 0, y],
        [y, 0, y, 0],
        [0, 1, y, 4]], dtype=torch.int32)

I have a very slow and naive loop-based solution that works:

def scatter_elements(x, e, y):
  for i in range(x.shape[1]):
    col = x.data[:, i]
    num_connections = col.count_nonzero()
    to_add = torch.clip(e - num_connections, 0, None)
    indices = torch.where(col == 0)[0]
    perm = torch.randperm(indices.shape[0])[:to_add]
    col.data[indices[perm]] = y

Is it possible to do this without loops? I've thought about using torch.scatter and generate an index array first, but since the number of elements to be added varies per column, I see no straightforward way to use it. Any suggestions or hints would be greatly appreciated!

Edit: swapped indices and updated title and description based on comment.

CodePudding user response:

In the case where you care only that each column has at least e elements and not EXACTLY e elements, you can do it without a loop. The key is that in this case, we can create an array with every non-zero value replaced, and then sample e values from this array for each column.

For convenience let x.shape = [a,b]

  1. Create an array replace with every value replaced (i.e. every 0 replaced with y).
  2. Create a random array of same size as x.
  3. Use torch.topk to get the k largest random numbers per column. This is used to get k random indices for each column (in your case k = e). Provided that x is non-negative integer, you can add x to the random array before the topk operation to ensure that the existing non-zero elements are selected first; this ensures that no more than e connections are added.
  4. Replace the indexed values per row with the values from replace.
def scatter_elements(x,e,y):
  x = x.float()
  # 1. replace has same shape as x and has all 0s replaced with y
  replace = torch.where(x > 0 , x, torch.ones(x.shape)*y)

  # 2-3. get random indices per column
  randn = torch.rand(x.shape)
  if True: # True if you don't want the modification to ever itself assign more than e elements in a column a non-zero value
      randn  = x # assumes x is non-negative integer

  ind = torch.topk(randn,e,dim = 0)[1] # first return is values, second return is indices

  # create a second index to indicate which column each index in ind corresponds to
  col_ind = torch.arange(x.shape[1]).unsqueeze(0).expand(ind.shape)

  # 4. Index x with ind and col_ind and set these values to the corresponding values in replace
  ind = ind.reshape(-1)         # flatten into 1D array so we can use to index row
  col_ind = col_ind.reshape(-1) # flatten into 1D array so we can use to index column
  x[ind,col_ind] = replace[ind,col_ind]
  
  return x

In my limited timing tests, the vectorized solution was about 5-6x faster than the original looping solution.

CodePudding user response:

After some more experimentation with my original loop-based solution I realized that I was already pretty close to a vectorized version:

def scatter_elements_vec(x, e, y):
    rand = torch.rand_like(x, dtype=torch.float)
    rand[x != 0] = torch.inf
    num_connections = x.count_nonzero(0)
    to_add = torch.clip(e - num_connections, 0, None)
    mask = torch.argsort(rand, 0) < to_add
    x[mask] = y
  • Related