I'm solving 'Non overlap intervals' problem on leetcode [https://leetcode.com/problems/non-overlapping-intervals/] In short, we need to define the minimum amount of intervals to delete to create non-overlapping set of them (number to delete is requested result).
And my solution is to build augmented interval tree ([https://en.wikipedia.org/wiki/Interval_tree#Augmented_tree]) out of all the intervals (for O((n log n) time complexity), then (the second traversal through the intervals) measure how many other intervals each given interval intersects (also for O((n log n) time complexity) (it gives also 1 self-intersection, but I use it only as relative metric) and sort all the intervals on this 'number-of intersections of others' metric. At the last step I just get intervals one by one out of the sorted, as described above, list and create non-overlapping set (have an explicit check for non-overlapping, using another instance of interval tree) forming the result set that should be deleted.
And below I give full code of the described solution to play on leetcode with.
The approach work sufficiently fast, BUT sometimes I get wrong, differs by 1, result. Leetcode doesn't give much feedback throwing back at me 'expected 810' instead of my result '811'. So I'm still debugging digging the 811 intervals.... :)
Even knowing other solutions to this problem I'd like find the case on which described approach fails (it can be useful edge case by itself). So if someone saw similar problem or just can spot it with some 'fresh eyes' - it would be the most appreciated!
Thank in advance for any constructive comments and ideas!
The solution code:
class Interval:
def __init__(self, lo: int, hi: int):
self.lo = lo
self.hi = hi
class Node:
def __init__(self, interval: Interval, left: 'Node' = None, right: 'Node' = None):
self.left = left
self.right = right
self.interval = interval
self.max_hi = interval.hi
class IntervalTree:
def __init__(self):
self.root = None
def __add(self, interval: Interval, node:Node) -> Node:
if node is None:
node = Node(interval)
node.max_hi = interval.hi
return node
if node.interval.lo > interval.lo:
node.left = self.__add(interval, node.left)
else:
node.right = self.__add(interval, node.right)
node.max_hi = max(node.left.max_hi if node.left else 0, node.right.max_hi if node.right else 0, node.interval.hi)
return node
def add(self, lo: int, hi: int):
interval = Interval(lo, hi)
self.root = self.__add(interval, self.root)
def __is_intersect(self, interval: Interval, node: Node) -> bool:
if node is None:
return False
if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
# print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
return True
if node.left and node.left.max_hi > interval.lo:
return self.__is_intersect(interval, node.left)
return self.__is_intersect(interval, node.right)
def is_intersect(self, lo: int, hi: int) -> bool:
interval = Interval(lo, hi)
return self.__is_intersect(interval, self.root)
def __all_intersect(self, interval: Interval, node: Node) -> Iterable[Interval]:
if node is None:
yield from ()
else:
if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
# print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
yield node.interval
if node.left and node.left.max_hi > interval.lo:
yield from self.__all_intersect(interval, node.left)
yield from self.__all_intersect(interval, node.right)
def all_intersect(self, lo: int, hi: int) -> Iterable[Interval]:
interval = Interval(lo, hi)
yield from self.__all_intersect(interval, self.root)
class Solution:
def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
ranged_intervals = []
interval_tree = IntervalTree()
for interval in intervals:
interval_tree.add(interval[0], interval[1])
for interval in intervals:
c = interval_tree.all_intersect(interval[0], interval[1])
ranged_intervals.append((len(list(c))-1, interval)) # decrement intersection to account self intersection
interval_tree = IntervalTree()
res = []
ranged_intervals.sort(key=lambda t: t[0], reverse=True)
while ranged_intervals:
_, interval = ranged_intervals.pop()
if not interval_tree.is_intersect(interval[0], interval[1]):
interval_tree.add(interval[0], interval[1])
else:
res.append(interval)
return len(res)
CodePudding user response:
To make a counter example for your algorithm, you can construct a problem where selecting the segment with the fewest number of intersections ruins the solution, like this:
[----][----][----][----]
[-------][----][-------]
[-------] [-------]
[-------] [-------]
[-------] [-------]
Your algorithm will choose the center interval first, which is incompatible with the optimal solution:
[----][----][----][----]
An algorithm that does work is, while there are any overlaps:
- Find the left-most point of overlap
- Pick any two intervals that overlap that point, and delete the one that extends farthest to the right.
This algorithm is also very simple to implement. You can do it in a single traversal through the list of intervals, sorted by start point:
class Solution:
def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
intervals.sort()
extent = None
deletes = 0
for interval in intervals:
if extent == None or extent <= interval[0]:
extent = interval[1]
else:
deletes = 1
extent = min(extent, interval[1])
return deletes