#include <language/modules/SchemeModule.hpp>

#include <analysis/GaussLegendreQuadratureDescriptor.hpp>
#include <analysis/GaussLobattoQuadratureDescriptor.hpp>
#include <analysis/GaussQuadratureDescriptor.hpp>
#include <language/modules/BinaryOperatorRegisterForVh.hpp>
#include <language/modules/MathFunctionRegisterForVh.hpp>
#include <language/modules/UnaryOperatorRegisterForVh.hpp>
#include <language/utils/BinaryOperatorProcessorBuilder.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/TypeDescriptor.hpp>
#include <mesh/Connectivity.hpp>
#include <mesh/IBoundaryDescriptor.hpp>
#include <mesh/IZoneDescriptor.hpp>
#include <mesh/Mesh.hpp>
#include <mesh/MeshData.hpp>
#include <mesh/MeshDataManager.hpp>
#include <mesh/MeshRandomizer.hpp>
#include <mesh/NumberedBoundaryDescriptor.hpp>
#include <scheme/AcousticSolver.hpp>
#include <scheme/AxisBoundaryConditionDescriptor.hpp>
#include <scheme/DirichletBoundaryConditionDescriptor.hpp>
#include <scheme/DiscreteFunctionDescriptorP0.hpp>
#include <scheme/DiscreteFunctionDescriptorP0Vector.hpp>
#include <scheme/DiscreteFunctionIntegrator.hpp>
#include <scheme/DiscreteFunctionInterpoler.hpp>
#include <scheme/DiscreteFunctionP0.hpp>
#include <scheme/DiscreteFunctionUtils.hpp>
#include <scheme/DiscreteFunctionVectorIntegrator.hpp>
#include <scheme/DiscreteFunctionVectorInterpoler.hpp>
#include <scheme/ExternalBoundaryConditionDescriptor.hpp>
#include <scheme/FixedBoundaryConditionDescriptor.hpp>
#include <scheme/FourierBoundaryConditionDescriptor.hpp>
#include <scheme/FreeBoundaryConditionDescriptor.hpp>
#include <scheme/IBoundaryConditionDescriptor.hpp>
#include <scheme/IDiscreteFunction.hpp>
#include <scheme/IDiscreteFunctionDescriptor.hpp>
#include <scheme/NeumannBoundaryConditionDescriptor.hpp>
#include <scheme/ScalarDiamondScheme.hpp>
#include <scheme/SymmetryBoundaryConditionDescriptor.hpp>
#include <scheme/VectorDiamondScheme.hpp>
#include <utils/Socket.hpp>

#include <memory>

SchemeModule::SchemeModule()
{
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IDiscreteFunction>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IDiscreteFunctionDescriptor>>);
  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IQuadratureDescriptor>>);

  this->_addTypeDescriptor(ast_node_data_type_from<std::shared_ptr<const IBoundaryConditionDescriptor>>);

  this->_addBuiltinFunction("P0", std::function(

                                    []() -> std::shared_ptr<const IDiscreteFunctionDescriptor> {
                                      return std::make_shared<DiscreteFunctionDescriptorP0>();
                                    }

                                    ));

  this->_addBuiltinFunction("P0Vector", std::function(

                                          []() -> std::shared_ptr<const IDiscreteFunctionDescriptor> {
                                            return std::make_shared<DiscreteFunctionDescriptorP0Vector>();
                                          }

                                          ));

  this->_addBuiltinFunction("Gauss", std::function(

                                       [](uint64_t degree) -> std::shared_ptr<const IQuadratureDescriptor> {
                                         return std::make_shared<GaussQuadratureDescriptor>(degree);
                                       }

                                       ));

  this->_addBuiltinFunction("GaussLobatto", std::function(

                                              [](uint64_t degree) -> std::shared_ptr<const IQuadratureDescriptor> {
                                                return std::make_shared<GaussLobattoQuadratureDescriptor>(degree);
                                              }

                                              ));

  this->_addBuiltinFunction("GaussLegendre", std::function(

                                               [](uint64_t degree) -> std::shared_ptr<const IQuadratureDescriptor> {
                                                 return std::make_shared<GaussLegendreQuadratureDescriptor>(degree);
                                               }

                                               ));

  this->_addBuiltinFunction("integrate",
                            std::function(

                              [](std::shared_ptr<const IMesh> mesh,
                                 const std::vector<std::shared_ptr<const IZoneDescriptor>>& integration_zone_list,
                                 std::shared_ptr<const IQuadratureDescriptor> quadrature_descriptor,
                                 std::shared_ptr<const IDiscreteFunctionDescriptor> discrete_function_descriptor,
                                 const std::vector<FunctionSymbolId>& function_id_list)
                                -> std::shared_ptr<const IDiscreteFunction> {
                                return DiscreteFunctionVectorIntegrator{mesh, integration_zone_list,
                                                                        quadrature_descriptor,
                                                                        discrete_function_descriptor, function_id_list}
                                  .integrate();
                              }

                              ));

  this->_addBuiltinFunction("integrate",
                            std::function(

                              [](std::shared_ptr<const IMesh> mesh,
                                 std::shared_ptr<const IQuadratureDescriptor> quadrature_descriptor,
                                 std::shared_ptr<const IDiscreteFunctionDescriptor> discrete_function_descriptor,
                                 const std::vector<FunctionSymbolId>& function_id_list)
                                -> std::shared_ptr<const IDiscreteFunction> {
                                return DiscreteFunctionVectorIntegrator{mesh, quadrature_descriptor,
                                                                        discrete_function_descriptor, function_id_list}
                                  .integrate();
                              }

                              ));

  this->_addBuiltinFunction(
    "integrate",
    std::function(

      [](std::shared_ptr<const IMesh> mesh,
         const std::vector<std::shared_ptr<const IZoneDescriptor>>& integration_zone_list,
         std::shared_ptr<const IQuadratureDescriptor> quadrature_descriptor,
         const FunctionSymbolId& function_id) -> std::shared_ptr<const IDiscreteFunction> {
        return DiscreteFunctionIntegrator{mesh, integration_zone_list, quadrature_descriptor, function_id}.integrate();
      }

      ));

  this->_addBuiltinFunction("integrate",
                            std::function(

                              [](std::shared_ptr<const IMesh> mesh,
                                 std::shared_ptr<const IQuadratureDescriptor> quadrature_descriptor,
                                 const FunctionSymbolId& function_id) -> std::shared_ptr<const IDiscreteFunction> {
                                return DiscreteFunctionIntegrator{mesh, quadrature_descriptor, function_id}.integrate();
                              }

                              ));

  this->_addBuiltinFunction("interpolate",
                            std::function(

                              [](std::shared_ptr<const IMesh> mesh,
                                 const std::vector<std::shared_ptr<const IZoneDescriptor>>& interpolation_zone_list,
                                 std::shared_ptr<const IDiscreteFunctionDescriptor> discrete_function_descriptor,
                                 const std::vector<FunctionSymbolId>& function_id_list)
                                -> std::shared_ptr<const IDiscreteFunction> {
                                switch (discrete_function_descriptor->type()) {
                                case DiscreteFunctionType::P0: {
                                  if (function_id_list.size() != 1) {
                                    throw NormalError("invalid function descriptor type");
                                  }
                                  return DiscreteFunctionInterpoler{mesh, interpolation_zone_list,
                                                                    discrete_function_descriptor, function_id_list[0]}
                                    .interpolate();
                                }
                                case DiscreteFunctionType::P0Vector: {
                                  return DiscreteFunctionVectorInterpoler{mesh, interpolation_zone_list,
                                                                          discrete_function_descriptor,
                                                                          function_id_list}
                                    .interpolate();
                                }
                                default: {
                                  throw NormalError("invalid function descriptor type");
                                }
                                }
                              }

                              ));

  this->_addBuiltinFunction(
    "interpolate",
    std::function(

      [](std::shared_ptr<const IMesh> mesh,
         std::shared_ptr<const IDiscreteFunctionDescriptor> discrete_function_descriptor,
         const std::vector<FunctionSymbolId>& function_id_list) -> std::shared_ptr<const IDiscreteFunction> {
        switch (discrete_function_descriptor->type()) {
        case DiscreteFunctionType::P0: {
          if (function_id_list.size() != 1) {
            throw NormalError("invalid function descriptor type");
          }
          return DiscreteFunctionInterpoler{mesh, discrete_function_descriptor, function_id_list[0]}.interpolate();
        }
        case DiscreteFunctionType::P0Vector: {
          return DiscreteFunctionVectorInterpoler{mesh, discrete_function_descriptor, function_id_list}.interpolate();
        }
        default: {
          throw NormalError("invalid function descriptor type");
        }
        }
      }

      ));

  this->_addBuiltinFunction("randomizeMesh",
                            std::function(

                              [](std::shared_ptr<const IMesh> p_mesh,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list) -> std::shared_ptr<const IMesh> {
                                MeshRandomizerHandler handler;
                                return handler.getRandomizedMesh(*p_mesh, bc_descriptor_list);
                              }

                              ));

  this->_addBuiltinFunction("randomizeMesh",
                            std::function(

                              [](std::shared_ptr<const IMesh> p_mesh,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list,
                                 const FunctionSymbolId& function_symbol_id) -> std::shared_ptr<const IMesh> {
                                MeshRandomizerHandler handler;
                                return handler.getRandomizedMesh(*p_mesh, bc_descriptor_list, function_symbol_id);
                              }

                              ));

  this->_addBuiltinFunction("fixed", std::function(

                                       [](std::shared_ptr<const IBoundaryDescriptor> boundary)
                                         -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                         return std::make_shared<FixedBoundaryConditionDescriptor>(boundary);
                                       }

                                       ));

  this->_addBuiltinFunction("axis", std::function(

                                      [](std::shared_ptr<const IBoundaryDescriptor> boundary)
                                        -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                        return std::make_shared<AxisBoundaryConditionDescriptor>(boundary);
                                      }

                                      ));

  this->_addBuiltinFunction("symmetry", std::function(

                                          [](std::shared_ptr<const IBoundaryDescriptor> boundary)
                                            -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                            return std::make_shared<SymmetryBoundaryConditionDescriptor>(boundary);
                                          }

                                          ));

  this->_addBuiltinFunction("pressure", std::function(

                                          [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                             const FunctionSymbolId& pressure_id)
                                            -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                            return std::make_shared<DirichletBoundaryConditionDescriptor>("pressure",
                                                                                                          boundary,
                                                                                                          pressure_id);
                                          }

                                          ));

  this->_addBuiltinFunction("velocity", std::function(

                                          [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                             const FunctionSymbolId& velocity_id)
                                            -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                            return std::make_shared<DirichletBoundaryConditionDescriptor>("velocity",
                                                                                                          boundary,
                                                                                                          velocity_id);
                                          }

                                          ));

  this->_addBuiltinFunction("dirichlet",
                            std::function(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                 const FunctionSymbolId& g_id) -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<DirichletBoundaryConditionDescriptor>("dirichlet", boundary,
                                                                                              g_id);
                              }

                              ));

  this->_addBuiltinFunction("normal_strain",
                            std::function(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                 const FunctionSymbolId& g_id) -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<DirichletBoundaryConditionDescriptor>("normal_strain", boundary,
                                                                                              g_id);
                              }

                              ));

  this->_addBuiltinFunction("fourier",
                            std::function(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary, const FunctionSymbolId& alpha_id,
                                 const FunctionSymbolId& g_id) -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<FourierBoundaryConditionDescriptor>("fourier", boundary,
                                                                                            alpha_id, g_id);
                              }

                              ));

  this->_addBuiltinFunction("neumann",
                            std::function(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                 const FunctionSymbolId& g_id) -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<NeumannBoundaryConditionDescriptor>("neumann", boundary, g_id);
                              }

                              ));

  this->_addBuiltinFunction("external_fsi_velocity",
                            std::function(

                              [](std::shared_ptr<const IBoundaryDescriptor> boundary,
                                 const std::shared_ptr<const Socket>& socket)
                                -> std::shared_ptr<const IBoundaryConditionDescriptor> {
                                return std::make_shared<ExternalBoundaryConditionDescriptor>("external_fsi_velocity",
                                                                                             boundary, socket);
                              }

                              ));

  this->_addBuiltinFunction("glace_fluxes", std::function(

                                              [](const std::shared_ptr<const IDiscreteFunction>& rho,
                                                 const std::shared_ptr<const IDiscreteFunction>& u,
                                                 const std::shared_ptr<const IDiscreteFunction>& c,
                                                 const std::shared_ptr<const IDiscreteFunction>& p,
                                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                                   bc_descriptor_list)
                                                -> std::tuple<std::shared_ptr<const ItemValueVariant>,
                                                              std::shared_ptr<const SubItemValuePerItemVariant>> {
                                                return AcousticSolverHandler{getCommonMesh({rho, c, u, p})}
                                                  .solver()
                                                  .compute_fluxes(AcousticSolverHandler::SolverType::Glace, rho, c, u,
                                                                  p, bc_descriptor_list);
                                              }

                                              ));

  this->_addBuiltinFunction("glace_solver",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction>& rho,
                                 const std::shared_ptr<const IDiscreteFunction>& u,
                                 const std::shared_ptr<const IDiscreteFunction>& E,
                                 const std::shared_ptr<const IDiscreteFunction>& c,
                                 const std::shared_ptr<const IDiscreteFunction>& p,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list,
                                 const double& dt)
                                -> std::tuple<std::shared_ptr<const IMesh>, std::shared_ptr<const IDiscreteFunction>,
                                              std::shared_ptr<const IDiscreteFunction>,
                                              std::shared_ptr<const IDiscreteFunction>> {
                                return AcousticSolverHandler{getCommonMesh({rho, u, E, c, p})}
                                  .solver()
                                  .apply(AcousticSolverHandler::SolverType::Glace, dt, rho, u, E, c, p,
                                         bc_descriptor_list);
                              }

                              ));

  this->_addBuiltinFunction("eucclhyd_fluxes",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction>& rho,
                                 const std::shared_ptr<const IDiscreteFunction>& u,
                                 const std::shared_ptr<const IDiscreteFunction>& c,
                                 const std::shared_ptr<const IDiscreteFunction>& p,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list)
                                -> std::tuple<std::shared_ptr<const ItemValueVariant>,
                                              std::shared_ptr<const SubItemValuePerItemVariant>> {
                                return AcousticSolverHandler{getCommonMesh({rho, c, u, p})}
                                  .solver()
                                  .compute_fluxes(AcousticSolverHandler::SolverType::Eucclhyd, rho, c, u, p,
                                                  bc_descriptor_list);
                              }

                              ));

  this->_addBuiltinFunction("eucclhyd_solver",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction>& rho,
                                 const std::shared_ptr<const IDiscreteFunction>& u,
                                 const std::shared_ptr<const IDiscreteFunction>& E,
                                 const std::shared_ptr<const IDiscreteFunction>& c,
                                 const std::shared_ptr<const IDiscreteFunction>& p,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list,
                                 const double& dt)
                                -> std::tuple<std::shared_ptr<const IMesh>, std::shared_ptr<const IDiscreteFunction>,
                                              std::shared_ptr<const IDiscreteFunction>,
                                              std::shared_ptr<const IDiscreteFunction>> {
                                return AcousticSolverHandler{getCommonMesh({rho, u, E, c, p})}
                                  .solver()
                                  .apply(AcousticSolverHandler::SolverType::Eucclhyd, dt, rho, u, E, c, p,
                                         bc_descriptor_list);
                              }

                              ));

  this->_addBuiltinFunction("apply_acoustic_fluxes",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction>& rho,            //
                                 const std::shared_ptr<const IDiscreteFunction>& u,              //
                                 const std::shared_ptr<const IDiscreteFunction>& E,              //
                                 const std::shared_ptr<const ItemValueVariant>& ur,              //
                                 const std::shared_ptr<const SubItemValuePerItemVariant>& Fjr,   //
                                 const double& dt)
                                -> std::tuple<std::shared_ptr<const IMesh>, std::shared_ptr<const IDiscreteFunction>,
                                              std::shared_ptr<const IDiscreteFunction>,
                                              std::shared_ptr<const IDiscreteFunction>> {
                                return AcousticSolverHandler{getCommonMesh({rho, u, E})}   //
                                  .solver()
                                  .apply_fluxes(dt, rho, u, E, ur, Fjr);
                              }

                              ));

  this->_addBuiltinFunction(
    "parabolicheat",
    std::function(

      [](const std::shared_ptr<const IDiscreteFunction>& alpha,
         const std::shared_ptr<const IDiscreteFunction>& mub_dual,
         const std::shared_ptr<const IDiscreteFunction>& mu_dual, const std::shared_ptr<const IDiscreteFunction>& f,
         const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>& bc_descriptor_list)
        -> std::shared_ptr<const IDiscreteFunction> {
        return ScalarDiamondSchemeHandler{alpha, mub_dual, mu_dual, f, bc_descriptor_list}.solution();
      }

      ));

  this->_addBuiltinFunction("unsteadyelasticity",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction> alpha,
                                 const std::shared_ptr<const IDiscreteFunction> lambdab,
                                 const std::shared_ptr<const IDiscreteFunction> mub,
                                 const std::shared_ptr<const IDiscreteFunction> lambda,
                                 const std::shared_ptr<const IDiscreteFunction> mu,
                                 const std::shared_ptr<const IDiscreteFunction> f,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list) -> std::shared_ptr<const IDiscreteFunction> {
                                return VectorDiamondSchemeHandler{alpha, lambdab,           mub, lambda, mu,
                                                                  f,     bc_descriptor_list}
                                  .solution();
                              }

                              ));

  this->_addBuiltinFunction("moleculardiffusion",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction> alpha,
                                 const std::shared_ptr<const IDiscreteFunction> lambdab,
                                 const std::shared_ptr<const IDiscreteFunction> mub,
                                 const std::shared_ptr<const IDiscreteFunction> lambda,
                                 const std::shared_ptr<const IDiscreteFunction> mu,
                                 const std::shared_ptr<const IDiscreteFunction> f,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list) -> std::tuple<std::shared_ptr<const IDiscreteFunction>,
                                                                     std::shared_ptr<const IDiscreteFunction>> {
                                return VectorDiamondSchemeHandler{alpha, lambdab,           mub, lambda, mu,
                                                                  f,     bc_descriptor_list}
                                  .apply();
                              }

                              ));

  this->_addBuiltinFunction("energybalance",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction> lambdab,
                                 const std::shared_ptr<const IDiscreteFunction> mub,
                                 const std::shared_ptr<const IDiscreteFunction> U,
                                 const std::shared_ptr<const IDiscreteFunction> dual_U,
                                 const std::shared_ptr<const IDiscreteFunction> source,
                                 const std::vector<std::shared_ptr<const IBoundaryConditionDescriptor>>&
                                   bc_descriptor_list) -> std::tuple<std::shared_ptr<const IDiscreteFunction>,
                                                                     std::shared_ptr<const IDiscreteFunction>> {
                                return EnergyComputerHandler{lambdab, mub, U, dual_U, source, bc_descriptor_list}
                                  .computeEnergyUpdate();
                              }

                              ));

  this->_addBuiltinFunction("lagrangian",
                            std::function(

                              [](const std::shared_ptr<const IMesh>& mesh,
                                 const std::shared_ptr<const IDiscreteFunction>& v)
                                -> std::shared_ptr<const IDiscreteFunction> { return shallowCopy(mesh, v); }

                              ));

  this->_addBuiltinFunction("acoustic_dt",
                            std::function(

                              [](const std::shared_ptr<const IDiscreteFunction>& c) -> double { return acoustic_dt(c); }

                              ));

  this
    ->_addBuiltinFunction("cell_volume",
                          std::function(

                            [](const std::shared_ptr<const IMesh>& i_mesh) -> std::shared_ptr<const IDiscreteFunction> {
                              switch (i_mesh->dimension()) {
                              case 1: {
                                constexpr size_t Dimension = 1;
                                using MeshType             = Mesh<Connectivity<Dimension>>;
                                std::shared_ptr<const MeshType> mesh =
                                  std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);

                                return std::make_shared<const DiscreteFunctionP0<
                                  Dimension, double>>(mesh, copy(MeshDataManager::instance().getMeshData(*mesh).Vj()));
                              }
                              case 2: {
                                constexpr size_t Dimension = 2;
                                using MeshType             = Mesh<Connectivity<Dimension>>;
                                std::shared_ptr<const MeshType> mesh =
                                  std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);

                                return std::make_shared<const DiscreteFunctionP0<
                                  Dimension, double>>(mesh, copy(MeshDataManager::instance().getMeshData(*mesh).Vj()));
                              }
                              case 3: {
                                constexpr size_t Dimension = 3;
                                using MeshType             = Mesh<Connectivity<Dimension>>;
                                std::shared_ptr<const MeshType> mesh =
                                  std::dynamic_pointer_cast<const Mesh<Connectivity<Dimension>>>(i_mesh);

                                return std::make_shared<const DiscreteFunctionP0<
                                  Dimension, double>>(mesh, copy(MeshDataManager::instance().getMeshData(*mesh).Vj()));
                              }
                              default: {
                                throw UnexpectedError("invalid mesh dimension");
                              }
                              }
                            }

                            ));

  MathFunctionRegisterForVh{*this};
}

void
SchemeModule::registerOperators() const
{
  BinaryOperatorRegisterForVh{};
  UnaryOperatorRegisterForVh{};
}