How do I test a sum tree? - python

I have 2 lists. One contains values, the other contains the levels those values hold in a sum tree. (the lists have same length)
For example:
[40,20,5,15,10,10] and [0,1,2,2,1,1]
Those lists correctly correspond because
- 40
- - 20
- - - 5
- - - 15
- - 10
- - 10
(20+10+10) == 40 and (5+15) == 20
I need to check if a given list of values and a list of its levels corresponds correctly. So far I have managed to put together this function, but for some reason it's not returning True for correct lists array and numbers. Input numbers here would be [40,20,5,15,10,10] and array would be [0,1,2,2,1,1]
def testsum(array, numbers):
k = len(array)
target = [0]*k
subsum = [0]*k
for x in range(0, k):
if target[array[x]]!=subsum[array[x]]:
return False
target[array[x]]=numbers[x]
subsum[array[x]]=0
if array[x]>0:
subsum[array[x]-1]+=numbers[x]
for x in range(0, k):
if(target[x]!=subsum[x]):
print(x, target[x],subsum[x])
return False
return True

I got this running using itertools.takewhile to grab the subtree under each level. Toss that into a recursive function and assert that all recursions pass.
I've slightly improved my initial implementation by grabbing a next_v and next_l and testing early to see if the current node is a parent node and only building subtree if there's something to build. That inequality check is much cheaper than iterating through the whole vs_ls zip.
import itertools
def testtree(values, levels):
if len(values) == 1:
# Last element, always true!
return True
vs_ls = zip(values, levels)
test_v, test_l = next(vs_ls)
next_v, next_l = next(vs_ls)
if next_l > test_l:
subtree = [v for v,l in itertools.takewhile(
lambda v_l: v_l[1] > test_l,
itertools.chain([(next_v, next_l)], vs_ls))
if l == test_l+1]
if sum(subtree) != test_v and subtree:
#TODO test if you can remove the "and subtree" check now!
print("{} != {}".format(subtree, test_v))
return False
return testtree(values[1:], levels[1:])
if __name__ == "__main__":
vs = [40, 20, 15, 5, 10, 10]
ls = [0, 1, 2, 2, 1, 1]
assert testtree(vs, ls) == True
It unfortunately adds a lot of complexity to the code since it pulls out the first value that we need, which necessitates an extra itertools.chain call. That's not ideal. Unless you're expecting to get very large lists for values and levels, it might be worthwhile to do vs_ls = list(zip(values, levels)) and approach this list-wise rather than iterator-wise. e.g...
...
vs_ls = list(zip(values, levels))
test_v, test_l = vs_ls[0]
next_v, next_l = vs_ls[1]
...
subtree = [v for v,l in itertools.takewhile(
lambda v_l: v_l[1] > test_l,
vs_ls[1:]) if l == test_l+1]
I still think the fastest way is probably to iterate once with an approach almost like a state machine and grab all the possible subtrees, then check them all individually. Something like:
from collections import namedtuple
Tree = namedtuple("Tree", ["level_num", "parent", "children"])
# equivalent to
# # class Tree:
# # def __init__(self, level_num: int,
# # parent: int,
# # children: list):
# # self.level_num = level_num
# # self.parent = parent
# # self.children = children
def build_trees(values, levels):
trees = [] # list of Trees
pending_trees = []
vs_ls = zip(values, levels)
last_v, last_l = next(vs_ls)
test_l = last_l + 1
for v, l in zip(values, levels):
if l > last_l:
# we've found a new tree
if l != last_l + 1:
# What do you do if you get levels like [0, 1, 3]??
raise ValueError("Improper leveling: {}".format(levels))
test_l = l
# Stash the old tree and start a new one.
pending_trees.append(cur_tree)
cur_tree = Tree(level_num=last_l, parent=last_v, children=[])
elif l < test_l:
# tree is finished
# Store the finished tree and grab the last one we stashed.
trees.append(cur_tree)
try:
cur_tree = pending_trees.pop()
except IndexError:
# No trees pending?? That's weird....
# I can't think of any case that this should happen, so maybe
# we should be raising ValueError here, but I'm not sure either
cur_tree = Tree(level_num=-1, parent=-1, children=[])
elif l == test_l:
# This is a child value in our current tree
cur_tree.children.append(v)
# Close the pending trees
trees.extend(pending_trees)
return trees
This should give you a list of Tree objects, each of which having the following attributes
level_num := level number of parent (as found in levels)
parent := number representing the expected sum of the tree
children := list containing all the children in that level
After you do that, you should be able to simply check
all([sum(t.children) == t.parent for t in trees])
But note that I haven't been able to test this second approach.

Related

Finding and averaging elements of multiple dataframes that are within 10% of each other

I have objects that store values are dataframes. I have been able to compare if values from two dataframes are within 10% of each other. However, I am having difficulty extending this to multiple dataframes. Moreover, I am wondering how I should apporach this problem if dataframes are not the same size?
def add_well_peak(self, *other):
if len(self.Bell) == len(other.Bell): #if dataframes ARE the same size
for k in range(len(self.Bell)):
for j in range(len(other.Bell)):
if int(self.Size[k]) - int(self.Size[k])*(1/10) <= int(other.Size[j]) <= int(self.Size[k]) + int(self.Size[k])*(1/10):
#average all
For example, in the image below, there are objects that contain dataframes (i.e., self, other1, other2). The colors represent matches (i.e, values that are within 10% of each other). If a match exist, then average the values. If a match does not exist still include the unmatch number. I want to be able to generalize this for any number of objects greater or equal than 2 (other 1, other 2, other 3, other ....). Any help would be appreciated. Please let me know if anything is unclear. This is my first time posting. Thanks again.
matching data
Results:
Using my solution on the dataframes of your image, I get the following:
Threshold outlier = 0.2:
0
0 1.000000
1 1493.500000
2 5191.333333
3 35785.333333
4 43586.500000
5 78486.000000
6 100000.000000
Threshold outlier = 0.5:
0 1
0 1.000000 NaN
1 1493.500000 NaN
2 5191.333333 NaN
3 43586.500000 35785.333333
4 78486.000000 100000.000000
Explanations:
The lines are averaged peaks, the columns representing the different values obtained for these peaks. I assumed the average emanating from the biggest number of elements was the legitimate one, and the rest within the THRESHOLD_OUTLIER were the outliers (should be sorted, the more probable you are as a legitimate peak, the more you are on the left (the 0th column is the most probable)). For instance, on line 3 of the 0.5 outlier threshold results, 43586.500000 is an average coming from 3 dataframes, while 35785.333333 comes from only 2, thus the most probable is the first one.
Issues:
The solution is quite complicated. I assume a big part of it could be removed, but I can't see how for the moment, and as it works, I'll certainly leave the optimization to you.
Still, I tried commenting my best, and if you have any question, do not hesitate!
Files:
CombinationLib.py
from __future__ import annotations
from typing import Dict, List
from Errors import *
class Combination():
"""
Support class, to make things easier.
Contains a string `self.combination` which is a binary number stored as a string.
This allows to test every combination of value (i.e. "101" on the list `[1, 2, 3]`
would signify grouping `1` and `3` together).
There are some methods:
- `__add__` overrides the `+` operator
- `compute_degree` gives how many `1`s are in the combination
- `overlaps` allows to verify if combination overlaps (use the same value twice)
(i.e. `100` and `011` don't overlap, while `101` and `001` do)
"""
def __init__(self, combination:str) -> Combination:
self.combination:str = combination
self.degree:int = self.compute_degree()
def __add__(self, other: Combination) -> Combination:
if self.combination == None:
return other.copy()
if other.combination == None:
return self.copy()
if self.overlaps(other):
raise CombinationsOverlapError()
result = ""
for c1, c2 in zip(self.combination, other.combination):
result += "1" if (c1 == "1" or c2 == "1") else "0"
return Combination(result)
def __str__(self) -> str:
return self.combination
def compute_degree(self) -> int:
if self.combination == None:
return 0
degree = 0
for bit in self.combination:
if bit == "1":
degree += 1
return degree
def copy(self) -> Combination:
return Combination(self.combination)
def overlaps(self, other:Combination) -> bool:
for c1, c2 in zip(self.combination, other.combination):
if c1 == "1" and c1 == c2:
return True
return False
class CombinationNode():
"""
The main class.
The main idea was to build a tree of possible "combinations of combinations":
100-011 => 111
|---010-001 => 111
|---001-010 => 111
At each node, the combination applied to the current list of values was to be acceptable
(all within THREASHOLD_AVERAGING).
Also, the shorter a path, the better the solution as it means it found a way to average
a lot of the values, with the minimum amount of outliers possible, maybe by grouping
the outliers together in a way that makes sense, ...
- `populate` fills the tree automatically, with every solution possible
- `path` is used mainly on leaves, to obtain the path taken to arrive there.
"""
def __init__(self, combination:Combination) -> CombinationNode:
self.combination:Combination = combination
self.children:List[CombinationNode] = []
self.parent:CombinationNode = None
self.total_combination:Combination = combination
def __str__(self) -> str:
list_paths = self.recur_paths()
list_paths = [",".join([combi.combination.combination for combi in path]) for path in list_paths]
return "\n".join(list_paths)
def add_child(self, child:CombinationNode) -> None:
if child.combination.degree > self.combination.degree and not self.total_combination.overlaps(child.combination):
raise ChildDegreeExceedParentDegreeError(f"{child.combination} > {self.combination}")
self.children.append(child)
child.parent = self
child.total_combination += self.total_combination
def path(self) -> List[CombinationNode]:
path = []
current = self
while current.parent != None:
path.append(current)
current = current.parent
path.append(current)
return path[::-1]
def populate(self, combination_dict:Dict[int, List[Combination]]) -> None:
missing_degrees = len(self.combination.combination)-self.total_combination.degree
if missing_degrees == 0:
return
for i in range(min(self.combination.degree, missing_degrees), 0, -1):
for combination in combination_dict[i]:
if not self.total_combination.overlaps(combination):
self.add_child(CombinationNode(combination))
for child in self.children:
child.populate(combination_dict)
def recur_paths(self) -> List[List[CombinationNode]]:
if len(self.children) == 0:
return [self.path()]
paths = []
for child in self.children:
for path in child.recur_paths():
paths.append(path)
return paths
Errors.py
class ChildDegreeExceedParentDegreeError(Exception):
pass
class CombinationsOverlapError(Exception):
pass
class ToImplementError(Exception):
pass
class UncompletePathError(Exception):
pass
main.py
from typing import Dict, List, Set, Tuple, Union
import pandas as pd
from CombinationLib import *
best_depth:int = -1
best_path:List[CombinationNode] = []
THRESHOLD_OUTLIER = 0.2
THRESHOLD_AVERAGING = 0.1
def verif_averaging_pct(combination:Combination, values:List[float]) -> bool:
"""
For a given combination of values, we must have all the values within
THRESHOLD_AVERAGING of the average of the combination
"""
avg = 0
for c,v in zip(combination.combination, values):
if c == "1":
avg += v
avg /= combination.degree
for c,v in zip(combination.combination, values):
if c == "1"and (v > avg*(1+THRESHOLD_AVERAGING) or v < avg*(1-THRESHOLD_AVERAGING)):
return False
return True
def recursive_check(node:CombinationNode, depth:int, values:List[Union[float, int]]) -> None:
"""
Here is where we preferencially ask for a small number of bigger groups
"""
global best_depth
global best_path
# If there are more groups than the current best way to do, stop
if best_depth != -1 and depth > best_depth:
return
# If all the values of the combination are not within THRESHOLD_AVERAGING, stop
if not verif_averaging_pct(node.combination, values):
return
# If we finished the list of combinations, and this way is the best, keep it, stop
if len(node.children) == 0:
if best_depth == -1 or depth < best_depth:
best_depth = depth
best_path = node.path()
return
# If we are still not finished (not every value has been used), continue
for cnode in node.children:
recursive_check(cnode, depth+1, values)
def groups_from_list(values:List[Union[float, int]]) -> List[List[Union[float, int]]]:
"""
From a list of values, get the smallest list of groups of elements
within THRESHOLD_AVERAGING of each other.
It implies that we will try and recursively find the biggest group possible
within the unsused values (i.e. groups with combinations of size [3, 1] are prefered
over [2, 2])
"""
global best_depth
global best_path
groups:List[List[float]] = []
# Generate all the combinations (I used binary for this)
combination_dict:Dict[int, List[Combination]] = {}
for i in range(1, 2**len(values)):
combination = format(i, f"0{len(values)}b") # Here is the binary conversion
counter = 0
for c in combination:
if c == "1":
counter += 1
if counter not in combination_dict:
combination_dict[counter] = []
combination_dict[counter].append(Combination(combination))
# Generate of the combinations of combinations that use all values (without using one twice)
combination_trees:List[List[CombinationNode]] = []
for key in combination_dict:
for combination in combination_dict[key]:
cn = CombinationNode(combination)
cn.populate(combination_dict)
combination_trees.append(cn)
best_depth = -1
best_path = None
for root in combination_trees:
recursive_check(root, 0, values)
# print(",".join([combination.combination.combination for combination in best_path]))
for combination in best_path:
temp = []
for c,v in zip(combination.combination.combination, values):
if c == "1":
temp.append(v)
groups.append(temp)
return groups
def averages_from_groups(gs:List[List[Union[float, int]]]) -> List[float]:
"""Computing the averages of each group"""
avgs:List[float] = []
for group in gs:
avg = 0
for elt in group:
avg += elt
avg /= len(group)
avgs.append(avg)
return avgs
def end_check(ds:List[pd.DataFrame], ids:List[int]) -> bool:
"""Check if we finished consuming all the dataframes"""
for d,i in zip(ds, ids):
if i < len(d[0]):
return False
return True
def search(group:List[Union[float, int]], values_list:List[Union[float, int]]) -> List[int]:
"""Obtain all the indices corresponding to a set of values"""
# We will get all the indices in values_list of the values in group
# If a value is present in group, all the occurences of this value will be too,
# so we can use a set and search every occurence for each value.
indices:List[int] = []
group_set = set(group)
for value in group_set:
for i,v in enumerate(values_list):
if value == v:
indices.append(i)
return indices
def threshold_grouper(total_list:List[Union[float, int]]) -> pd.DataFrame:
"""Building a 2D pd.DataFrame with the averages (x) and the outliers (y)"""
result_list:List[List[Union[float, int]]] = [[total_list[0]]]
result_index = 0
total_index = 1
while total_index < len(total_list):
# Only checking if the bigger one is within THRESHOLD_OUTLIER of the little one.
# If it is the case, the opposite is true too.
# If yes, it is an outlier
if result_list[result_index][0]*(1+THRESHOLD_OUTLIER) >= total_list[total_index]:
result_list[result_index].append(total_list[total_index])
# Else it is a new peak
else:
result_list.append([total_list[total_index]])
result_index += 1
total_index += 1
result:pd.DataFrame = pd.DataFrame(result_list)
return result
def dataframes_merger(dataframes:List[pd.DataFrame]) -> pd.DataFrame:
"""Merging the dataframes, with THRESHOLDS"""
# Store the averages for the within 10% cells, in ascending order
result = []
# Keep tabs on where we are regarding each dataframe (needed for when we skip cells)
curr_indices:List[int] = [0 for _ in range(len(dataframes))]
# Repeat until all the cells in every dataframe has been seen once
while not end_check(dataframes, curr_indices):
# Get the values of the current indices in the dataframes
curr_values = [dataframe[0][i] for dataframe,i in zip(dataframes, curr_indices)]
# Get the largest 10% groups from the current list of values
groups = groups_from_list(curr_values)
# Compute the average of these groups
avgs = averages_from_groups(groups)
# Obtain the minimum average...
avg_min = min(avgs)
# ... and its index
avg_min_index = avgs.index(avg_min)
# Then get the group corresponding to the minimum average
avg_min_group = groups[avg_min_index]
# Get the indices of the values included in this group
indices_to_increment = search(avg_min_group, curr_values)
# Add the average to the result merged list
result.append(avg_min)
# For every element in the average we added, increment the corresponding index
for index in indices_to_increment:
curr_indices[index] += 1
# Re-assemble the dataframe, taking the threshold% around average into account
result = threshold_grouper(result)
print(result)
df1 = pd.DataFrame([1, 1487, 5144, 35293, 78486, 100000])
df2 = pd.DataFrame([1, 1500, 5144, 36278, 45968, 100000])
df3 = pd.DataFrame([1, 5286, 35785, 41205, 100000])
dataframes_merger([df3, df2, df1])

Bad Tree design, Data Structure

I tried making a Tree as a part of my Data Structures course. The code works but is extremely slow, almost double the time that is accepted for the course. I do not have experience with Data Structures and Algorithms but I need to optimize the program. If anyone has any tips, advices, criticism I would greatly appreciate it.
The tree is not necessarily a binary tree.
Here is the code:
import sys
import threading
class Node:
def __init__(self,value):
self.value = value
self.children = []
self.parent = None
def add_child(self,child):
child.parent = self
self.children.append(child)
def compute_height(n, parents):
found = False
indices = []
for i in range(n):
indices.append(i)
for i in range(len(parents)):
currentItem = parents[i]
if currentItem == -1:
root = Node(parents[i])
startingIndex = i
found = True
break
if found == False:
root = Node(parents[0])
startingIndex = 0
return recursion(startingIndex,root,indices,parents)
def recursion(index,toWhomAdd,indexes,values):
children = []
for i in range(len(values)):
if index == values[i]:
children.append(indexes[i])
newNode = Node(indexes[i])
toWhomAdd.add_child(newNode)
recursion(i, newNode, indexes, values)
return toWhomAdd
def checkHeight(node):
if node == '' or node == None or node == []:
return 0
counter = []
for i in node.children:
counter.append(checkHeight(i))
if node.children != []:
mostChildren = max(counter)
else:
mostChildren = 0
return(1 + mostChildren)
def main():
n = int(int(input()))
parents = list(map(int, input().split()))
root = compute_height(n, parents)
print(checkHeight(root))
sys.setrecursionlimit(10**7) # max depth of recursion
threading.stack_size(2**27) # new thread will get stack of such size
threading.Thread(target=main).start()
Edit:
For this input(first number being number of nodes and other numbers the node's values)
5
4 -1 4 1 1
We expect this output(height of the tree)
3
Another example:
Input:
5
-1 0 4 0 3
Output:
4
It looks like the value that is given for a node, is a reference by index of another node (its parent). This is nowhere stated in the question, but if that assumption is right, you don't really need to create the tree with Node instances. Just read the input into a list (which you already do), and you actually have the tree encoded in it.
So for example, the list [4, -1, 4, 1, 1] represents this tree, where the labels are the indices in this list:
1
/ \
4 3
/ \
0 2
The height of this tree — according to the definition given in Wikipedia — would be 2. But apparently the expected result is 3, which is the number of nodes (not edges) on the longest path from the root to a leaf, or — otherwise put — the number of levels in the tree.
The idea to use recursion is correct, but you can do it bottom up (starting at any node), getting the result of the parent recursively, and adding one to 1. Use the principle of dynamic programming by storing the result for each node in a separate list, which I called levels:
def get_num_levels(parents):
levels = [0] * len(parents)
def recur(node):
if levels[node] == 0: # this node's level hasn't been determined yet
parent = parents[node]
levels[node] = 1 if parent == -1 else recur(parent) + 1
return levels[node]
for node in range(len(parents)):
recur(node)
return max(levels)
And the main code could be as you had it:
def main():
n = int(int(input()))
parents = list(map(int, input().split()))
print(get_num_levels(parents))

What is the most efficient way of getting the intersection of k sorted arrays?

Given k sorted arrays what is the most efficient way of getting the intersection of these lists
Example
INPUT:
[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
Output:
[1,7]
There is a way to get the union of k sorted arrays based on what I read in the Elements of programming interviews book in nlogk time. I was wondering if there is a way to do something similar for the intersection as well
## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heapq.heappop(heap)
it = srtd_iters[ary]
res.append(elem)
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
EDIT: obviously this is an algorithm question that I am trying to solve so I cannot use any of the inbuilt functions like set intersection etc
Exploiting sort order
Here is a single pass O(n) approach that doesn't require any special data structures or auxiliary memory beyond the fundamental requirement of one iterator per input.
from itertools import cycle, islice
def intersection(inputs):
"Yield the intersection of elements from multiple sorted inputs."
# intersection(['ABBCD', 'BBDE', 'BBBDDE']) --> B B D
n = len(inputs)
iters = cycle(map(iter, inputs))
try:
candidate = next(next(iters))
while True:
for it in islice(iters, n-1):
while (value := next(it)) < candidate:
pass
if value != candidate:
candidate = value
break
else:
yield candidate
candidate = next(next(iters))
except StopIteration:
return
Here's a sample session:
>>> data = [[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
>>> list(intersection(data))
[1, 7]
>>> data = [[1,1,2,3], [1,1,4,4]]
>>> list(intersection(data))
[1, 1]
Algorithm in words
The algorithm starts by selecting the next value from the next iterator to be a candidate.
The main loop assumes a candidate has been selected and it loops over the next n - 1 iterators. For each of those iterators, it consumes values until it finds a value that is a least as large as the candidate. If that value is larger than the candidate, that value becomes the new candidate and the main loop starts again. If all n - 1 values are equal to the candidate, then the candidate is emitted and a new candidate is fetched.
When any input iterator is exhausted, the algorithm is complete.
Doing it without libraries (core language only)
The same algorithm works fine (though less beautifully) without using itertools. Just replace cycle and islice with their list based equivalents:
def intersection(inputs):
"Yield the intersection of elements from multiple sorted inputs."
# intersection(['ABBCD', 'BBDE', 'BBBDDE']) --> B B D
n = len(inputs)
iters = list(map(iter, inputs))
curr_iter = 0
try:
it = iters[curr_iter]
curr_iter = (curr_iter + 1) % n
candidate = next(it)
while True:
for i in range(n - 1):
it = iters[curr_iter]
curr_iter = (curr_iter + 1) % n
while (value := next(it)) < candidate:
pass
if value != candidate:
candidate = value
break
else:
yield candidate
it = iters[curr_iter]
curr_iter = (curr_iter + 1) % n
candidate = next(it)
except StopIteration:
return
Yes, it is possible! I've modified your example code to do this.
My answer assumes that your question is about the algorithm - if you want the fastest-running code using sets, see other answers.
This maintains the O(n log(k)) time complexity: all the code between if lowest != elem or ary != times_seen: and unbench_all = False is O(log(k)). There is a nested loop inside the main loop (for unbenched in range(times_seen):) but this only runs times_seen times, and times_seen is initially 0 and is reset to 0 after every time this inner loop is run, and can only be incremented once per main loop iteration, so the inner loop cannot do more iterations in total than the main loop. Thus, since the code inside the inner loop is O(log(k)) and runs at most as many times as the outer loop, and the outer loop is O(log(k)) and runs n times, the algorithm is O(n log(k)).
This algorithm relies upon how tuples are compared in Python. It compares the first items of the tuples, and if they are equal it, compares the second items (i.e. (x, a) < (x, b) is true if and only if a < b).
In this algorithm, unlike in the example code in the question, when an item is popped from the heap, it is not necessarily pushed again in the same iteration. Since we need to check if all sub-lists contain the same number, after a number is popped from the heap, it's sublist is what I call "benched", meaning that it is not added back to the heap. This is because we need to check if other sub-lists contain the same item, so adding this sub-list's next item is not needed right now.
If a number is indeed in all sub-lists, then the heap will look something like [(2,0),(2,1),(2,2),(2,3)], with all the first elements of the tuples the same, so heappop will select the one with the lowest sub-list index. This means that first index 0 will be popped and times_seen will be incremented to 1, then index 1 will be popped and times_seen will be incremented to 2 - if ary is not equal to times_seen then the number is not in the intersection of all sub-lists. This leads to the condition if lowest != elem or ary != times_seen:, which decides when a number shouldn't be in the result. The else branch of this if statement is for when it still might be in the result.
The unbench_all boolean is for when all sub-lists need to be removed from the bench - this could be because:
The current number is known to not be in the intersection of the sub-lists
It is known to be in the intersection of the sub-lists
When unbench_all is True, all the sub-lists that were removed from the heap are re-added. It is known that these are the ones with indices in range(times_seen) since the algorithm removes items from the heap only if they have the same number, so they must have been removed in order of index, contiguously and starting from index 0, and there must be times_seen of them. This means that we don't need to store the indices of the benched sub-lists, only the number that have been benched.
import heapq
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# the number of tims that the current number has been seen
times_seen = 0
# the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
lowest = heap[0][0] if heap else None
# collect results in nlogK time
while heap:
elem, ary = heap[0]
unbench_all = True
if lowest != elem or ary != times_seen:
if lowest == elem:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
else:
heapq.heappop(heap)
times_seen += 1
if times_seen == len(srtd_arys):
res.append(elem)
else:
unbench_all = False
if unbench_all:
for unbenched in range(times_seen):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
times_seen = 0
if heap:
lowest = heap[0][0]
return res
if __name__ == '__main__':
a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
a2 = [[1, 1], [1, 1, 2, 2, 3]]
for arys in [a1, a2]:
print(mergeArys(arys))
An equivalent algorithm can be written like this, if you prefer:
def mergeArys(srtd_arys):
heap = []
srtd_iters = [iter(x) for x in srtd_arys]
# put the first element from each srtd array onto the heap
for idx, it in enumerate(srtd_iters):
elem = next(it, None)
if elem:
heapq.heappush(heap, (elem, idx))
res = []
# collect results in nlogK time
while heap:
elem, ary = heap[0]
lowest = elem
keep_elem = True
for i in range(len(srtd_arys)):
elem, ary = heap[0]
if lowest != elem or ary != i:
if ary != i:
heapq.heappop(heap)
it = srtd_iters[ary]
nxt = next(it, None)
if nxt:
heapq.heappush(heap, (nxt, ary))
keep_elem = False
i -= 1
break
heapq.heappop(heap)
if keep_elem:
res.append(elem)
for unbenched in range(i+1):
unbenched_it = srtd_iters[unbenched]
nxt = next(unbenched_it, None)
if nxt:
heapq.heappush(heap, (nxt, unbenched))
if len(heap) < len(srtd_arys):
heap = []
return res
You can use builtin sets and sets intersections :
d = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
result = set(d[0]).intersection(*d[1:])
{1, 7}
You can use reduce:
from functools import reduce
a = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
reduce(lambda x, y: x & set(y), a[1:], set(a[0]))
{1, 7}
I've come up with this algorithm. It doesn't exceed O(nk) I don't know if it's good enough for you. the point of this algorithm is that you can have k indexes for each array and each iteration you find the indexes of the next element in the intersection and increase every index until you exceed the bounds of an array and there are no more items in the intersection. the trick is since the arrays are sorted you can look at two elements in two different arrays and if one is bigger than the other you can instantly throw away the other because you know you cant have a smaller number than the one you are looking at. the worst case of this algorithm is that every index will be increased to the bound which takes kn time since an index cannot decrease its value.
inter = []
for n in range(len(arrays[0])):
if indexes[0] >= len(arrays[0]):
return inter
for i in range(1,k):
if indexes[i] >= len(arrays[i]):
return inter
while indexes[i] < len(arrays[i]) and arrays[i][indexes[i]] < arrays[0][indexes[0]]:
indexes[i] += 1
while indexes[i] < len(arrays[i]) and indexes[0] < len(arrays[0]) and arrays[i][indexes[i]] > arrays[0][indexes[0]]:
indexes[0] += 1
if indexes[0] < len(arrays[0]):
inter.append(arrays[0][indexes[0]])
indexes = [idx+1 for idx in indexes]
return inter
You said we can't use sets but how about dicts / hash tables? (yes I know they're basically the same thing) :D
If so, here's a fairly simple approach (please excuse the py2 syntax):
arrays = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
counts = {}
for ar in arrays:
last = None
for i in ar:
if (i != last):
counts[i] = counts.get(i, 0) + 1
last = i
N = len(arrays)
intersection = [i for i, n in counts.iteritems() if n == N]
print intersection
Same as Raymond Hettinger's solution but with more basic python code:
def intersection(arrays, unique: bool=False):
result = []
if not len(arrays) or any(not len(array) for array in arrays):
return result
pointers = [0] * len(arrays)
target = arrays[0][0]
start_step = 0
current_step = 1
while True:
idx = current_step % len(arrays)
array = arrays[idx]
while pointers[idx] < len(array) and array[pointers[idx]] < target:
pointers[idx] += 1
if pointers[idx] < len(array) and array[pointers[idx]] > target:
target = array[pointers[idx]]
start_step = current_step
current_step += 1
continue
if unique:
while (
pointers[idx] + 1 < len(array)
and array[pointers[idx]] == array[pointers[idx] + 1]
):
pointers[idx] += 1
if (current_step - start_step) == len(arrays):
result.append(target)
for other_idx, other_array in enumerate(arrays):
pointers[other_idx] += 1
if pointers[idx] < len(array):
target = array[pointers[idx]]
start_step = current_step
if pointers[idx] == len(array):
return result
current_step += 1
Here's an O(n) answer (where n = sum(len(sublist) for sublist in data)).
from itertools import cycle
def intersection(data):
result = []
maxval = float("-inf")
consecutive = 0
try:
for sublist in cycle(iter(sublist) for sublist in data):
value = next(sublist)
while value < maxval:
value = next(sublist)
if value > maxval:
maxval = value
consecutive = 0
continue
consecutive += 1
if consecutive >= len(data)-1:
result.append(maxval)
consecutive = 0
except StopIteration:
return result
print(intersection([[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]))
[1, 7]
Some of the above methods are not covering the examples when there are duplicates in every subset of the list. The Below code implements this intersection and it will be more efficient if there are lots of duplicates in the subset of the list :) If not sure about duplicates it is recommended to use Counter from collections from collections import Counter. The custom counter function is made for increasing the efficiency of handling large duplicates. But still can not beat Raymond Hettinger's implementation.
def counter(my_list):
my_list = sorted(my_list)
first_val, *all_val = my_list
p_index = my_list.index(first_val)
my_counter = {}
for item in all_val:
c_index = my_list.index(item)
diff = abs(c_index-p_index)
p_index = c_index
my_counter[first_val] = diff
first_val = item
c_index = my_list.index(item)
diff = len(my_list) - c_index
my_counter[first_val] = diff
return my_counter
def my_func(data):
if not data or not isinstance(data, list):
return
# get the first value
first_val, *all_val = data
if not isinstance(first_val, list):
return
# count items in first value
p = counter(first_val) # counter({1: 2, 3: 1, 5: 1, 7: 1})
# collect all common items and calculate the minimum occurance in intersection
for val in all_val:
# collecting common items
c = counter(val)
# calculate the minimum occurance in intersection
inner_dict = {}
for inner_val in set(c).intersection(set(p)):
inner_dict[inner_val] = min(p[inner_val], c[inner_val])
p = inner_dict
# >>>p
# {1: 2, 7: 1}
# Sort by keys of counter
sorted_items = sorted(p.items(), key=lambda x:x[0]) # [(1, 2), (7, 1)]
result=[i[0] for i in sorted_items for _ in range(i[1])] # [1, 1, 7]
return result
Here are the sample Examples
>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> my_func(data=data)
[1, 7]
>>> data = [[1,1,3,5,7],[1,1,3,5,7],[1,1,4,7,9]]
>>> my_func(data=data)
[1, 1, 7]
You can do the following using the functions heapq.merge, chain.from_iterable and groupby
from heapq import merge
from itertools import groupby, chain
ls = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
def index_groups(lst):
"""[1, 1, 3, 5, 7] -> [(1, 0), (1, 1), (3, 0), (5, 0), (7, 0)]"""
return chain.from_iterable(((e, i) for i, e in enumerate(group)) for k, group in groupby(lst))
iterables = (index_groups(li) for li in ls)
flat = merge(*iterables)
res = [k for (k, _), g in groupby(flat) if sum(1 for _ in g) == len(ls)]
print(res)
Output
[1, 7]
The idea is to give an extra value (using enumerate) to differentiate between equal values within the same list (see the function index_groups).
The complexity of this algorithm is O(n) where n is the sum of the lengths of each list in the input.
Note that the output for (an extra 1 en each list):
ls = [[1, 1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 1, 4, 7, 9]]
is:
[1, 1, 7]
You can use bit-masking with one-hot encoding. The inner lists become maxterms. You and them together for the intersection and or them for the union. Then you have to convert back, for which I've used a bit hack.
problem = [[1,3,5,7],[1,1,3,5,8,7],[1,4,7,9]];
debruijn = [0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9];
u32 = accum = (1 << 32) - 1;
for vec in problem:
maxterm = 0;
for v in vec:
maxterm |= 1 << v;
accum &= maxterm;
# https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogDeBruijn
result = [];
while accum:
power = accum;
accum &= accum - 1; # Peter Wegner CACM 3 (1960), 322
power &= ~accum;
result.append(debruijn[((power * 0x077CB531) & u32) >> 27]);
print result;
This uses (simulates) 32-bit integers, so you can only have [0, 31] in your sets.
*I am inexperienced at Python, so I timed it. One should definitely use set.intersection.
Here is the single-pass counting algorithm, a simplified version of what others have suggested.
def intersection(iterables):
target, count = None, 0
for it in itertools.cycle(map(iter, iterables)):
for value in it:
if count == 0 or value > target:
target, count = value, 1
break
if value == target:
count += 1
break
else: # exhausted iterator
return
if count >= len(iterables):
yield target
count = 0
Binary and exponential search haven't come up yet. They're easily recreated even with the "no builtins" constraint.
In practice, that would be much faster, and sub-linear. In the worst case - where the intersection isn't shrinking - the naive approach would repeat work. But there's a solution for that: integrate the binary search while splitting the arrays in half.
def intersection(seqs):
seq = min(seqs, key=len)
if not seq:
return
pivot = seq[len(seq) // 2]
lows, counts, highs = [], [], []
for seq in seqs:
start = bisect.bisect_left(seq, pivot)
stop = bisect.bisect_right(seq, pivot, start)
lows.append(seq[:start])
counts.append(stop - start)
highs.append(seq[stop:])
yield from intersection(lows)
yield from itertools.repeat(pivot, min(counts))
yield from intersection(highs)
Both handle duplicates. Both guarantee O(N) worst-case time (counting slicing as atomic). The latter will approach O(min_size) speed; by always splitting the smallest in half it essentially can't suffer from the bad luck of uneven splits.
I couldn't help but notice that this is seems to be a variation on the Welfare Crook problem; see David Gries's book, The Science of Programming. Edsger Dijkstra also wrote an EWD about this, see Ascending Functions and the Welfare Crook.
The Welfare Crook
Suppose we have three long magnetic tapes, each containing a list of names in alphabetical order:
all people working for IBM Yorktown
students at Columbia University
people on welfare in New York City
Practically speaking, all three lists are endless, so no upper bounds are given. It is know that at least one person is on all three lists. Write a program to locate the first such person.
Our intersection of the ordered lists problem is a generalization of the Welfare Crook problem.
Here's a (rather primitive?) Python solution to the Welfare Crook problem:
def find_welfare_crook(f, g, h, i, j, k):
"""f, g, and h are "ascending functions," i.e.,
i <= j implies f[i] <= f[j] or, equivalently,
f[i] < f[j] implies i < j, and the same goes for g and h.
i, j, k define where to start the search in each list.
"""
# This is an implementation of a solution to the Welfare Crook
# problems presented in David Gries's book, The Science of Programming.
# The surprising and beautiful thing is that the guard predicates are
# so few and so simple.
i , j , k = i , j , k
while True:
if f[i] < g[j]:
i += 1
elif g[j] < h[k]:
j += 1
elif h[k] < f[i]:
k += 1
else:
break
return (i,j,k)
# The other remarkable thing is how the negation of the guard
# predicates works out to be: f[i] == g[j] and g[j] == c[k].
Generalization to Intersection of K Lists
This generalizes to K lists, and here's what I devised; I don't know how Pythonic this is, but it pretty compact:
def findIntersectionLofL(lofl):
"""Generalized findIntersection function which operates on a "list of lists." """
K = len(lofl)
indices = [0 for i in range(K)]
result = []
#
try:
while True:
# idea is to maintain the indices via a construct like the following:
allEqual = True
for i in range(K):
if lofl[i][indices[i]] < lofl[(i+1)%K][indices[(i+1)%K]] :
indices[i] += 1
allEqual = False
# When the above iteration finishes, if all of the list
# items indexed by the indices are equal, then another
# item common to all of the lists must be added to the result.
if allEqual :
result.append(lofl[0][indices[0]])
while lofl[0][indices[0]] == lofl[1][indices[1]]:
indices[0] += 1
except IndexError as e:
# Eventually, the foregoing iteration will advance one of the
# indices past the end of one of the lists, and when that happens
# an IndexError exception will be raised. This means the algorithm
# is finished.
return result
This solution does not keep repeated items. Changing the program to include all of the repeated items by changing what the program does in the conditional at the end of the "while True" loop is an exercise left to the reader.
Improved Performance
Comments from #greybeard prompted refinements shown below, in the
pre-computation of the "array index moduli" (the "(i+1)%K" expressions) and further investigation also brought about changes to the inner iteration's structure, to further remove overhead:
def findIntersectionLofLunRolled(lofl):
"""Generalized findIntersection function which operates on a "list of lists."
Accepts a list-of-lists, lofl. Each of the lists must be ordered.
Returns the list of each element which appears in all of the lists at least once.
"""
K = len(lofl)
indices = [0] * K
result = []
lt = [ (i, (i+1) % K) for i in range(K) ] # avoids evaluation of index exprs inside the loop
#
try:
while True:
allUnEqual = True
while allUnEqual:
allUnEqual = False
for i,j in lt:
if lofl[i][indices[i]] < lofl[j][indices[j]]:
indices[i] += 1
allUnEqual = True
# Now all of the lofl[i][indices[i]], for all i, are the same value.
# Store that value in the result, and then advance all of the indices
# past that common value:
v = lofl[0][indices[0]]
result.append(v)
for i,j in lt:
while lofl[i][indices[i]] == v:
indices[i] += 1
except IndexError as e:
# Eventually, the foregoing iteration will advance one of the
# indices past the end of one of the lists, and when that happens
# an IndexError exception will be raised. This means the algorithm
# is finished.
return result

How to sort an array with n elements in which k elements are out of place in O(n + k log k)?

I was asked this in an interview today, and am starting to believe it is not solvable.
Given a sorted array of size n, select k elements in the array, and reshuffle them back into the array, resulting in a new "nk-sorted" array.
Find the k (or less) elements that have moved in that new array.
Here is (Python) code that creates such arrays, but I don't care about language for this.
import numpy as np
def __generate_unsorted_array(size, is_integer=False, max_int_value=100000):
return np.random.randint(max_int_value, size=size) if is_integer else np.random.rand(size)
def generate_nk_unsorted_array(n, k, is_integer=False, max_int_value=100000):
assert k <= n
unsorted_n_array = __generate_unsorted_array(n - k, is_integer, max_int_value=max_int_value)
sorted_n_array = sorted(unsorted_n_array)
random_k_array = __generate_unsorted_array(k, is_integer, max_int_value=max_int_value)
insertion_inds = np.random.choice(n - k + 1, k, replace=True) # can put two unsorted next to each other.
nk_unsorted_array = np.insert(sorted_n_array, insertion_inds, random_k_array)
return list(nk_unsorted_array)
Is this doable under the complexity constraint?
This is only part of the question. The whole question required to sort the "nk-sorted array" in O(n+klogk)
Note: This is a conceptual solution. It is coded in Python, but because of the way Python implements List, does not actually run in the required complexity. See soyuzzzz's answer to see an actual solution in Python in the complexity requirement.
Accepted #soyuzzzz's answer over this one.
Original answer (works, but the complexity is only correct assuming Linked list implementation for Python's List, which is not the case):
This sorts a nk-unsorted array in O(n + klogk), assuming the array should be ascending.
Find elements which are not sorted by traversing the array.
If such an element was found (it is larger then the following one), then either it or the following one are out of order (or both).
Keep both of them aside, and remove them from the array
continue traversing on the newly obtained array (after removal), form the index which comes before the found element.
This will put aside 2k elements in O(n) time.
Sort 2k elements O(klogk)
Merge two sorted lists which have total n elements, O(n)
Total O(n + klogk)
Code:
def merge_sorted_lists(la, lb):
if la is None or la == []:
return lb
if lb is None or lb == []:
return la
a_ind = b_ind = 0
a_len = len(la)
b_len = len(lb)
merged = []
while a_ind < a_len and b_ind < b_len:
a_value = la[a_ind]
b_value = lb[b_ind]
if a_value < b_value:
merged.append(la[a_ind])
a_ind += 1
else:
merged.append(lb[b_ind])
b_ind += 1
# get the leftovers into merged
while a_ind < a_len:
merged.append(la[a_ind])
a_ind += 1
while b_ind < b_len:
merged.append(lb[b_ind])
b_ind += 1
return merged
and
def sort_nk_unsorted_list(nk_unsorted_list):
working_copy = nk_unsorted_list.copy() # just for ease of testing
requires_resorting = []
current_list_length = len(working_copy)
i = 0
while i < current_list_length - 1 and 1 < current_list_length:
if i == -1:
i = 0
first = working_copy[i]
second = working_copy[i + 1]
if second < first:
requires_resorting.append(first)
requires_resorting.append(second)
del working_copy[i + 1]
del working_copy[i]
i -= 2
current_list_length -= 2
i += 1
sorted_2k_elements = sorted(requires_resorting)
sorted_nk_list = merge_sorted_lists(sorted_2k_elements, working_copy)
return sorted_nk_list
Even though #Gulzar's solution is correct, it doesn't actually give us O(n + k * log k).
The problem is in the sort_nk_unsorted_list function. Unfortunately, deleting an arbitrary item from a Python list is not constant time. It's actually O(n). That gives the overall algorithm a complexity of O(n + nk + k * log k)
What we can do to address this is use a different data structure. If you use a doubly-linked list, removing an item from that list is actually O(1). Unfortunately, Python does not come with one by default.
Here's my solution that achieves O(n + k * log k).
The entry-point function to solve the problem:
def sort(my_list):
in_order, out_of_order = separate_in_order_from_out_of_order(my_list)
out_of_order.sort()
return merge(in_order, out_of_order)
The function that separates the in-order elements from the out-of-order elements:
def separate_in_order_from_out_of_order(my_list):
list_dll = DoublyLinkedList.from_list(my_list)
out_of_order = []
current = list_dll.head
while current.next is not None:
if current.value > current.next.value:
out_of_order.append(current.value)
out_of_order.append(current.next.value)
previous = current.prev
current.next.remove()
current.remove()
current = previous
else:
current = current.next
in_order = list_dll.to_list()
return in_order, out_of_order
The function to merge the two separated lists:
def merge(first, second):
"""
Merges two [sorted] lists into a sorted list.
Runtime complexity: O(n)
Space complexity: O(n)
"""
i, j = 0, 0
result = []
while i < len(first) and j < len(second):
if first[i] < second[j]:
result.append(first[i])
i += 1
else:
result.append(second[j])
j += 1
result.extend(first[i:len(first)])
result.extend(second[j:len(second)])
return result
And last, this is the DoublyLinkedList implementation (I used a sentinel node to make things easier):
class DoublyLinkedNode:
def __init__(self, value):
self.value = value
self.next = None
self.prev = None
def remove(self):
if self.prev:
self.prev.next = self.next
if self.next:
self.next.prev = self.prev
class DoublyLinkedList:
def __init__(self, head):
self.head = head
#staticmethod
def from_list(lst):
sentinel = DoublyLinkedNode(-math.inf)
previous = sentinel
for item in lst:
node = DoublyLinkedNode(item)
node.prev = previous
previous.next = node
previous = node
return DoublyLinkedList(sentinel)
def to_list(self):
result = []
current = self.head.next
while current is not None:
result.append(current.value)
current = current.next
return result
And these are the unit tests I used to validate the code:
import unittest
class TestSort(unittest.TestCase):
def test_sort(self):
test_cases = [
# ( input, expected result)
([1, 2, 3, 4, 10, 5, 6], [1, 2, 3, 4, 5, 6, 10]),
([1, 2, 5, 4, 10, 6, 0], [0, 1, 2, 4, 5, 6, 10]),
([1], [1]),
([1, 3, 2], [1, 2, 3]),
([], [])
]
for (test_input, expected) in test_cases:
result = sort(test_input)
self.assertEqual(expected, result)

Python - speed up pathfinding

This is my pathfinding function:
def get_distance(x1,y1,x2,y2):
neighbors = [(-1,0),(1,0),(0,-1),(0,1)]
old_nodes = [(square_pos[x1,y1],0)]
new_nodes = []
for i in range(50):
for node in old_nodes:
if node[0].x == x2 and node[0].y == y2:
return node[1]
for neighbor in neighbors:
try:
square = square_pos[node[0].x+neighbor[0],node[0].y+neighbor[1]]
if square.lightcycle == None:
new_nodes.append((square,node[1]))
except KeyError:
pass
old_nodes = []
old_nodes = list(new_nodes)
new_nodes = []
nodes = []
return 50
The problem is that the AI takes to long to respond( response time <= 100ms)
This is just a python way of doing https://en.wikipedia.org/wiki/Pathfinding#Sample_algorithm
You should replace your algorithm with A*-search with the Manhattan distance as a heuristic.
One reasonably fast solution is to implement the Dijkstra algorithm (that I have already implemented in that question):
Build the original map. It's a masked array where the walker cannot walk on masked element:
%pylab inline
map_size = (20,20)
MAP = np.ma.masked_array(np.zeros(map_size), np.random.choice([0,1], size=map_size))
matshow(MAP)
Below is the Dijkstra algorithm:
def dijkstra(V):
mask = V.mask
visit_mask = mask.copy() # mask visited cells
m = numpy.ones_like(V) * numpy.inf
connectivity = [(i,j) for i in [-1, 0, 1] for j in [-1, 0, 1] if (not (i == j == 0))]
cc = unravel_index(V.argmin(), m.shape) # current_cell
m[cc] = 0
P = {} # dictionary of predecessors
#while (~visit_mask).sum() > 0:
for _ in range(V.size):
#print cc
neighbors = [tuple(e) for e in asarray(cc) - connectivity
if e[0] > 0 and e[1] > 0 and e[0] < V.shape[0] and e[1] < V.shape[1]]
neighbors = [ e for e in neighbors if not visit_mask[e] ]
tentative_distance = [(V[e]-V[cc])**2 for e in neighbors]
for i,e in enumerate(neighbors):
d = tentative_distance[i] + m[cc]
if d < m[e]:
m[e] = d
P[e] = cc
visit_mask[cc] = True
m_mask = ma.masked_array(m, visit_mask)
cc = unravel_index(m_mask.argmin(), m.shape)
return m, P
def shortestPath(start, end, P):
Path = []
step = end
while 1:
Path.append(step)
if step == start: break
if P.has_key(step):
step = P[step]
else:
break
Path.reverse()
return asarray(Path)
And the result:
start = (2,8)
stop = (17,19)
D, P = dijkstra(MAP)
path = shortestPath(start, stop, P)
imshow(MAP, interpolation='nearest')
plot(path[:,1], path[:,0], 'ro-', linewidth=2.5)
Below some timing statistics:
%timeit dijkstra(MAP)
#10 loops, best of 3: 32.6 ms per loop
The biggest issue with your code is that you don't do anything to avoid the same coordinates being visited multiple times. This means that the number of nodes you visit is guaranteed to grow exponentially, since it can keep going back and forth over the first few nodes many times.
The best way to avoid duplication is to maintain a set of the coordinates we've added to the queue (though if your node values are hashable, you might be able to add them directly to the set instead of coordinate tuples). Since we're doing a breadth-first search, we'll always reach a given coordinate by (one of) the shortest path(s), so we never need to worry about finding a better route later on.
Try something like this:
def get_distance(x1,y1,x2,y2):
neighbors = [(-1,0),(1,0),(0,-1),(0,1)]
nodes = [(square_pos[x1,y1],0)]
seen = set([(x1, y1)])
for node, path_length in nodes:
if path_length == 50:
break
if node.x == x2 and node.y == y2:
return path_length
for nx, ny in neighbors:
try:
square = square_pos[node.x + nx, node.y + ny]
if square.lightcycle == None and (square.x, square.y) not in seen:
nodes.append((square, path_length + 1))
seen.add((square.x, square.y))
except KeyError:
pass
return 50
I've also simplified the loop a bit. Rather than switching out the list after each depth, you can just use one loop and add to its end as you're iterating over the earlier values. I still abort if a path hasn't been found with fewer than 50 steps (using the distance stored in the 2-tuple, rather than the number of passes of the outer loop). A further improvement might be to use a collections.dequeue for the queue, since you could efficiently pop from one end while appending to the other end. It probably won't make a huge difference, but might avoid a little bit of memory usage.
I also avoided most of the indexing by one and zero in favor of unpacking into separate variable names in the for loops. I think this is much easier to read, and it avoids confusion since the two different kinds of 2-tuples had had different meanings (one is a node, distance tuple, the other is x, y).

Categories

Resources