#include "stdafx.h"
#include "AStarShortestPathSolver.h"

#include <boost/graph/adjacency_list.hpp>
#include <boost/graph/astar_search.hpp>

using namespace boost;
using namespace delaunator;

//#define GRDRAW

static void grdraw(const AcGePoint3d& pa, const AcGePoint3d& pb, int color)
{
#ifdef GRDRAW
    acedGrDraw(asDblArray(pa), asDblArray(pb), color, 0);
#endif
}

static double circum(const AcGePoint3d& pa, const AcGePoint3d& pb, const AcGePoint3d& pc)
{
    const auto a = sqrt(pow((pa[0] - pb[0]) * 2, 2) + pow((pa[1] - pb[1]) * 2, 2));
    const auto b = sqrt(pow((pb[0] - pc[0]) * 2, 2) + pow((pb[1] - pc[1]) * 2, 2));
    const auto c = sqrt(pow((pc[0] - pa[0]) * 2, 2) + pow((pc[1] - pa[1]) * 2, 2));
    const auto s = (a + b + c) / 2.0;
    const auto area = sqrt(s * (s - a) * (s - b) * (s - c));
    return a * b * c / (4.0 * area);
}

static int64_t getIndex(const std::vector<AcGePoint3d>& v, const AcGePoint3d& K)
{
    AcGeTol tol;
    tol.setEqualPoint(0.001);//TODO use LUPREC?
    const auto it = find_if(v.begin(), v.end(), [&](const AcGePoint3d& val) { return K.isEqualTo(val, tol); });
    if (it != v.end())
        return it - v.begin();
    else
        return -1;
}

template <class Graph, class CostType, class LocMap>
class distance_heuristic : public astar_heuristic<Graph, CostType>
{
public:
    typedef typename graph_traits<Graph>::vertex_descriptor Vertex;
    distance_heuristic(LocMap l, Vertex goal)
        : m_location(l), m_goal(goal)
    {
    }
    CostType operator()(Vertex u)
    {
        return m_location[m_goal].distanceTo(m_location[u]);
    }
private:
    LocMap m_location;
    Vertex m_goal;
};

struct found_goal {}; // exception for termination

template <class Vertex>
class astar_goal_visitor : public boost::default_astar_visitor
{
public:
    astar_goal_visitor(Vertex goal) : m_goal(goal)
    {
    }
    template <class Graph>
    void examine_vertex(Vertex u, Graph& g)
    {
        if (u == m_goal)
            throw found_goal();
    }
private:
    Vertex m_goal;
};

AStarShortestPathSolver::AStarShortestPathSolver(int64_t _startidx, int64_t _endidx, double _alpha, const std::vector<AcGePoint3d>& _inputPoints)
    : startidx(_startidx), endidx(_endidx), alpha(_alpha), inputPoints(_inputPoints)
{
}

bool AStarShortestPathSolver::visited(const std::pair<size_t, size_t>& item) const
{
    auto edge = item;
    if (edgemap.contains(edge))
        return true;
    std::swap(edge.first, edge.second);
    if (edgemap.contains(edge))
        return true;
    return false;
}

void AStarShortestPathSolver::setVisited(const std::pair<size_t, size_t>& item)
{
    auto edge = item;
    edgemap.insert(edge);
    std::swap(edge.first, edge.second);
    edgemap.insert(edge);
}

void AStarShortestPathSolver::intertLastEdge(size_t i, size_t j, size_t x)
{
    if (i > j)
        std::swap(i, j);
    lastedgemap[{i, j}].push_back({ x });
}

bool AStarShortestPathSolver::initTriangles()
{
    bool returnValue = false;
    try
    {
        std::vector<double> coords2d;
        coords2d.reserve(inputPoints.size() * 2);
        for (const auto& item : inputPoints)
        {
            coords2d.emplace_back(item.x);
            coords2d.emplace_back(item.y);
        }
        pDelaunator.reset(new Delaunator(coords2d));
        returnValue = true;
    }
    catch (...)
    {
        returnValue = false;
        acutPrintf(_T("\nDelaunaySolver failed"));
    }
    return returnValue;
}

std::vector<AcGePoint3d> AStarShortestPathSolver::solveShortestPathAStar()
{
    using EdgeNode = std::pair< size_t, size_t>;
    using Graph = adjacency_list < vecS, vecS, undirectedS, no_property, property < edge_weight_t, double > >;
    using Edge = graph_traits <Graph>::edge_descriptor;
    typedef Graph::vertex_descriptor vertex;
    typedef property_map<Graph, edge_weight_t>::type WeightMap;

    std::vector<AcGePoint3d> path;
    if (!initTriangles())
        return path;

    if (startidx == -1)
        return path;

    if (endidx == -1)
        return path;

    std::vector<EdgeNode> edges;
    edges.reserve(pDelaunator->triangles.size() * 3);

    std::vector<double> weights;
    weights.reserve(pDelaunator->triangles.size() * 3);

    lastedgemap.reserve(inputPoints.size() * 3);

    for (size_t i = 0; i < pDelaunator->triangles.size(); i += 3)
    {
        auto t0 = pDelaunator->triangles[i + 0];
        auto t1 = pDelaunator->triangles[i + 1];
        auto t2 = pDelaunator->triangles[i + 2];

        const auto& p0 = inputPoints[t0];
        const auto& p1 = inputPoints[t1];
        const auto& p2 = inputPoints[t2];

        if (alpha < 0 || circum(p0, p1, p2) < alpha)
        {
            //not quite what I want, but it works
            //if two triangles have a matching edge i.e t1, t2
            //then create an edge between the two t0s
            intertLastEdge(t0, t1, t2);
            intertLastEdge(t1, t2, t0);
            intertLastEdge(t2, t0, t1);

            if (!visited({ t0 ,t1 }))
            {
                //grdraw(p0, p1, 3);
                edges.push_back({ t0 ,t1 });
                weights.push_back({ p0.distanceTo(p1) });
                setVisited({ t0 ,t1 });
            }
            if (!visited({ t1 ,t2 }))
            {
                //grdraw(p1, p2, 4);
                edges.push_back({ t1 ,t2 });
                weights.push_back({ p1.distanceTo(p2) });
                setVisited({ t1 ,t2 });
            }
            if (!visited({ t2 ,t0 }))
            {
                //grdraw(p2, p0, 1);
                edges.push_back({ t0 ,t2 });
                weights.push_back({ p2.distanceTo(p0) });
                setVisited({ t2 ,t0 });
            }
        }
    }

    for (const auto& item : lastedgemap)
    {
        if (item.second.size() == 2)
        {
            edges.push_back({ item.second[0] , item.second[1] });
            //grdraw(inputPoints[item.second[0]], inputPoints[item.second[1]], 3);
            weights.push_back(inputPoints[item.second[0]].distanceTo(inputPoints[item.second[1]]));
        }
    }

    bool found = false;
    const vertex start = startidx;
    const vertex goal = endidx;

    Graph g(edges.begin(), edges.end(), weights.begin(), inputPoints.size());
    std::vector<Graph::vertex_descriptor> p(num_vertices(g));
    std::vector<double> d(num_vertices(g));
    try
    {
        astar_search(g, start,
            distance_heuristic<Graph, double, std::vector<AcGePoint3d>>(inputPoints, goal),
            predecessor_map(&p[0]).distance_map(&d[0]).
            visitor(astar_goal_visitor<vertex>(goal)));
    }
    catch (found_goal fg)
    {
        found = true;//ugh
    }
    if (found)
    {
        path.push_back(inputPoints[startidx]);
        std::list<vertex> shortest_path;
        for (vertex v = goal;; v = p[v])
        {
            shortest_path.push_front(v);
            if (p[v] == v)
                break;
        }
        std::list<vertex>::iterator spi = shortest_path.begin();
        for (++spi; spi != shortest_path.end(); ++spi)
            path.push_back(inputPoints[*spi]);
    }
    else
    {
        path.push_back(inputPoints[startidx]);
        path.push_back(inputPoints[endidx]);
    }
    return path;
}

int AStarShortestPathSolver::shortestPathAstarLispFunc()
{
    AcGePoint3d sp;
    AcGePoint3d ep;
    double alpha = 360;

    size_t narg = 0;
    std::vector<AcGePoint3d> points;
    {
        AcResBufPtr pArgs(acedGetArgs());
        for (resbuf* pTail = pArgs.get(); pTail != nullptr; pTail = pTail->rbnext)
        {
            switch (pTail->restype)
            {
                case RTSHORT:
                    alpha = pTail->resval.rint;
                    break;
                case RTLONG:
                    alpha = pTail->resval.rlong;
                    break;
                case RTREAL:
                    alpha = pTail->resval.rreal;
                    break;
                case RTPOINT:
                case RT3DPOINT:
                {
                    switch (narg)
                    {
                        case 1:
                            sp = asPnt3d(pTail->resval.rpoint);
                            break;
                        case 2:
                            ep = asPnt3d(pTail->resval.rpoint);
                            break;
                        default:
                            points.push_back(asPnt3d(pTail->resval.rpoint));
                            break;
                    }
                    break;
                }
                break;
                default:
                    break;
            }
            narg++;
        }
    }
    if (points.size() > 2)
    {
        std::sort(std::execution::par, points.begin(), points.end(), [](const AcGePoint3d& a, const AcGePoint3d& b) -> bool
        {
            if (a.y == b.y)
                return a.x < b.x;
            return a.y < b.y;
        });
        std::sort(std::execution::par, points.begin(), points.end(), [](const AcGePoint3d& a, const AcGePoint3d& b) -> bool
        {
            if (a.x == b.x)
                return a.y < b.y;
            return a.x < b.x;
        });

        auto startidx = getIndex(points, sp);
        if (startidx == -1)
        {
            acedRetNil();
            return (RSRSLT);
        }

        auto endidx = getIndex(points, ep);
        if (endidx == -1)
        {
            acedRetNil();
            return (RSRSLT);
        }

        AStarShortestPathSolver solver(startidx, endidx, alpha, std::move(points));
        const auto& path = solver.solveShortestPathAStar();
        if (path.size())
        {
            AcResBufPtr pRes(acutNewRb(RTLB));
            resbuf* pResTail = pRes.get();
            for (auto& item : path)
            {
                pResTail = pResTail->rbnext = acutNewRb(RT3DPOINT);
                memcpy(pResTail->resval.rpoint, asDblArray(item), sizeof(pResTail->resval.rpoint));
            }
            pResTail = pResTail->rbnext = acutNewRb(RTLE);
            acedRetList(pRes.get());
        }
    }
    return (RSRSLT);
}

