#!/usr/bin/env python3

import os
import sys
import numpy
import random

import skimage.graph

from matplotlib import pyplot as plt


def to_pixel(x, y, shape, scale=2):
    """
    Take in a physical coordinates and convert it to the closest pixel coordinates
    given the shape of the pixel grid and a scale factor for the conversion.
    Returns the pixel coordinates as a two-element tuple.
    """
    
    x = int(round(x*scale)) + shape[0]//2
    y = int(round(y*scale)) + shape[1]//2
    return x, y


def from_pixel(x, y, shape, scale=2):
    """
    Take in a pixel coordinate and convert it to physical coordinates given the
    shape of the pixel grid and a scale factor for the conversion. Returns the
    physical coordinates as a two-element tuple.
    """
    
    x = (x - shape[0]//2) / scale
    y = (y - shape[1]//2) / scale
    return x, y


def main(args):
    # Define the parameters of the problem
    CONV_SCALE = 2          # Conversion scale from physical to pixel
    REUSE_FACTOR_NUM = 7    # Trench reuse numerator
    REUSE_FACTOR_DEN = 8    # Trench reuse denominator
    NITER = 100             # Number of random orderings to try
    NODE_COST = 1           # Cost for going through a stand/shelter
    FIELD_COST = 1000       # Cost for everything else
    assert(REUSE_FACTOR_NUM <= REUSE_FACTOR_DEN)
    assert(NITER >= 10)
    assert(isinstance(NODE_COST, int))
    assert(isinstance(FIELD_COST, int))
    
    # Load in the stand data 
    filename = args[0]
    data = numpy.loadtxt(filename)
    print(f"Load positions for {data.shape[0]} stands")
    
    # Shift the average antenna position to be at (0,0)
    center = data[:,[1,2]].mean(axis=0)
    data[:,1] -= center[0]
    data[:,2] -= center[1]
    print(f"Shifted array center by {-center[0]:.2f}, {-center[1]:.2f}")
    
    # Define the location of the shelter entry panel in the same coordinate
    # system as the antennas
    shelter = [-48, 55.8]
    print(f"Shelter is located at {shelter[0]:.2f}, {shelter[1]:.2f}, a distance of \
{numpy.sqrt(shelter[0]**2+shelter[1]**2):.2f} from the center")
    
    # Setup the graph
    dx = max([110, data[:,1].max() - data[:,1].min(), 2*abs(shelter[0])])
    dy = max([110, data[:,2].max() - data[:,2].min(), 2*abs(shelter[1])])
    d = max([dx, dy])
    npx = int(numpy.ceil(d/10))*10 * CONV_SCALE
    npx += (npx+1) % 2
    print(f"Using a {npx} by {npx} grid for the trench layout")
    
    # Convert the antenna positions to pixels that will become nodes on the graph
    graph = numpy.zeros((npx,npx), dtype=numpy.bool)
    for i in range(data.shape[0]):
        x, y = data[i,1], data[i,2]
        x, y = to_pixel(x, y, graph.shape, scale=CONV_SCALE)
        graph[x,y] = True
        
    # Convert the position of the shelter to pixels.  This will also become a 
    # node on the graph
    sx, sy = shelter
    sx, sy = to_pixel(sx, sy, graph.shape, scale=CONV_SCALE)
    graph[sx,sy] = True
    
    # Optimize
    best_order = None
    best_length = 1e9
    best_paths = None
    for i in range(NITER):
        ## Randomize the antenna list
        if i == 0:
            ### Special case 0: the original order
            order = list(range(data.shape[0]))
        elif i == 1:
            ### Special case 1: an order based on distance from the shelter
            shelter_d = (data[:,1]-shelter[0])**2 + (data[:,2]-shelter[1])**2
            order = list(numpy.argsort(shelter_d))
        else:
            ### The normal case - a randomized order
            order = random.sample(list(range(data.shape[0])), data.shape[0])
        print(f"Running order realization {i+1} of {NITER}", end='')
        
        ## Create an initial cost function where the cost is low at a node and high
        ## everywhere else
        ## NOTE:  This can also be updated to add topology constrains with even
        ##        higher costs than FIELD_COST
        costs = numpy.where(graph, NODE_COST, FIELD_COST)
        
        ## Find the lowest cost paths from the shelter to each antenna.   In the
        ## process, update the costs so that we favor reusing an existing trench
        total_length = 0.0
        paths = []
        for j in order:
            ### Antenna physical position -> pixel positions
            x, y = data[j,1], data[j,2]
            x, y = to_pixel(x, y, graph.shape, scale=CONV_SCALE)
            
            ### Find the lowest cost path from the shelter to the antenna
            path, cost = skimage.graph.route_through_array(costs,
                                                           start=(sx,sy),
                                                           end=(x,y),
                                                           fully_connected=True)
            paths.append(path)
            
            ### Update the costs to reflect the new (?) path
            for point in path:
                x, y = point
                costs[x,y] = max(NODE_COST, costs[x,y]*REUSE_FACTOR_NUM//REUSE_FACTOR_DEN)
                
            ### Find the total length
            length = 0.0
            for k in range(1, len(path)):
                #### Convert from pixel back to physical coordinates
                start = path[k-1]
                end = path[k]
                x0, y0 = from_pixel(*start, graph.shape, scale=CONV_SCALE)
                x1, y1 = from_pixel(*end, graph.shape, scale=CONV_SCALE)
                
                #### Update the length
                length += numpy.sqrt((x1-x0)**2 + (y1-y0)**2)
                
            total_length += length
            
        ## Check the quality
        if total_length < best_length:
            best_order = order
            best_length = total_length
            best_paths = paths
            print(f" -> {total_length:.3f}")
        else:
            print(' -> not as good')
            
    # Report on the best
    ## Start the plot
    fig = plt.figure()
    ax = fig.gca()
    ax.scatter(data[:,1], data[:,2], marker='o', color='blue')
    ax.scatter(shelter[0], shelter[1], marker='o', color='orange')
    
    ## Fill in the paths
    order = best_order
    paths = best_paths
    total_length = 0.0
    for i in range(data.shape[0]):
        j = order.index(i)
        path = paths[j]
        
        ### Plot up the path and find the total length
        length = 0.0
        for k in range(1, len(path)):
            ### Convert from pixel back to physical coordinates
            start = path[k-1]
            end = path[k]
            x0, y0 = from_pixel(*start, graph.shape, scale=CONV_SCALE)
            x1, y1 = from_pixel(*end, graph.shape, scale=CONV_SCALE)
            
            ### Add the line segment
            ax.plot([x0, x1], [y0, y1], color='k', alpha=0.6)
            
            ### Update the length
            length += numpy.sqrt((x1-x0)**2 + (y1-y0)**2)
        print(f"  Antenna {i+1} -> {length:.3f} in length")
        
        total_length += length
        
    print(f"   Total Length-> {total_length:.3f}")
    
    ax.set_xlabel('East-West [m]')
    ax.set_ylabel('North-South [m]')
    ax.set_xlim((-dx/2,dx/2))
    ax.set_ylim((-dy/2,dy/2))
    plt.show()


if __name__ == '__main__':
    main(sys.argv[1:])
