#include "stdafx.h"
#include "TestSolver.h"
#include <ppl.h>

#include "ortools/constraint_solver/routing.h"
#include "ortools/constraint_solver/routing_enums.pb.h"
#include "ortools/constraint_solver/routing_index_manager.h"
#include "ortools/constraint_solver/routing_parameters.h"


using namespace concurrency;
using namespace operations_research;

std::vector<AcGePoint2d> TestSolver::solveTest(std::vector<AcGePoint2d> inputPoints, double scale)
{
    std::vector<AcGePoint2d> path;
    {
        PerfTimer timer(_T("Sort section"));
        std::sort(std::execution::par, inputPoints.begin(), inputPoints.end(), [](const AcGePoint2d& a, const AcGePoint2d& b) -> bool
        {
            if (a.y == b.y)
                return a.x < b.x;
            return a.y < b.y;
        });
        std::sort(std::execution::par, inputPoints.begin(), inputPoints.end(), [](const AcGePoint2d& a, const AcGePoint2d& b) -> bool
        {
            if (a.x == b.x)
                return a.y < b.y;
            return a.x < b.x;
        });
    }
    {
        std::vector<std::vector<int64_t>> distances = std::vector<std::vector<int64_t>>(inputPoints.size(), std::vector<int64_t>(inputPoints.size(), int64_t{ 0 }));
        {
            PerfTimer timer(_T("Distance section"));
            parallel_for(size_t(0), inputPoints.size(), [&](size_t fromNode)
            {
                for (size_t toNode = 0; toNode < inputPoints.size(); toNode++)
                {
                    if (fromNode != toNode)
                    {
                        distances[fromNode][toNode] = static_cast<int64_t>(inputPoints[fromNode].distanceTo(inputPoints[toNode]) * scale);
                    }
                }
            });
        }
        const operations_research::RoutingIndexManager::NodeIndex depot{ 0 };
        RoutingIndexManager manager(inputPoints.size(), 1, depot);
        RoutingModel routing(manager);

        int64_t transit_callback_index = routing.RegisterTransitCallback([&](int64_t from_index, int64_t to_index) -> int64_t
        {
            auto from_node = manager.IndexToNode(from_index).value();
            auto to_node = manager.IndexToNode(to_index).value();
            return distances[from_node][to_node];
        });
        routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index);
        RoutingSearchParameters searchParameters = DefaultRoutingSearchParameters();
        searchParameters.set_first_solution_strategy(FirstSolutionStrategy::AUTOMATIC);

        const Assignment* solution = routing.SolveWithParameters(searchParameters);
        if (solution != nullptr)
        {
            path.reserve(inputPoints.size());
            int64_t index = routing.Start(0);
            while (routing.IsEnd(index) == false)
            {
                path.emplace_back(inputPoints[manager.IndexToNode(index).value()]);
                int64_t previous_index = index;
                index = solution->Value(routing.NextVar(index));
            }
        }
        acutPrintf(_T("\nCount = %ld"), distances.size());
    }
    return path;
}

std::vector<AcGePoint2d> TestSolver::solveTest2(std::vector<AcGePoint2d> inputPoints, double scale)
{
    std::vector<AcGePoint2d> path;
    {
        PerfTimer timer(_T("Sort section"));
        std::sort(std::execution::par, inputPoints.begin(), inputPoints.end(), [](const AcGePoint2d& a, const AcGePoint2d& b) -> bool
        {
            if (a.y == b.y)
                return a.x < b.x;
            return a.y < b.y;
        });
        std::sort(std::execution::par, inputPoints.begin(), inputPoints.end(), [](const AcGePoint2d& a, const AcGePoint2d& b) -> bool
        {
            if (a.x == b.x)
                return a.y < b.y;
            return a.x < b.x;
        });
    }
    {
        std::unordered_map<size_t, int64_t> distmap;
        {
            PerfTimer timer(_T("Distance section"));
            for (size_t fromNode = 0; fromNode < inputPoints.size(); fromNode++)
            {
                for (size_t toNode = 0; toNode < inputPoints.size(); toNode++)
                {
                    if (fromNode != toNode)
                    {
                        size_t key = SizeTpairHasher::makehashKeyFromIndexes(fromNode, toNode, inputPoints.size());
                        if (!distmap.contains(key))
                            distmap.emplace(key, static_cast<int64_t>(inputPoints[fromNode].distanceTo(inputPoints[toNode]) * scale));
                    }
                }
            }
        }
        const operations_research::RoutingIndexManager::NodeIndex depot{ 0 };
        RoutingIndexManager manager(inputPoints.size(), 1, depot);
        RoutingModel routing(manager);

        int64_t transit_callback_index = routing.RegisterTransitCallback([&](int64_t from_index, int64_t to_index) -> int64_t
        {
            auto from_node = manager.IndexToNode(from_index).value();
            auto to_node = manager.IndexToNode(to_index).value();
            if (from_node == to_node)
                return 0;

            const size_t key = SizeTpairHasher::makehashKeyFromIndexes(from_node, to_node, inputPoints.size());
            if (distmap.contains(key))
                return distmap.at(key);

            acutPrintf(_T("*"));//cache miss
            return static_cast<int64_t>(inputPoints[from_node].distanceTo(inputPoints[to_node]) * scale);
        });
        routing.SetArcCostEvaluatorOfAllVehicles(transit_callback_index);
        RoutingSearchParameters searchParameters = DefaultRoutingSearchParameters();
        searchParameters.set_first_solution_strategy(FirstSolutionStrategy::AUTOMATIC);

        const Assignment* solution = routing.SolveWithParameters(searchParameters);
        if (solution != nullptr)
        {
            path.reserve(inputPoints.size());
            int64_t index = routing.Start(0);
            while (routing.IsEnd(index) == false)
            {
                path.emplace_back(inputPoints[manager.IndexToNode(index).value()]);
                int64_t previous_index = index;
                index = solution->Value(routing.NextVar(index));
            }
        }
        acutPrintf(_T("\nCount = %ld"), distmap.size());
    }
    return path;
}

int TestSolver::testSolverLispFunc()
{
    PerfTimer timer(_T("Total"));
    std::vector<AcGePoint2d> points;
    {
        PerfTimer timer(_T("Lisp section"));
        AcResBufPtr pArgs(acedGetArgs());
        for (resbuf* pTail = pArgs.get(); pTail != nullptr; pTail = pTail->rbnext)
        {
            switch (pTail->restype)
            {
                case RTPOINT:
                case RT3DPOINT:
                    points.push_back(asPnt2d(pTail->resval.rpoint));
                    break;
                default:
                    break;
            }
        }
    }
    if (points.size() > 2)
    {
        std::vector<AcGePoint2d> path{ solveTest(std::move(points), 100) };
        if (path.size())
        {
            AcResBufPtr pRes(acutNewRb(RTLB));
            resbuf* pResTail = pRes.get();
            for (auto& item : path)
            {
                pResTail = pResTail->rbnext = acutNewRb(RTPOINT);
                memcpy(pResTail->resval.rpoint, asDblArray(item), sizeof(pResTail->resval.rpoint));
            }
            pResTail = pResTail->rbnext = acutNewRb(RTLE);
            acedRetList(pRes.get());
        }
    }
    return (RSRSLT);
}

int TestSolver::testSolverLispFunc2()
{
    PerfTimer timer(_T("Total"));
    std::vector<AcGePoint2d> points;
    {
        PerfTimer timer(_T("Lisp section"));
        AcResBufPtr pArgs(acedGetArgs());
        for (resbuf* pTail = pArgs.get(); pTail != nullptr; pTail = pTail->rbnext)
        {
            switch (pTail->restype)
            {
                case RTPOINT:
                case RT3DPOINT:
                    points.push_back(asPnt2d(pTail->resval.rpoint));
                    break;
                default:
                    break;
            }
        }
    }
    if (points.size() > 2)
    {
        std::vector<AcGePoint2d> path{ solveTest2(std::move(points), 100) };
        if (path.size())
        {
            AcResBufPtr pRes(acutNewRb(RTLB));
            resbuf* pResTail = pRes.get();
            for (auto& item : path)
            {
                pResTail = pResTail->rbnext = acutNewRb(RTPOINT);
                memcpy(pResTail->resval.rpoint, asDblArray(item), sizeof(pResTail->resval.rpoint));
            }
            pResTail = pResTail->rbnext = acutNewRb(RTLE);
            acedRetList(pRes.get());
        }
    }
    return (RSRSLT);
}