Home > other >  Given xor and the range of 2 numbers, find their maximum possible sum
Given xor and the range of 2 numbers, find their maximum possible sum

Time:07-27

I got this task:

You are given 5 integers a,b,c,d,k. Print the maximum value of x y that follows the given conditions:

  • a<=x<=b
  • c<=y<=d
  • x^y=k ('^' sign denotes XOR operation)

Constraints:

  • 0 <= a <= b <= 10^18
  • 0 <= c <= d <= 10^18

Explanation:

  • x and y are the 2 numbers,
  • k is their xor value ,
  • [a,b] is range of x ,
  • [c,d] is range of y.

My Attempt:

found=False
a,b,c,d,k=map(int,input().split())
for x in range(b,a,-1):
    if found==True:
        break
    for y in range(d,c,-1):
        if x^y ==k:
            print(x y)
            found=True
            break

I know its the brute force but this is the only algorithm I can think of to solve the problem but this is obviously not gonna work as the time complexity is O((b-a)*(d-c)) or in the worst case, it could take 10^36 operations. This approach needs to be optimized to logarithmic or constant time complexity.

Reading similar question from here,

X Y = (X ^ Y) 2 * (X & Y)

So,

ans = k 2*(X&Y)

So, I need to find the maximum value of and operation of 2 numbers whose range is given. But how to do it?

Any help is appreciated, thanks.

CodePudding user response:

10^18 is about 2^60 (a little smaller). You can work on the bits and only check numbers that would give a valid xor result. This is already a lot better than your algorithm, but I don't know if it is good enough.

public long solve(long a, long b, long c, long d, long k) {
    return solve(a, b, c, d, k, 0L, 0L, 1L << 59);
}

private long solve(long a, long b, long c, long d, long k, long x, long y, long mask) {
    if (mask == 0)
        return x >= a && x <= b && y >= c && y <= d ? x   y : -1L;
    if ((mask & k) == 0) {
        if ((mask | x) <= b && (mask | y) <= d) {
            long r = solve(a, b, c, d, k, x | mask, y | mask, mask >> 1);
            if (r > 0)
                return r;
        }
        if ((mask | x) > a && (mask | y) > c)
            return solve(a, b, c, d, k, x, y, mask >> 1);
    } else {
        if ((mask | x) > a && (mask | y) <= d) {
            long r = solve(a, b, c, d, k, x, y | mask, mask >> 1);
            if (r > 0)
                return r;
        }
        if ((mask | x) <= b && (mask | y) > c)
            return solve(a, b, c, d, k, x | mask, y, mask >> 1);
    }
    return -1L;
}

CodePudding user response:

Look at the numbers in base 2 as arrays of bits, largest to smallest.

There are 4 inequality constraints:

  1. a<=x Satisfied at end, or if x first has a 1 where a has a 0.
  2. x<=b Satisfied at end, or if x first has a 0 where b has a 1.
  3. c<=y Satisfied at end, or if y first has a 1 where c has a 0.
  4. y<=d Satisfied at end, or if y first has a 0 where d has a 0.

So after each bit we have a state consisting of which inequality constraints are currently done or active. This state can be represented by a number in the range 0..15. For each state we only care about the largest sum of the values of the already set bits for x and y.

This is a perfect setup for dynamic programming.

def to_bits (n):
    answer = []
    while 0 < n:
        answer.append(n&1)
        n = n >> 1
    return answer

def solve (a, b, c, d, k):
    a_bits = to_bits(a)
    b_bits = to_bits(b)
    c_bits = to_bits(c)
    d_bits = to_bits(d)
    k_bits = to_bits(k)

    s = max(len(a_bits), len(b_bits), len(c_bits), len(d_bits), len(k_bits))
    while len(a_bits) < s:
        a_bits.append(0)
    while len(b_bits) < s:
        b_bits.append(0)
    while len(c_bits) < s:
        c_bits.append(0)
    while len(d_bits) < s:
        d_bits.append(0)
    while len(k_bits) < s:
        k_bits.append(0)

    a_open = 1
    b_open = 2
    c_open = 4
    d_open = 8
    best_by_state = {15: 0}
    for i in range(s-1, -1, -1):
        next_best_by_state = {}
        power = 2**i
        if 0 == k_bits[i]:
            choices = [(0, 0), (1, 1)]
        else:
            choices = [(0, 1), (1, 0)]
        for state, value in best_by_state.items():
            for choice in choices:
                next_state = state

                # Check all conditions, noting state changes.
                if (state & a_open):
                    if choice[0] < a_bits[i]:
                        continue
                    elif a_bits[i] < choice[0]:
                        next_state -= a_open

                if (state & b_open):
                    if b_bits[i] < choice[0]:
                        continue
                    elif choice[0] < b_bits[i]:
                        next_state -= b_open

                if (state & c_open):
                    if choice[1] < c_bits[i]:
                        continue
                    elif c_bits[i] < choice[1]:
                        next_state -= c_open

                if (state & d_open):
                    if d_bits[i] < choice[1]:
                        continue
                    elif choice[1] < d_bits[i]:
                        next_state -= d_open

                next_value = value   power * sum(choice)

                if next_best_by_state.get(next_state, -1) < next_value:
                    next_best_by_state[next_state] = next_value

        best_by_state = next_best_by_state

    possible = best_by_state.values()
    if 0 < len(possible):
        return max(possible)
    else:
        return None

print(solve(1000000000000000,2000000000000000,3000000000000000,3600000000000000,3333000333000333))

This program has performance linear in the number of bits.

  • Related