I would like to replicate the following operation:
import numpy as np
a = np.eye(3)
A = np.einsum("ij,kl->ikjl", a,a)
I have the following:
def IndexPermutation(stringWithTensorIndices, Tensor):
# Function to split the string with the indices
def split(word):
return [char for char in word]
# Gets indices from string
a, b, c, d = split(stringWithTensorIndices)
store=np.zeros((3,3,3,3))
# Performs the tensor index permutations
for i in [0,1,2]:
for j in [0,1,2]:
for k in [0,1,2]:
for l in [0,1,2]:
store[i,j,k,l]=Tensor[vars()[a],vars()[b],vars()[c],vars()[d]]
return store
I would like to remove the vars()
function. And currently, it is only accepting "ij" "kl" and afterward permuting the indices (equivalent to einsum ("some combination of ijkl -> ijkl").
Can anyone give me a hand with this?
CodePudding user response:
I understand that what you want to do is provide some input stringWithTensorIndices
with values like "ijkl"
or "jlki"
to determine how the summation will be achieved. Then the logic of the loop is
- for the case where the input is "jlki", a="j"
- and so
vars()[a]
=vars()["j"]
- this is the value of the loop variable
j
- so you get
store[i,j,k,l] = Tensor[j,l,k,i]
There are a few things to say here.
First, your split
function isn't doing anything for you. a,b,c,d = "jkli"
will correctly resolve to a="j"
etc. because a string can itself be treated as a Python collection. There's no safety for incorrect user input either way. Best practice might be to catch the ValueError
and provide an informative exception. In practice I think we'll take a different approach anyway.
Now, the important part. At any point in your heavily-nested loop, you are going to have some indices [i, j, k, l]
(these being your loop variables). The input is kind of like a key for choosing from this collection of indices.
For example, "jlki" should pick out indices[1], indices[3], indices[2], indices[0]
.
What we need is a helper function:
tensor_indices = "iljk"
lookup = {"i": 0, "j": 1, "k": 2, "l": 3}
mapped_indices = [lookup[x] for x in tensor_indices]
def get_indices(loop_indices, mapped_indices):
# loop_indices is [i, j, k, l], changes every iteration
# mapped indices is [0, 3, 1, 2], never changes
for_tensor = tuple(loop_indices[x] for x in mapped_indices)
return for_tensor # (i, l, j, k), as was required
Returning as a tuple, not a list, is important because if you index the multi-dimensional numpy array a
with a list, a[[0,2,1,0]]
, that is quite different to indexing it as a[0,2,1,0]
. The former extracts slices from the array along its first dimension, the latter picks out a specific value.
Fortunately, Python will generally try to treat a tuple like a comma-separated list of variables, so a[t]
where t = (0,2,1,0)
is equivalent to a[0,2,1,0]
.
With this in hand I think we can move to a full solution:
import numpy as np
from typing import List, Tuple
# type annotations are useful, and mean you don't have to use
# names like stringWithTensorIndices to keep track
def get_indices(
loop_indices: List[int], mapped_indices: List[int]
) -> Tuple[int, int, int, int]:
"""
get_indices
This function takes a set of loop variables, e.g. i, j, k, l
and uses a set of mapped indices (e.g. [0, 1, 1, 2] for 'ijjk')
and returns loop variables chosen according to the mapping
(i.e. (i, j, j, k) in this example).
"""
for_tensor = tuple(loop_indices[x] for x in mapped_indices)
return for_tensor # (i, l, j, k), as was required
def index_permutation(tensor_indices: str, tensor: np.ndarray) -> np.ndarray:
"""
index_permutation
Permutes indices of a tensor
Inputs:
tensor_indices - e.g. "ijkl"
tensor - the tensor to operate on
Output:
new_tensor - the resulting tensor
"""
lookup = {"i": 0, "j": 1, "k": 2, "l": 3}
# for e.g. 'iljk' we want 0, 3, 1, 2:
mapped_indices = [lookup[x] for x in tensor_indices]
dim = 3
shape = (dim, dim, dim, dim)
new_tensor = np.zeros(shape)
for i in range(dim):
for j in range(dim):
for k in range(dim):
for l in range(dim):
indices = [i, j, k, l]
reordered = get_indices(indices, mapped_indices)
new_tensor[i,j,k,l]=Tensor[reordered]
return store
Miscellania
I've gone ahead and implemented some best-practices here:
- D.R.Y. (don't repeat yourself) - a lot of stuff to change simultaneously if you want to change the behaviour in future, so [0,1,2] for example just gets stored as one list that you reference four times
- builtin iterators like
range(N)
,enumerate
,zip
are handy - docstrings, comments - future you will thank you even if no-one else does
- type annotations - even if you aren't using
mypy
to check types, they aid readability especially where you have a bunch of variables with similar names. In the newest python you don't even needList
, regularlist[int]
will do. - standard python conventions -
snake_case
for variable and function names,CamelCase
for class names
One thing I haven't done is attempted to eliminate that heavily nested loop. There are ways to do that with numpy
(it's what it's good at), but it seemed to me that this would be beside the point of the exercise.