# file astar.py # Dr. Schwesinger, Spring 2018 import numpy as np import matplotlib.pyplot as plt plt.style.use('classic') def Astar(M, start, goal): costs = np.ones(M.shape) * np.inf closed_set = np.zeros(M.shape, dtype=np.bool) predecessors = -np.ones(M.shape + (2,), dtype=np.int) H = np.zeros(M.shape) for x, y in [(x,y) for x in range(M.shape[0]) for y in range(M.shape[1])]: H[x, y] = np.linalg.norm([x,y] - np.array(goal)) # plot the heuristic #plot_map((M == 0) * H, start, goal) parent = start costs[start] = 0 while True: o = np.where(closed_set, np.inf, costs) + H cell = np.unravel_index(o.argmin(), o.shape) # no more empty cells if o[cell] == np.inf: break parent = cell closed_set[cell] = True if parent == goal: break # get the neighbors neighbors = [] offsets = [(i,j) for i in [-1,0,1] for j in [-1,0,1] if j != 0 or i != 0] for i, j in offsets: x = cell[0] + i y = cell[1] + j if x < 0 or x >= M.shape[0]: continue if y < 0 or y >= M.shape[1]: continue if M[(x,y)] == 1: continue neighbors.append((x,y)) for n in neighbors: edge_cost = np.linalg.norm(np.array(cell) - np.array(n)) if costs[cell] + edge_cost < costs[n]: costs[n] = costs[cell] + edge_cost predecessors[n] = np.array(parent) # plot the costs of expanded nodes #plot_map(costs, start, goal) # plot the costs + heuristic of expanded nodes plot_map(costs + H, start, goal) # recover the path P = np.zeros(M.shape, dtype=np.bool) result = [] if parent == goal: path_length = 0 while (predecessors[parent][0]) >= 0: result.append(parent) P[parent] = True predecessor = predecessors[parent] path_length += np.linalg.norm(np.array(parent) - predecessor) parent = tuple(predecessor) result.append(parent) P[parent] = True result.reverse() print("Path length", path_length) # plot some stuff plot_map(M, start, goal) expanded = np.where(closed_set) frontier = np.where(costs != np.inf) print("Nodes expanded:", np.count_nonzero(closed_set)) path = np.where(P > 0) plt.plot(frontier[0], frontier[1], 'oy') plt.plot(expanded[0], expanded[1], 'ob') plt.plot(path[0], path[1], 'or') return result def plot_map(M, start, goal): plt.figure() plt.imshow(M.T, cmap=plt.cm.gray_r, interpolation='none', origin='upper') plt.axis([0, M.shape[0]-1, 0, M.shape[1]-1]) plt.plot(start[0], start[1], 'ro') plt.plot(goal[0], goal[1], 'go') plt.xlabel('x') plt.ylabel('y') if __name__ == '__main__': M = np.loadtxt('map.txt') start = (22,33) goal = (40,15) path = Astar(M, start, goal) print(path) plt.show()