Home > Mobile >  Calculating multilabel recall for this problem
Calculating multilabel recall for this problem

Time:12-03

I have a table with two columns, and the two entries of a row show that they are related:

Col1 Col2
a A
b B
a C
c A
b D

Here a is related to A, C and b to B, D and c to A, meaning the same entry in col1 might have multiple labels in col2 related. I trained a Machine Learning model to quantify the relationship between Col1 and Col2 by creating a vector embedding of Col1 and Col2 and optimizing the cosine_similarity between the two vectors. Now, I want to test my model by calculating Recall on a test set. I want to check if at various recall@N, what proportion of these positive relationships can be retrieved. Suppose I have normalized vector representation of all entries in each column, then I can calculate the cosine distance between them as :

cosine_distance = torch.mm(col1_feature, col2_feature.t())

which gives a matrix of distances between all pairs that can be formed between col1 and col2.

dist(a,A) dist(a,B) dist(a,C) dist(a,A) dist(a, D)
dist(b,A) dist(b,B) dist(b,C) dist(b,A) dist(b, D)
dist(a,A) dist(a,B) dist(a,C) dist(a,A) dist(a, D)
dist(c,A) dist(c,B) dist(c,C) dist(c,A) dist(c, D)
dist(b,A) dist(b,B) dist(b,C) dist(b,A) dist(b, D)

I can then calculate which pairs have largest distance to calculate recall@k. My question is how can I make this efficient for a millions of rows. I found out this module in pytorch: torchmetrics.classification.MultilabelRecall(https://torchmetrics.readthedocs.io/en/stable/classification/recall.html), that seems to be useful but for that I need to specify number of labels. In my case, I can have variable number of labels for each unique entry of col1. Any ideas?

CodePudding user response:

You can use a clustering algorithm to group the entries in Col1 and Col2 into clusters. Then you can use the MultilabelRecall metric to calculate the recall for each cluster. This way, you don't have to specify the number of labels for each entry in Col1.

CodePudding user response:

If you have a large number of rows in your table, it may be inefficient to calculate the cosine distance between all pairs of entries in Col1 and Col2. One way to make this more efficient is to use approximate nearest neighbor (ANN) algorithms, which can quickly find the closest vectors in a high-dimensional space. These algorithms typically involve constructing a data structure that allows for efficient search, such as a k-d tree or locality-sensitive hashing. Once you have built this data structure, you can use it to quickly find the entries in Col2 that are closest to a given entry in Col1, and then calculate the recall@k for those entries.

Here is an example of how you might use an ANN algorithm to calculate the recall@k in your case. This code uses the k-d tree implementation in the scikit-learn library to index the vectors in Col1 and Col2, and then finds the nearest neighbors of each vector in Col1 using the k-d tree. It then calculates the recall@k for the nearest neighbors of each vector in Col1.

from sklearn.neighbors import KDTree

# Create a k-d tree to index the vectors in Col1 and Col2
tree = KDTree(np.concatenate((col1_feature, col2_feature), axis=0))

# Find the nearest neighbors of each vector in Col1 using the k-d tree
# This returns a tuple containing the indices of the nearest neighbors
# in Col2 and the distances to those neighbors
neighbors = tree.query(col1_feature, k=k)

# Calculate the recall@k for each vector in Col1
recall_at_k = 0
for i, (neighbor_indices, distances) in enumerate(neighbors):
    # Get the labels for the nearest neighbors of the current vector
    neighbor_labels = col2[neighbor_indices]

    # Count the number of true labels among the nearest neighbors
    true_labels = 0
    for label in neighbor_labels:
        if label in true_labels_for_col1[i]:
            true_labels  = 1

    # Calculate the recall@k for the current vector
    recall_at_k  = true_labels / k

# Calculate the average recall@k over all vectors in Col1
average_recall_at_k = recall_at_k / len(col1)
  • Related