Home > database >  Can someone help to spot an edge case with error for this algorithm?
Can someone help to spot an edge case with error for this algorithm?

Time:01-03

I'm solving 'Non overlap intervals' problem on leetcode [https://leetcode.com/problems/non-overlapping-intervals/] In short, we need to define the minimum amount of intervals to delete to create non-overlapping set of them (number to delete is requested result).

And my solution is to build augmented interval tree ([https://en.wikipedia.org/wiki/Interval_tree#Augmented_tree]) out of all the intervals (for O((n log n) time complexity), then (the second traversal through the intervals) measure how many other intervals each given interval intersects (also for O((n log n) time complexity) (it gives also 1 self-intersection, but I use it only as relative metric) and sort all the intervals on this 'number-of intersections of others' metric. At the last step I just get intervals one by one out of the sorted, as described above, list and create non-overlapping set (have an explicit check for non-overlapping, using another instance of interval tree) forming the result set that should be deleted.

And below I give full code of the described solution to play on leetcode with.

The approach work sufficiently fast, BUT sometimes I get wrong, differs by 1, result. Leetcode doesn't give much feedback throwing back at me 'expected 810' instead of my result '811'. So I'm still debugging digging the 811 intervals.... :)

Even knowing other solutions to this problem I'd like find the case on which described approach fails (it can be useful edge case by itself). So if someone saw similar problem or just can spot it with some 'fresh eyes' - it would be the most appreciated!

Thank in advance for any constructive comments and ideas!

The solution code:

class Interval:
    def __init__(self, lo: int, hi: int):
        self.lo = lo
        self.hi = hi

class Node:
    def __init__(self, interval: Interval, left: 'Node' = None, right: 'Node' = None):
        self.left = left
        self.right = right
        self.interval = interval
        self.max_hi = interval.hi

class IntervalTree:
    def __init__(self):
        self.root = None

    def __add(self, interval: Interval, node:Node) -> Node:
        if node is None:
            node = Node(interval)
            node.max_hi = interval.hi
            return node

        if node.interval.lo > interval.lo:
            node.left = self.__add(interval, node.left)
        else:
            node.right = self.__add(interval, node.right)
        node.max_hi = max(node.left.max_hi if node.left else 0, node.right.max_hi if node.right else 0, node.interval.hi)
        return node

    def add(self, lo: int, hi: int):
        interval = Interval(lo, hi)
        self.root = self.__add(interval, self.root)

    def __is_intersect(self, interval: Interval, node: Node) -> bool:
        if node is None:
            return False
        if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
            # print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
            return True
        if node.left and node.left.max_hi > interval.lo:
            return self.__is_intersect(interval, node.left)
        return self.__is_intersect(interval, node.right)

    def is_intersect(self, lo: int, hi: int) -> bool:
        interval = Interval(lo, hi)
        return self.__is_intersect(interval, self.root)

    def __all_intersect(self, interval: Interval, node: Node) -> Iterable[Interval]:
        if node is None:
            yield from ()
        else:
            if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
                # print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
                yield node.interval
            if node.left and node.left.max_hi > interval.lo:
                yield from self.__all_intersect(interval, node.left)
            yield from self.__all_intersect(interval, node.right)

    def all_intersect(self, lo: int, hi: int) -> Iterable[Interval]:
        interval = Interval(lo, hi)
        yield from self.__all_intersect(interval, self.root)

class Solution:
    def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
        ranged_intervals = []

        interval_tree = IntervalTree()
        for interval in intervals:
            interval_tree.add(interval[0], interval[1])
        for interval in intervals:
            c = interval_tree.all_intersect(interval[0], interval[1])
            ranged_intervals.append((len(list(c))-1, interval))  # decrement intersection to account self intersection

        interval_tree = IntervalTree()
        res = []
        ranged_intervals.sort(key=lambda t: t[0], reverse=True)
        while ranged_intervals:
            _, interval = ranged_intervals.pop()
            if not interval_tree.is_intersect(interval[0], interval[1]):
                interval_tree.add(interval[0], interval[1])
            else:
                res.append(interval)

        return len(res)

CodePudding user response:

To make a counter example for your algorithm, you can construct a problem where selecting the segment with the fewest number of intersections ruins the solution, like this:

[----][----][----][----]
[-------][----][-------]
[-------]      [-------]
[-------]      [-------]
[-------]      [-------]

Your algorithm will choose the center interval first, which is incompatible with the optimal solution:

[----][----][----][----]

An algorithm that does work is, while there are any overlaps:

  1. Find the left-most point of overlap
  2. Pick any two intervals that overlap that point, and delete the one that extends farthest to the right.

This algorithm is also very simple to implement. You can do it in a single traversal through the list of intervals, sorted by start point:

class Solution:
    def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
        intervals.sort()
        extent = None
        deletes = 0
        for interval in intervals:
            if extent == None or extent <= interval[0]:
                extent = interval[1]
            else:
                deletes  = 1
                extent = min(extent, interval[1])
        return deletes
  • Related