Home > Mobile >  How to efficiently count all concatenations of 2-tuples into longer chains in Python
How to efficiently count all concatenations of 2-tuples into longer chains in Python

Time:04-10

Let us say that we would like to build a long (metal) chain which will be composed of smaller links, chained together. I know what the length of the chain should be: n. The links are represented as 2-tuples: (a, b). We may chain links together if and only if they share the same element at the side by which they would be chained.
I am given a list of lists of length n-1 - links - which represents all links available to me at each position of the chain. For example:

links = [
    [
        ('a', 1),
        ('a', 2),
        ('a', 3),
        ('b', 1),
    ],
    [
        (1, 'A'),
        (2, 'A'),
        (2, 'B'),
    ],
    [
        ('A', 'a'),
        ('B', 'a'),
        ('B', 'b'),
    ]
]

In this case the length of the final chain will be: n = 4.
Here we may generate these possible chains:

('a', 1, 'A', 'a')
('b', 1, 'A', 'a')
('a', 2, 'A', 'a')
('a', 2, 'B', 'a')
('a', 2, 'B', 'b')

This procedure is quite similar to forming a long line with domino puzzles, however I cannot rotate the tiles.

My task is that given such an input list I need to calculate all possible distinct chains of length n that may be created. The case above is a simplified toy example but in reality the chain's length may be as high as 1000 and I may be able to use tens of different links at each specific position. However, I know that for sure for each link available at position i there exists another link at position i-1 which is compatible to it.

I wrote a very naive solution with iterates over all links from beginning to end and merges them together, growing all possible versions of the final chain:


    # THIS CODE WAS ORIGINALLY BUGGED ONCE I POSTED IT
    # BUT IS FIXED NOW

    # initiate chains with links that could make up
    # the first position, then: iteratively grow them
    chains = links[0]
    
    # seach for all possible paths:
    # iterate over all positions
    for position in links[1:]:
        
        # temp array to help me grow the chain
        temp = []
            
        # over each chain in the current set of chains
        for chain in chains:

            # over each link in a given position
            for link in position:
                
                # check if the chain and link are chainable
                if chain[-1] == link[0]:
                    
                    # append new link to a pre-existing chain
                    temp.append(chain   tuple([link[1]]))
        
        # overwrite the current list of chains
        chains = temp

This solution works fine, i.e. I am quite convinced it returns a correct result. However, it is extremely slow, I need to speed it up, preferably ~100x. Therefore I think I need to employ a smart algorithm to count all the possibilities, not a brute-force concatenation as above... Since I only need to count the chains, not enumerate them, maybe there would be a backtracking procedure which would start from each possible final link and multiply possibilities along the way; in the end adding up over all final links? I have some vague ideas but cannot really nail this down...

CodePudding user response:

Since counting is enough, let's just do that, and then it takes a split second for large cases as well.

from collections import defaultdict, Counter

def count_chains(links):
    chains = defaultdict(lambda: 1)
    for position in links:
        temp = Counter()
        for a, b in position:
            temp[b]  = chains[a]
        chains = temp
    return sum(chains.values())

It does pretty much the same as yours, except instead of chains being a list of chains ending in some b-values, I'm using a Counter of chains ending in some b-values: chains[b] tells me how many chains end in b. And Counters (and defaultdict) are dictionaries, so I don't have to search and check for matching connectors, I just look them up.

The backwards compatibility means we might better go backwards, so we're not tracking dead ends, but I don't think it would help much if at all (depends on your data).

For example for links = [[(1, 1), (1, 2), (2, 1), (2, 2)]] * 1000, it takes about 2 ms to compute the number of chains, which is:

21430172143725346418968500981200036211228096234110672148875007767407021022498722449863967576313917162551893458351062936503742905713846280871969155149397149607869135549648461970842149210124742283755908364306092949967163882534797535118331087892154125829142392955373084335320859663305248773674411336138752

Try it online!

CodePudding user response:

Here is my solution with Graph data structure approach which will be more efficient that yours which has cubic time complexity O(n3).

links = [
    [
        ('a', 1),
        ('a', 2),
        ('b', 3),
        ('b', 1),
    ],
    [
        (1, 'A'),
        (2, 'A'),
        (2, 'B'),
    ],
    [
        ('A', 'a'),
        ('B', 'a'),
        ('B', 'b'),
    ]
]

class Node:
    def __init__(self, val, next=None):
        self.val = val
        self.next = next if next else []
        self.visited = False
    def __repr__(self):
        return f'({self.val})'
    
nodes = {}
start_nodes = set(i[0] for i in links[0]) # {'a', 'b'}

# constructiong the graph
for i in links:
    for j in i:
        if j[0] not in nodes:
            nodes[j[0]] = Node(j[0])
        if j[1] not in nodes:
            nodes[j[1]] = Node(j[1])
            
        nodes[j[0]].next.append(nodes[j[1]])

def find_chain_with_length(start_node, node, length, valid_length, solution):
    solution[-1].append(node.val)
    
    if length  1 == valid_length:
        solution.append(solution[-1][:-1])
        return
    
    # if already visited just return
    if node.visited:
        if solution[-1]:
            solution[-1].pop()
        return         
            
    # if this is not leaf node
    # we will mark it visited
    node.visited = True
    for each_neighbor in node.next:
        find_chain_with_length(start_node, each_neighbor, length 1, valid_length, solution)
    if solution[-1]:
        solution[-1].pop()
    # after visiting mark it unvisited
    node.visited = False
    
solution = [[]]
for each_start_node in start_nodes:
    find_chain_with_length(nodes[each_start_node], nodes[each_start_node],0, 4, solution)
solution.pop()

print(solution)
[['a', 1, 'A', 'a'], ['a', 2, 'A', 'a'], ['a', 2, 'B', 'a'], ['a', 2, 'B', 'b'], ['b', 1, 'A', 'a']]
  • Related