I am trying to use interval tree to solve this problem. Below is my try but understandably it is not working i.e. it is not returning all the intervals.
A cricket match is going to be held. The field is represented by a 1D plane. A cricketer, Mr. X has favorite shots. Each shot has a particular range. The range of the ith shot is from A(i) to B(i). That means his favorite shot can be anywhere in this range. Each player on the opposite team can field only in a particular range. Player can field from A(i) to B(i). You are given the favorite shots of Mr. X and the range of M players.
Brute force solution is timing out for some of the test cases. All I need is an idea.
class node:
def __init__(self, low, high):
self.left = None
self.right = None
self.highest = high
self.low = low
self.high = high
class interval:
def __init__(self):
self.head = None
self.count = 0
def add_node(self, node):
if self.head == None:
self.head = node
else:
if self.head.highest < node.high:
self.head.highest = node.high
self.__add_node(self.head, node)
def __add_node(self, head, node):
if node.low <= head.low:
if head.left == None:
head.left = node
else:
if head.left.highest < node.high:
head.left.highest = node.high
self.__add_node(head.left, node)
else:
if head.right == None:
head.right = node
else:
if head.right.highest < node.high:
head.right.highest = node.high
self.__add_node(head.right, node)
def search(self, node):
self.count = 0
return self._search(self.head, node)
def _search(self, head, node):
if node.low <= head.high and node.high >= head.low:
self.count += 1
print(self.count, head.high, head.low)
if head.left != None and head.left.highest >= node.low:
return self._search(head.left, node)
elif head.right != None:
return self._search(head.right, node)
return self.count
data = input().split(" ")
N = int(data[0])
M = int(data[1])
intervals = interval()
for i in range(N):
data = input().split(" ")
p = node(int(data[0]), int(data[1]))
intervals.add_node(p)
count = 0
for i in range(M):
data = input().split(" ")
count += intervals.search(node(int(data[0]), int(data[1])))
print(count)
The key to solving the problem is to realize that there's no need to compare single fielding range to a single shot range since only the total number intersecting ranges needs to be known. In order to achieve this in O(n log n) time following algorithm can be used.
Take the shot ranges and create two ordered lists: one for start values and another for end values. The example problem has shots [[1, 2], [2, 3], [4, 5], [6, 7]] and after the sorting we have two lists: [1, 2, 4, 6] and [2, 3, 5, 7]. Everything so far can be done in O(n log n) time.
Next process the outfield players. First player has range [1, 5]. When we do binary search with start value 1 to sorted end values [2, 3, 5, 7] we notice that all the shot ranges end after the start value. Next we do another search with end value 5 to sorted start values [1, 2, 4, 6] we notice that 3 shot ranges start before or at the end value. Then we do simple calculation 3 - 0 to conclude that first outfield player can intersect 3 ranges. Repeating this to all outfield players (M) takes O(m log n) time.
I did some homework and tried to solve it with interval tree.But as you have realized,traditional interval tree may not be suitable for this problem.This is because there is only one match returned when searching an interval tree,but we need to find all matches.More exactly,we just need to count all matches,it's not required to find all of them.
So I add 2 fields to your node for the sake of pruning.I'm not familiar with python,It looks like this in java:
static class Node implements Comparable<Node> {
Node left;//left child
Node right;//right child
int low;//low of current node
int high;//high of current node
int lowest;//lowest of current subtree
int highest;//highest of current subtree
int nodeCount;//node count of current subtree
#Override
public int compareTo(Node o) {
return low - o.low;
}
}
In order to make an balanced tree,I sort all the intervals and then build the tree from middle to both sides recursively(It may be better to use red-black tree).This affects a lot to performance,so I suggest to add this feature to your program.
The preparations have been finished so far.The search method looks like this:
private static int search(Node node, int low, int high) {
//pruning 1: interval [low,high] totally overlaps with subtree,thus overlaps with all children
if (node.lowest >= low && node.highest <= high) {
return node.nodeCount;
}
//pruning 2: interval [low,high] never overlaps with subtree
if (node.highest < low || node.lowest > high) {
return 0;
}
//can't judge,go through left and right child
//overlapped with current node or not
int c = (high < node.low || low > node.high ? 0 : 1);
if (node.left != null) {
c += search(node.left, low, high);
}
if (node.right != null) {
c += search(node.right, low, high);
}
return c;
}
There are 2 main prunings as the comments show.There is no need to go through the children when the current subtree is totally overlapped or never overlapped.
It works well in most conditions and has been accepted by the system.It costs about 4000ms to solve the most complicate test case(N=99600,M=98000).I'm still trying to do more optimization,hoping to be helpful.
Related
I'm solving 'Non overlap intervals' problem on leetcode [https://leetcode.com/problems/non-overlapping-intervals/]
In short, we need to define the minimum amount of intervals to delete to create non-overlapping set of them (number to delete is requested result).
And my solution is to build augmented interval tree ([https://en.wikipedia.org/wiki/Interval_tree#Augmented_tree]) out of all the intervals (for O((n log n) time complexity), then (the second traversal through the intervals) measure how many other intervals each given interval intersects (also for O((n log n) time complexity) (it gives also +1 self-intersection, but I use it only as relative metric) and sort all the intervals on this 'number-of intersections of others' metric.
At the last step I just get intervals one by one out of the sorted, as described above, list and create non-overlapping set (have an explicit check for non-overlapping, using another instance of interval tree) forming the result set that should be deleted.
And below I give full code of the described solution to play on leetcode with.
The approach work sufficiently fast, BUT sometimes I get wrong, differs by 1, result. Leetcode doesn't give much feedback throwing back at me 'expected 810' instead of my result '811'. So I'm still debugging digging the 811 intervals.... :)
Even knowing other solutions to this problem I'd like find the case on which described approach fails (it can be useful edge case by itself). So if someone saw similar problem or just can spot it with some 'fresh eyes' - it would be the most appreciated!
Thank in advance for any constructive comments and ideas!
The solution code:
class Interval:
def __init__(self, lo: int, hi: int):
self.lo = lo
self.hi = hi
class Node:
def __init__(self, interval: Interval, left: 'Node' = None, right: 'Node' = None):
self.left = left
self.right = right
self.interval = interval
self.max_hi = interval.hi
class IntervalTree:
def __init__(self):
self.root = None
def __add(self, interval: Interval, node:Node) -> Node:
if node is None:
node = Node(interval)
node.max_hi = interval.hi
return node
if node.interval.lo > interval.lo:
node.left = self.__add(interval, node.left)
else:
node.right = self.__add(interval, node.right)
node.max_hi = max(node.left.max_hi if node.left else 0, node.right.max_hi if node.right else 0, node.interval.hi)
return node
def add(self, lo: int, hi: int):
interval = Interval(lo, hi)
self.root = self.__add(interval, self.root)
def __is_intersect(self, interval: Interval, node: Node) -> bool:
if node is None:
return False
if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
# print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
return True
if node.left and node.left.max_hi > interval.lo:
return self.__is_intersect(interval, node.left)
return self.__is_intersect(interval, node.right)
def is_intersect(self, lo: int, hi: int) -> bool:
interval = Interval(lo, hi)
return self.__is_intersect(interval, self.root)
def __all_intersect(self, interval: Interval, node: Node) -> Iterable[Interval]:
if node is None:
yield from ()
else:
if not (node.interval.lo >= interval.hi or node.interval.hi <= interval.lo):
# print(f'{interval.lo}-{interval.hi} intersects {node.interval.lo}-{node.interval.hi}')
yield node.interval
if node.left and node.left.max_hi > interval.lo:
yield from self.__all_intersect(interval, node.left)
yield from self.__all_intersect(interval, node.right)
def all_intersect(self, lo: int, hi: int) -> Iterable[Interval]:
interval = Interval(lo, hi)
yield from self.__all_intersect(interval, self.root)
class Solution:
def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
ranged_intervals = []
interval_tree = IntervalTree()
for interval in intervals:
interval_tree.add(interval[0], interval[1])
for interval in intervals:
c = interval_tree.all_intersect(interval[0], interval[1])
ranged_intervals.append((len(list(c))-1, interval)) # decrement intersection to account self intersection
interval_tree = IntervalTree()
res = []
ranged_intervals.sort(key=lambda t: t[0], reverse=True)
while ranged_intervals:
_, interval = ranged_intervals.pop()
if not interval_tree.is_intersect(interval[0], interval[1]):
interval_tree.add(interval[0], interval[1])
else:
res.append(interval)
return len(res)
To make a counter example for your algorithm, you can construct a problem where selecting the segment with the fewest number of intersections ruins the solution, like this:
[----][----][----][----]
[-------][----][-------]
[-------] [-------]
[-------] [-------]
[-------] [-------]
Your algorithm will choose the center interval first, which is incompatible with the optimal solution:
[----][----][----][----]
An algorithm that does work is, while there are any overlaps:
Find the left-most point of overlap
Pick any two intervals that overlap that point, and delete the one that extends farthest to the right.
This algorithm is also very simple to implement. You can do it in a single traversal through the list of intervals, sorted by start point:
class Solution:
def eraseOverlapIntervals(self, intervals: List[List[int]]) -> int:
intervals.sort()
extent = None
deletes = 0
for interval in intervals:
if extent == None or extent <= interval[0]:
extent = interval[1]
else:
deletes += 1
extent = min(extent, interval[1])
return deletes
I tried making a Tree as a part of my Data Structures course. The code works but is extremely slow, almost double the time that is accepted for the course. I do not have experience with Data Structures and Algorithms but I need to optimize the program. If anyone has any tips, advices, criticism I would greatly appreciate it.
The tree is not necessarily a binary tree.
Here is the code:
import sys
import threading
class Node:
def __init__(self,value):
self.value = value
self.children = []
self.parent = None
def add_child(self,child):
child.parent = self
self.children.append(child)
def compute_height(n, parents):
found = False
indices = []
for i in range(n):
indices.append(i)
for i in range(len(parents)):
currentItem = parents[i]
if currentItem == -1:
root = Node(parents[i])
startingIndex = i
found = True
break
if found == False:
root = Node(parents[0])
startingIndex = 0
return recursion(startingIndex,root,indices,parents)
def recursion(index,toWhomAdd,indexes,values):
children = []
for i in range(len(values)):
if index == values[i]:
children.append(indexes[i])
newNode = Node(indexes[i])
toWhomAdd.add_child(newNode)
recursion(i, newNode, indexes, values)
return toWhomAdd
def checkHeight(node):
if node == '' or node == None or node == []:
return 0
counter = []
for i in node.children:
counter.append(checkHeight(i))
if node.children != []:
mostChildren = max(counter)
else:
mostChildren = 0
return(1 + mostChildren)
def main():
n = int(int(input()))
parents = list(map(int, input().split()))
root = compute_height(n, parents)
print(checkHeight(root))
sys.setrecursionlimit(10**7) # max depth of recursion
threading.stack_size(2**27) # new thread will get stack of such size
threading.Thread(target=main).start()
Edit:
For this input(first number being number of nodes and other numbers the node's values)
5
4 -1 4 1 1
We expect this output(height of the tree)
3
Another example:
Input:
5
-1 0 4 0 3
Output:
4
It looks like the value that is given for a node, is a reference by index of another node (its parent). This is nowhere stated in the question, but if that assumption is right, you don't really need to create the tree with Node instances. Just read the input into a list (which you already do), and you actually have the tree encoded in it.
So for example, the list [4, -1, 4, 1, 1] represents this tree, where the labels are the indices in this list:
1
/ \
4 3
/ \
0 2
The height of this tree — according to the definition given in Wikipedia — would be 2. But apparently the expected result is 3, which is the number of nodes (not edges) on the longest path from the root to a leaf, or — otherwise put — the number of levels in the tree.
The idea to use recursion is correct, but you can do it bottom up (starting at any node), getting the result of the parent recursively, and adding one to 1. Use the principle of dynamic programming by storing the result for each node in a separate list, which I called levels:
def get_num_levels(parents):
levels = [0] * len(parents)
def recur(node):
if levels[node] == 0: # this node's level hasn't been determined yet
parent = parents[node]
levels[node] = 1 if parent == -1 else recur(parent) + 1
return levels[node]
for node in range(len(parents)):
recur(node)
return max(levels)
And the main code could be as you had it:
def main():
n = int(int(input()))
parents = list(map(int, input().split()))
print(get_num_levels(parents))
I have 2 lists. One contains values, the other contains the levels those values hold in a sum tree. (the lists have same length)
For example:
[40,20,5,15,10,10] and [0,1,2,2,1,1]
Those lists correctly correspond because
- 40
- - 20
- - - 5
- - - 15
- - 10
- - 10
(20+10+10) == 40 and (5+15) == 20
I need to check if a given list of values and a list of its levels corresponds correctly. So far I have managed to put together this function, but for some reason it's not returning True for correct lists array and numbers. Input numbers here would be [40,20,5,15,10,10] and array would be [0,1,2,2,1,1]
def testsum(array, numbers):
k = len(array)
target = [0]*k
subsum = [0]*k
for x in range(0, k):
if target[array[x]]!=subsum[array[x]]:
return False
target[array[x]]=numbers[x]
subsum[array[x]]=0
if array[x]>0:
subsum[array[x]-1]+=numbers[x]
for x in range(0, k):
if(target[x]!=subsum[x]):
print(x, target[x],subsum[x])
return False
return True
I got this running using itertools.takewhile to grab the subtree under each level. Toss that into a recursive function and assert that all recursions pass.
I've slightly improved my initial implementation by grabbing a next_v and next_l and testing early to see if the current node is a parent node and only building subtree if there's something to build. That inequality check is much cheaper than iterating through the whole vs_ls zip.
import itertools
def testtree(values, levels):
if len(values) == 1:
# Last element, always true!
return True
vs_ls = zip(values, levels)
test_v, test_l = next(vs_ls)
next_v, next_l = next(vs_ls)
if next_l > test_l:
subtree = [v for v,l in itertools.takewhile(
lambda v_l: v_l[1] > test_l,
itertools.chain([(next_v, next_l)], vs_ls))
if l == test_l+1]
if sum(subtree) != test_v and subtree:
#TODO test if you can remove the "and subtree" check now!
print("{} != {}".format(subtree, test_v))
return False
return testtree(values[1:], levels[1:])
if __name__ == "__main__":
vs = [40, 20, 15, 5, 10, 10]
ls = [0, 1, 2, 2, 1, 1]
assert testtree(vs, ls) == True
It unfortunately adds a lot of complexity to the code since it pulls out the first value that we need, which necessitates an extra itertools.chain call. That's not ideal. Unless you're expecting to get very large lists for values and levels, it might be worthwhile to do vs_ls = list(zip(values, levels)) and approach this list-wise rather than iterator-wise. e.g...
...
vs_ls = list(zip(values, levels))
test_v, test_l = vs_ls[0]
next_v, next_l = vs_ls[1]
...
subtree = [v for v,l in itertools.takewhile(
lambda v_l: v_l[1] > test_l,
vs_ls[1:]) if l == test_l+1]
I still think the fastest way is probably to iterate once with an approach almost like a state machine and grab all the possible subtrees, then check them all individually. Something like:
from collections import namedtuple
Tree = namedtuple("Tree", ["level_num", "parent", "children"])
# equivalent to
# # class Tree:
# # def __init__(self, level_num: int,
# # parent: int,
# # children: list):
# # self.level_num = level_num
# # self.parent = parent
# # self.children = children
def build_trees(values, levels):
trees = [] # list of Trees
pending_trees = []
vs_ls = zip(values, levels)
last_v, last_l = next(vs_ls)
test_l = last_l + 1
for v, l in zip(values, levels):
if l > last_l:
# we've found a new tree
if l != last_l + 1:
# What do you do if you get levels like [0, 1, 3]??
raise ValueError("Improper leveling: {}".format(levels))
test_l = l
# Stash the old tree and start a new one.
pending_trees.append(cur_tree)
cur_tree = Tree(level_num=last_l, parent=last_v, children=[])
elif l < test_l:
# tree is finished
# Store the finished tree and grab the last one we stashed.
trees.append(cur_tree)
try:
cur_tree = pending_trees.pop()
except IndexError:
# No trees pending?? That's weird....
# I can't think of any case that this should happen, so maybe
# we should be raising ValueError here, but I'm not sure either
cur_tree = Tree(level_num=-1, parent=-1, children=[])
elif l == test_l:
# This is a child value in our current tree
cur_tree.children.append(v)
# Close the pending trees
trees.extend(pending_trees)
return trees
This should give you a list of Tree objects, each of which having the following attributes
level_num := level number of parent (as found in levels)
parent := number representing the expected sum of the tree
children := list containing all the children in that level
After you do that, you should be able to simply check
all([sum(t.children) == t.parent for t in trees])
But note that I haven't been able to test this second approach.
I'm trying to create depth-first algorithm that assigns finishing times (the time when a vertex can no longer be expanded) which are used for things like Kosaraju's algorithm. I was able to create a recursive version of DFS fairly easily, but I'm having a hard time converting it to an iterative version.
I'm using an adjacency list to represent the graph: a dict of vertices. For example, the input graph {1: [0, 4], 2: [1, 5], 3: [1], 4: [1, 3], 5: [2, 4], 6: [3, 4, 7, 8], 7: [5, 6], 8: [9], 9: [6, 11], 10: [9], 11: [10]} represents edges (1,0), (1,4), (2,1), (2,5), etc. The following is the implementation of an iterative DFS that uses a simple stack (LIFO), but it doesn't compute finishing times. One of the key problems I faced was that since the vertices are popped, there is no way for the algorithm to trace back its path once a vertex has been fully expanded (unlike in recursion). How do I fix this?
def dfs(graph, vertex, finish, explored):
global count
stack = []
stack.append(vertex)
while len(stack) != 0:
vertex = stack.pop()
if explored[vertex] == False:
explored[vertex] = True
#add all outgoing edges to stack:
if vertex in graph: #check if key exists in hash -- since not all vertices have outgoing edges
for v in graph[vertex]:
stack.append(v)
#this doesn't assign finishing times, it assigns the times when vertices are discovered:
#finish[count] = vertex
#count += 1
N.b. there is also an outer loop that complements DFS -- though, I don't think the problem lies there:
#outer loop:
for vertex in range(size, 0, -1):
if explored[vertex] == False:
dfs(hGraph, vertex, finish, explored)
Think of your stack as a stack of tasks, not vertices. There are two types of task you need to do. You need to expand vertexes, and you need to add finishing times.
When you go to expand a vertex, you first add the task of computing a finishing time, then add expanding every child vertex.
When you go to add a finishing time, you can do so knowing that expansion finished.
Here is a working solution that uses two stacks during the iterative subroutine. The array traceBack holds the vertices that have been explored and is associated with complementary 2D-array, stack, that holds arrays of edges that have yet to be explored. These two arrays are linked; when we add an element to traceBack we also add to stack (same with popping elements).
count = 0
def outerLoop(hGraph, N):
explored = [False for iii in range(N+1)]
finish = {}
for vertex in range(N, -1, -1):
if explored[vertex] == False:
dfs(vertex, hGraph, explored, finish)
return finish
def dfs(vertex, graph, explored, finish):
global count
stack = [[]] #stack contains the possible edges to explore
traceBack = []
traceBack.append(vertex)
while len(stack) > 0:
explored[vertex] = True
try:
for n in graph[vertex]:
if explored[n] == False:
if n not in stack[-1]: #to prevent double adding (when we backtrack to previous vertex)
stack[-1].append(n)
else:
if n in stack[-1]: #make sure num exists in array before removing
stack[-1].remove(n)
except IndexError: pass
if len(stack[-1]) == 0: #if current stack is empty then there are no outgoing edges:
finish[count] = traceBack.pop() #thus, we can add finishing times
count += 1
if len(traceBack) > 0: #to prevent popping empty array
vertex = traceBack[-1]
stack.pop()
else:
vertex = stack[-1][-1] #pick last element in top of stack to be new vertex
stack.append([])
traceBack.append(vertex)
Here is a way. Every time we face the following condition, we do a callback or we mark the time,
The node has no outgoing edge(no way).
When the parent of the traversing node is different from last parent(parent change). In that case we finish the last parent.
When we reach at the end of the stack(tree end). We finish the last parent.
Here is the code,
var dfs_with_finishing_time = function(graph, start, cb) {
var explored = [];
var parent = [];
var i = 0;
for(i = 0; i < graph.length; i++) {
if(i in explored)
continue;
explored[i] = 1;
var stack = [i];
parent[i] = -1;
var last_parent = -1;
while(stack.length) {
var u = stack.pop();
var k = 0;
var no_way = true;
for(k = 0; k < graph.length; k++) {
if(k in explored)
continue;
if(!graph[u][k])
continue;
stack.push(k);
explored[k] = 1;
parent[k] = u;
no_way = false;
}
if(no_way) {
cb(null, u+1); // no way, reversed post-ordering (finishing time)
}
if(last_parent != parent[u] && last_parent != -1) {
cb(null, last_parent+1); // parent change, reversed post-ordering (finishing time)
}
last_parent = parent[u];
}
if(last_parent != -1) {
cb(null, last_parent+1); // tree end, reversed post-ordering (finishing time)
}
}
}
here is the first part of the code that i have did for Kosaraju's algorithm.
###### reading the data #####
with open('data.txt') as req_file:
ori_data = []
for line in req_file:
line = line.split()
if line:
line = [int(i) for i in line]
ori_data.append(line)
###### forming the Grev ####
revscc_dic = {}
for temp in ori_data:
if temp[1] not in revscc_dic:
revscc_dic[temp[1]] = [temp[0]]
else:
revscc_dic[temp[1]].append(temp[0])
print revscc_dic
######## finding the G#####
scc_dic = {}
for temp in ori_data:
if temp[0] not in scc_dic:
scc_dic[temp[0]] = [temp[1]]
else:
scc_dic[temp[0]].append(temp[1])
print scc_dic
##### iterative dfs ####
path = []
for i in range(max(max(ori_data)),0,-1):
start = i
q=[start]
while q:
v=q.pop(0)
if v not in path:
path.append(v)
q=revscc_dic[v]+q
print path
The code reads the data and forms Grev and G correctly. I have written a code for iterative dfs. How can i include to find the finishing time ?? I understand finding the finishing time using paper and pen but I do not understand the part of finishing time as a code ?? how can I implement it.. Only after this I can proceed my next part of code. Pls help. Thanks in advance.
The data.txt file contains:
1 4
2 8
3 6
4 7
5 2
6 9
7 1
8 5
8 6
9 7
9 3
please save it as data.txt.
With recursive dfs, it is easy to see when a given vertex has "finished" (i.e. when we have visited all of its children in the dfs tree). The finish time can be calculated just after the recursive call has returned.
However with iterative dfs, this is not so easy. Now that we are iteratively processing the queue using a while loop we have lost some of the nested structure that is associated with function calls. Or more precisely, we don't know when backtracking occurs. Unfortunately, there is no way to know when backtracking occurs without adding some additional information to our stack of vertices.
The quickest way to add finishing times to your dfs implementation is like so:
##### iterative dfs (with finish times) ####
path = []
time = 0
finish_time_dic = {}
for i in range(max(max(ori_data)),0,-1):
start = i
q = [start]
while q:
v = q.pop(0)
if v not in path:
path.append(v)
q = [v] + q
for w in revscc_dic[v]:
if w not in path: q = [w] + q
else:
if v not in finish_time_dic:
finish_time_dic[v] = time
time += 1
print path
print finish_time_dic
The trick used here is that when we pop off v from the stack, if it is the first time we have seen it, then we add it back to the stack again. This is done using: q = [v] + q. We must push v onto the stack before we push on its neighbours (we write the code that pushes v before the for loop that pushes v's neighbours) - or else the trick doesn't work. Eventually we will pop v off the stack again. At this point, v has finished! We have seen v before, so, we go into the else case and compute a fresh finish time.
For the graph provided, finish_time_dic gives the correct finishing times:
{1: 6, 2: 1, 3: 3, 4: 7, 5: 0, 6: 4, 7: 8, 8: 2, 9: 5}
Note that this dfs algorithm (with the finishing times modification) still has O(V+E) complexity, despite the fact that we are pushing each node of the graph onto the stack twice. However, more elegant solutions exist.
I recommend reading Chapter 5 of Python Algorithms: Mastering Basic Algorithms in the Python Language by Magnus Lie Hetland (ISBN: 1430232374, 9781430232377). Question 5-6 and 5-7 (on page 122) describe your problem exactly. The author answers these questions and gives an alternate solution to the problem.
Questions:
5-6 In recursive DFS, backtracking occurs when you return from one of the recursive calls. But where has the backtracking gone in the iterative version?
5-7. Write a nonrecursive version of DFS that can deal determine finish-times.
Answers:
5-6 It’s not really represented at all in the iterative version. It just implicitly occurs once you’ve popped off all your “traversal descendants” from the stack.
5-7 As explained in Exercise 5-6, there is no point in the code where backtracking occurs in the iterative DFS, so we can’t just set the finish time at some specific place (like in the recursive one). Instead, we’d need to add a marker to the stack. For example, instead of adding the neighbors of u to the stack, we could add edges of the form (u, v), and before all of them, we’d push (u, None), indicating the backtracking point for u.
Iterative DFS itself is not complicated, as seen from Wikipedia. However, calculating the finish time of each node requires some tweaks to the algorithm. We only pop the node off the stack the 2nd time we encounter it.
Here's my implementation which I feel demonstrates what's going on a bit more clearly:
step = 0 # time counter
def dfs_visit(g, v):
"""Run iterative DFS from node V"""
global step
total = 0
stack = [v] # create stack with starting vertex
while stack: # while stack is not empty
step += 1
v = stack[-1] # peek top of stack
if v.color: # if already seen
v = stack.pop() # done with this node, pop it from stack
if v.color == 1: # if GRAY, finish this node
v.time_finish = step
v.color = 2 # BLACK, done
else: # seen for first time
v.color = 1 # GRAY: discovered
v.time_discover = step
total += 1
for w in v.child: # for all neighbor (v, w)
if not w.color: # if not seen
stack.append(w)
return total
def dfs(g):
"""Run DFS on graph"""
global step
step = 0 # reset step counter
for k, v in g.nodes.items():
if not v.color:
dfs_visit(g, v)
I am following the conventions of the CLR Algorithm Book and use node coloring to designate its state during the DFS search. I feel this is easier to understand than using a separate list to track node state.
All nodes start out as white. When it's discovered during the search it is marked as gray. When we are done with it, it is marked as black.
Within the while loop, if a node is white we keep it in the stack, and change its color to gray. If it's gray we change its color to black, and set its finish time. If it's black we just ignore it.
It is possible for a node on the stack to be black (even with our coloring check before adding it to the stack). A white node can be added to the stack twice (via two different neighbors). One will eventually turn black. When we reach the 2nd instance on the stack, we need to make sure we don't change its already set finish time.
Here are some additional support codes:
class Node(object):
def __init__(self, name=None):
self.name = name
self.child = [] # children | adjacency list
self.color = 0 # 0: white [unvisited], 1: gray [found], 2: black [finished]
self.time_discover = None # DFS
self.time_finish = None # DFS
class Graph(object):
def __init__(self):
self.nodes = defaultdict(Node) # list of Nodes
self.max_heap = [] # nodes in decreasing finish time for SCC
def build_max_heap(self):
"""Build list of nodes in max heap using DFS finish time"""
for k, v in self.nodes.items():
self.max_heap.append((0-v.time_finish, v)) # invert finish time for max heap
heapq.heapify(self.max_heap)
To run DFS on the reverse graph, you can build a parent list similar to the child list for each Node when the edges file is processed, and use the parent list instead of the child list in dfs_visit().
To process Nodes in decreasing finish time for the last part of SCC computation, you can build a max heap of Nodes, and use that max heap in dfs_visit() instead of simply the child list.
while g.max_heap:
v = heapq.heappop(g.max_heap)[1]
if not v.color:
size = dfs_visit(g, v)
scc_size.append(size)
I had a few issues with the order produced by Lawson's version of the iterative DFS. Here is code for my version which has a 1-to-1 mapping with a recursive version of DFS.
n = len(graph)
time = 0
finish_times = [0] * (n + 1)
explored = [False] * (n + 1)
# Determine if every vertex connected to v
# has already been explored
def all_explored(G, v):
if v in G:
for w in G[v]:
if not explored[w]:
return False
return True
# Loop through vertices in reverse order
for v in xrange(n, 0, -1):
if not explored[v]:
stack = [v]
while stack:
print(stack)
v = stack[-1]
explored[v] = True
# If v still has outgoing edges to explore
if not all_explored(graph_reversed, v):
for w in graph_reversed[v]:
# Explore w before others attached to v
if not explored[w]:
stack.append(w)
break
# We have explored vertices findable from v
else:
stack.pop()
time += 1
finish_times[v] = time
Here are the recursive and iterative implementations in java:
int time = 0;
public void dfsRecursive(Vertex vertex) {
time += 1;
vertex.setVisited(true);
vertex.setDiscovered(time);
for (String neighbour : vertex.getNeighbours()) {
if (!vertices.get(neighbour).getVisited()) {
dfsRecursive(vertices.get(neighbour));
}
}
time += 1;
vertex.setFinished(time);
}
public void dfsIterative(Vertex vertex) {
Stack<Vertex> stack = new Stack<>();
stack.push(vertex);
while (!stack.isEmpty()) {
Vertex current = stack.pop();
if (!current.getVisited()) {
time += 1;
current.setVisited(true);
current.setDiscovered(time);
stack.push(current);
List<String> currentsNeigbours = current.getNeighbours();
for (int i = currentsNeigbours.size() - 1; i >= 0; i--) {
String currentNeigbour = currentsNeigbours.get(i);
Vertex neighBour = vertices.get(currentNeigbour);
if (!neighBour.getVisited())
stack.push(neighBour);
}
} else {
if (current.getFinished() < 1) {
time += 1;
current.setFinished(time);
}
}
}
}
First, you should know exactly what is finished time. In recursive dfs, finished time is when all of the adjacent nodes [V]s of a Node v is finished,
with this keeping in mind you need to have additional data structure to store all infos.
adj[][] //graph
visited[]=NULL //array of visited node
finished[]=NULL //array of finished node
Stack st=new Stack //normal stack
Stack backtrack=new Stack //additional stack
function getFinishedTime(){
for(node i in adj){
if (!vistied.contains[i]){
st.push(i);
visited.add(i)
while(!st.isEmpty){
int j=st.pop();
int[] unvisitedChild= getUnvistedChild(j);
if(unvisitedChild!=null){
for(int c in unvisitedChild){
st.push(c);
visited.add(c);
}
backtrack.push([j,unvisitedChild]); //you can store each entry as array with the first index as the parent node j, followed by all the unvisited child node.
}
else{
finished.add(j);
while(!backtrack.isEmpty&&finished.containsALL(backtrack.peek())) //all of the child node is finished, then we can set the parent node visited
{
parent=backtrack.pop()[0];
finished.add(parent);
}
}
}
}
}
function getUnvistedChild(int i){
unvisitedChild[]=null
for(int child in adj[i]){
if(!visited.contains(child))
unvisitedChild.add(child);
}
return unvisitedChild;
}
and the finished time should be
[5, 2, 8, 3, 6, 9, 1, 4, 7]