#include <language/ast/ASTNodeBuiltinFunctionExpressionBuilder.hpp>

#include <language/PEGGrammar.hpp>
#include <language/ast/ASTNodeDataTypeFlattener.hpp>
#include <language/node_processor/BuiltinFunctionProcessor.hpp>
#include <language/utils/ASTNodeNaturalConversionChecker.hpp>
#include <language/utils/ParseError.hpp>
#include <language/utils/SymbolTable.hpp>

PUGS_INLINE std::unique_ptr<IFunctionArgumentConverter>
ASTNodeBuiltinFunctionExpressionBuilder::_getArgumentConverter(const ASTNodeDataType& parameter_type,
                                                               const ASTNodeSubDataType& argument_node_sub_data_type,
                                                               const size_t argument_number)
{
  auto get_function_argument_converter_for =
    [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> {
    using ParameterT = std::decay_t<decltype(parameter_v)>;
    switch (argument_node_sub_data_type.m_data_type) {
    case ASTNodeDataType::bool_t: {
      return std::make_unique<FunctionArgumentConverter<ParameterT, bool>>(argument_number);
    }
    case ASTNodeDataType::unsigned_int_t: {
      return std::make_unique<FunctionArgumentConverter<ParameterT, uint64_t>>(argument_number);
    }
    case ASTNodeDataType::int_t: {
      return std::make_unique<FunctionArgumentConverter<ParameterT, int64_t>>(argument_number);
    }
    case ASTNodeDataType::double_t: {
      return std::make_unique<FunctionArgumentConverter<ParameterT, double>>(argument_number);
    }
      // LCOV_EXCL_START
    default: {
      throw ParseError("unexpected error: invalid argument type for function",
                       std::vector{argument_node_sub_data_type.m_parent_node.begin()});
    }
      // LCOV_EXCL_STOP
    }
  };

  auto get_function_argument_converter_for_vector =
    [&](const auto& parameter_v) -> std::unique_ptr<IFunctionArgumentConverter> {
    using ParameterT = std::decay_t<decltype(parameter_v)>;

    if constexpr (std::is_same_v<ParameterT, TinyVector<1>>) {
      switch (argument_node_sub_data_type.m_data_type) {
      case ASTNodeDataType::vector_t: {
        if (argument_node_sub_data_type.m_data_type.dimension() == 1) {
          return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ParameterT>>(argument_number);
        } else {
          // LCOV_EXCL_START
          throw ParseError("unexpected error: invalid argument dimension",
                           std::vector{argument_node_sub_data_type.m_parent_node.begin()});
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::bool_t: {
        return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, bool>>(argument_number);
      }
      case ASTNodeDataType::int_t: {
        return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, int64_t>>(argument_number);
      }
      case ASTNodeDataType::unsigned_int_t: {
        return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, uint64_t>>(argument_number);
      }
      case ASTNodeDataType::double_t: {
        return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, double>>(argument_number);
      }
        // LCOV_EXCL_START
      default: {
        throw ParseError("unexpected error: invalid argument type",
                         std::vector{argument_node_sub_data_type.m_parent_node.begin()});
      }
        // LCOV_EXCL_STOP
      }
    } else {
      switch (argument_node_sub_data_type.m_data_type) {
      case ASTNodeDataType::vector_t: {
        if (argument_node_sub_data_type.m_data_type.dimension() == parameter_v.dimension()) {
          return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ParameterT>>(argument_number);
        } else {
          // LCOV_EXCL_START
          throw ParseError("unexpected error: invalid argument dimension",
                           std::vector{argument_node_sub_data_type.m_parent_node.begin()});
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::list_t: {
        if (argument_node_sub_data_type.m_parent_node.children.size() == parameter_v.dimension()) {
          return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ParameterT>>(argument_number);
        } else {
          // LCOV_EXCL_START
          throw ParseError("unexpected error: invalid argument dimension",
                           std::vector{argument_node_sub_data_type.m_parent_node.begin()});
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::int_t: {
        if (argument_node_sub_data_type.m_parent_node.is_type<language::integer>()) {
          if (std::stoi(argument_node_sub_data_type.m_parent_node.string()) == 0) {
            return std::make_unique<FunctionTinyVectorArgumentConverter<ParameterT, ZeroType>>(argument_number);
          }
        }
        [[fallthrough]];
      }
        // LCOV_EXCL_START
      default: {
        throw ParseError("unexpected error: invalid argument type",
                         std::vector{argument_node_sub_data_type.m_parent_node.begin()});
      }
        // LCOV_EXCL_STOP
      }
    }
  };

  auto get_function_argument_to_string_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
    return std::make_unique<FunctionArgumentToStringConverter>(argument_number);
  };

  auto get_function_argument_to_type_id_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
    switch (argument_node_sub_data_type.m_data_type) {
    case ASTNodeDataType::type_id_t: {
      return std::make_unique<FunctionArgumentConverter<EmbeddedData, EmbeddedData>>(argument_number);
    }
      // LCOV_EXCL_START
    default: {
      throw ParseError("unexpected error: invalid argument type for function",
                       std::vector{argument_node_sub_data_type.m_parent_node.begin()});
    }
      // LCOV_EXCL_STOP
    }
  };

  auto get_function_argument_to_tuple_converter =
    [&](const auto& parameter_content_v) -> std::unique_ptr<IFunctionArgumentConverter> {
    using ParameterContentT   = std::decay_t<decltype(parameter_content_v)>;
    const auto& arg_data_type = argument_node_sub_data_type.m_data_type;
    switch (arg_data_type) {
    case ASTNodeDataType::tuple_t: {
      const auto& tuple_content_type = arg_data_type.contentType();
      switch (tuple_content_type) {
      case ASTNodeDataType::type_id_t: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, EmbeddedData>>(argument_number);
      }
      case ASTNodeDataType::bool_t: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, bool>>(argument_number);
      }
      case ASTNodeDataType::unsigned_int_t: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, uint64_t>>(argument_number);
      }
      case ASTNodeDataType::int_t: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, int64_t>>(argument_number);
      }
      case ASTNodeDataType::double_t: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, double>>(argument_number);
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError(dataTypeName(tuple_content_type) + " unexpected tuple content type of argument ");
      }
        // LCOV_EXCL_STOP
      }
    }
    case ASTNodeDataType::list_t: {
      return std::make_unique<FunctionListArgumentConverter<ParameterContentT, ParameterContentT>>(argument_number);
    }
    case ASTNodeDataType::type_id_t: {
      return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, EmbeddedData>>(argument_number);
    }
    case ASTNodeDataType::bool_t: {
      return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, bool>>(argument_number);
    }
    case ASTNodeDataType::unsigned_int_t: {
      return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, uint64_t>>(argument_number);
    }
    case ASTNodeDataType::int_t: {
      return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, int64_t>>(argument_number);
    }
    case ASTNodeDataType::double_t: {
      return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, double>>(argument_number);
    }
    case ASTNodeDataType::vector_t: {
      switch (arg_data_type.dimension()) {
      case 1: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyVector<1>>>(argument_number);
      }
      case 2: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyVector<2>>>(argument_number);
      }
      case 3: {
        return std::make_unique<FunctionTupleArgumentConverter<ParameterContentT, TinyVector<3>>>(argument_number);
      }
        // LCOV_EXCL_START
      default: {
        throw UnexpectedError(dataTypeName(arg_data_type) + " unexpected dimension of vector");
      }
        // LCOV_EXCL_STOP
      }
    }
      // LCOV_EXCL_START
    default: {
      throw UnexpectedError(dataTypeName(arg_data_type) + " argument to tuple ");
    }
      // LCOV_EXCL_STOP
    }
  };

  auto get_function_argument_to_function_id_converter = [&]() -> std::unique_ptr<IFunctionArgumentConverter> {
    switch (argument_node_sub_data_type.m_data_type) {
    case ASTNodeDataType::function_t: {
      const ASTNode& parent_node = argument_node_sub_data_type.m_parent_node;
      auto symbol_table          = parent_node.m_symbol_table;

      return std::make_unique<FunctionArgumentToFunctionSymbolIdConverter>(argument_number, symbol_table);
    }
      // LCOV_EXCL_START
    default: {
      throw ParseError("unexpected error: invalid argument type for function",
                       std::vector{argument_node_sub_data_type.m_parent_node.begin()});
    }
      // LCOV_EXCL_STOP
    }
  };

  auto get_function_argument_converter_for_argument_type = [&]() {
    switch (parameter_type) {
    case ASTNodeDataType::bool_t: {
      return get_function_argument_converter_for(bool{});
    }
    case ASTNodeDataType::unsigned_int_t: {
      return get_function_argument_converter_for(uint64_t{});
    }
    case ASTNodeDataType::int_t: {
      return get_function_argument_converter_for(int64_t{});
    }
    case ASTNodeDataType::double_t: {
      return get_function_argument_converter_for(double{});
    }
    case ASTNodeDataType::vector_t: {
      switch (parameter_type.dimension()) {
      case 1: {
        return get_function_argument_converter_for_vector(TinyVector<1>{});
      }
      case 2: {
        return get_function_argument_converter_for_vector(TinyVector<2>{});
      }
      case 3: {
        return get_function_argument_converter_for_vector(TinyVector<3>{});
      }
        // LCOV_EXCL_START
      default: {
        throw ParseError("unexpected error: undefined parameter type for function",
                         std::vector{argument_node_sub_data_type.m_parent_node.begin()});
      }
        // LCOV_EXCL_STOP
      }
    }
    case ASTNodeDataType::string_t: {
      return get_function_argument_to_string_converter();
    }
    case ASTNodeDataType::type_id_t: {
      return get_function_argument_to_type_id_converter();
    }
    case ASTNodeDataType::function_t: {
      return get_function_argument_to_function_id_converter();
    }
    case ASTNodeDataType::tuple_t: {
      switch (parameter_type.contentType()) {
      case ASTNodeDataType::type_id_t: {
        return get_function_argument_to_tuple_converter(EmbeddedData{});
      }
      case ASTNodeDataType::bool_t: {
        return get_function_argument_to_tuple_converter(bool{});
      }
      case ASTNodeDataType::unsigned_int_t: {
        return get_function_argument_to_tuple_converter(uint64_t{});
      }
      case ASTNodeDataType::int_t: {
        return get_function_argument_to_tuple_converter(int64_t{});
      }
      case ASTNodeDataType::double_t: {
        return get_function_argument_to_tuple_converter(double{});
      }
      case ASTNodeDataType::vector_t: {
        switch (parameter_type.contentType().dimension()) {
        case 1: {
          return get_function_argument_to_tuple_converter(TinyVector<1>{});
        }
        case 2: {
          return get_function_argument_to_tuple_converter(TinyVector<2>{});
        }
        case 3: {
          return get_function_argument_to_tuple_converter(TinyVector<3>{});
        }
        // LCOV_EXCL_START
        default: {
          throw ParseError("unexpected error: unexpected tuple content for function: '" + dataTypeName(parameter_type) +
                             "'",
                           std::vector{argument_node_sub_data_type.m_parent_node.begin()});
        }
          // LCOV_EXCL_STOP
        }
      }
      case ASTNodeDataType::string_t: {
        return get_function_argument_to_string_converter();
      }
        // LCOV_EXCL_START
      default: {
        throw ParseError("unexpected error: unexpected tuple content type for function",
                         std::vector{argument_node_sub_data_type.m_parent_node.begin()});
      }
        // LCOV_EXCL_STOP
      }
    }
      // LCOV_EXCL_START
    default: {
      throw ParseError("unexpected error: undefined parameter type for function",
                       std::vector{argument_node_sub_data_type.m_parent_node.begin()});
    }
      // LCOV_EXCL_STOP
    }
  };

  ASTNodeNaturalConversionChecker<AllowRToR1Conversion>{argument_node_sub_data_type, parameter_type};

  return get_function_argument_converter_for_argument_type();
}

PUGS_INLINE
void
ASTNodeBuiltinFunctionExpressionBuilder::_storeArgumentProcessor(
  const std::vector<ASTNodeDataType>& parameter_type_list,
  const ASTNodeDataTypeFlattener::FlattenedDataTypeList& flattened_datatype_list,
  const size_t argument_number,
  BuiltinFunctionProcessor& processor)
{
  processor.addArgumentConverter(this->_getArgumentConverter(parameter_type_list[argument_number],
                                                             flattened_datatype_list[argument_number],
                                                             argument_number));
}

PUGS_INLINE
void
ASTNodeBuiltinFunctionExpressionBuilder::_buildArgumentProcessors(
  const std::vector<ASTNodeDataType>& parameter_type_list,
  ASTNode& node,
  BuiltinFunctionProcessor& processor)
{
  ASTNode& argument_nodes = *node.children[1];

  ASTNodeDataTypeFlattener::FlattenedDataTypeList flattened_datatype_list;
  ASTNodeDataTypeFlattener{argument_nodes, flattened_datatype_list};

  const size_t arguments_number  = flattened_datatype_list.size();
  const size_t parameters_number = parameter_type_list.size();

  if (arguments_number != parameters_number) {
    std::ostringstream error_message;
    error_message << "bad number of arguments: expecting " << rang::fgB::yellow << parameters_number
                  << rang::style::reset << ", provided " << rang::fgB::yellow << arguments_number << rang::style::reset;
    throw ParseError(error_message.str(), argument_nodes.begin());
  }

  for (size_t i = 0; i < arguments_number; ++i) {
    this->_storeArgumentProcessor(parameter_type_list, flattened_datatype_list, i, processor);
  }
}

ASTNodeBuiltinFunctionExpressionBuilder::ASTNodeBuiltinFunctionExpressionBuilder(ASTNode& node)
{
  auto [i_function_symbol, found] = node.m_symbol_table->find(node.children[0]->string(), node.begin());
  Assert(found);
  Assert(i_function_symbol->attributes().dataType() == ASTNodeDataType::builtin_function_t);

  uint64_t builtin_function_id = std::get<uint64_t>(i_function_symbol->attributes().value());

  auto& builtin_function_embedder_table     = node.m_symbol_table->builtinFunctionEmbedderTable();
  std::shared_ptr builtin_function_embedder = builtin_function_embedder_table[builtin_function_id];

  std::vector<ASTNodeDataType> builtin_function_parameter_type_list =
    builtin_function_embedder->getParameterDataTypes();

  ASTNode& argument_nodes                    = *node.children[1];
  std::unique_ptr builtin_function_processor = std::make_unique<BuiltinFunctionProcessor>(argument_nodes);

  this->_buildArgumentProcessors(builtin_function_parameter_type_list, node, *builtin_function_processor);

  builtin_function_processor->setFunctionExpressionProcessor(
    std::make_unique<BuiltinFunctionExpressionProcessor>(builtin_function_embedder));

  node.m_node_processor = std::move(builtin_function_processor);
}