I would like to ensure that each column in a matrix has at least e
non-zero elements, and for each column that does not randomoly replace zero-valued elements with the value y
until the column contains e
non-zero elements. Consider the following matrix where some columns have 0, 1 or 2 elements. After the operation, each column should have at least e
elements of value y
.
before
tensor([[0, 7, 0, 0],
[0, 0, 0, 0],
[0, 1, 0, 4]], dtype=torch.int32)
after, e = 2
tensor([[y, 7, 0, y],
[y, 0, y, 0],
[0, 1, y, 4]], dtype=torch.int32)
I have a very slow and naive loop-based solution that works:
def scatter_elements(x, e, y):
for i in range(x.shape[1]):
col = x.data[:, i]
num_connections = col.count_nonzero()
to_add = torch.clip(e - num_connections, 0, None)
indices = torch.where(col == 0)[0]
perm = torch.randperm(indices.shape[0])[:to_add]
col.data[indices[perm]] = y
Is it possible to do this without loops? I've thought about using torch.scatter
and generate an index
array first, but since the number of elements to be added varies per column, I see no straightforward way to use it. Any suggestions or hints would be greatly appreciated!
Edit: swapped indices and updated title and description based on comment.
source https://stackoverflow.com/questions/73501398/ensure-that-every-column-in-a-matrix-has-at-least-e-non-zero-elements
Comments
Post a Comment