#include <language/modules/DevUtilsModule.hpp>

#include <dev/ParallelChecker.hpp>
#include <language/utils/ASTDotPrinter.hpp>
#include <language/utils/ASTExecutionInfo.hpp>
#include <language/utils/ASTPrinter.hpp>
#include <language/utils/BuiltinFunctionEmbedder.hpp>
#include <language/utils/SymbolTable.hpp>

#include <fstream>

class DiscreteFunctionVariant;
template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const DiscreteFunctionVariant>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("Vh");

class ItemValueVariant;
template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemValueVariant>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_value");

class ItemArrayVariant;
template <>
inline ASTNodeDataType ast_node_data_type_from<std::shared_ptr<const ItemArrayVariant>> =
  ASTNodeDataType::build<ASTNodeDataType::type_id_t>("item_array");

DevUtilsModule::DevUtilsModule()
{
  this->_addBuiltinFunction("getAST", std::function(

                                        []() -> std::string {
                                          const auto& root_node = ASTExecutionInfo::current().rootNode();

                                          std::ostringstream os;
                                          os << ASTPrinter{root_node};

                                          return os.str();
                                        }

                                        ));

  this->_addBuiltinFunction("saveASTDot", std::function(

                                            [](const std::string& dot_filename) -> void {
                                              const auto& root_node = ASTExecutionInfo::current().rootNode();

                                              std::ofstream fout(dot_filename);

                                              if (not fout) {
                                                std::ostringstream os;
                                                os << "could not create file '" << dot_filename << "'\n";
                                                throw NormalError(os.str());
                                              }

                                              ASTDotPrinter dot_printer{root_node};
                                              fout << dot_printer;

                                              if (not fout) {
                                                std::ostringstream os;
                                                os << "could not write AST to '" << dot_filename << "'\n";
                                                throw NormalError(os.str());
                                              }
                                            }

                                            ));

  this->_addBuiltinFunction("getFunctionAST", std::function(

                                                [](const FunctionSymbolId& function_symbol_id) -> std::string {
                                                  const auto& function_descriptor = function_symbol_id.descriptor();

                                                  std::ostringstream os;
                                                  os << function_descriptor.name() << ": domain mapping\n";
                                                  os << ASTPrinter(function_descriptor.domainMappingNode());
                                                  os << function_descriptor.name() << ": definition\n";
                                                  os << ASTPrinter(function_descriptor.definitionNode());

                                                  return os.str();
                                                }

                                                ));

  this->_addBuiltinFunction("parallel_check",
                            std::function(

                              [](const std::shared_ptr<const DiscreteFunctionVariant>& discrete_function,
                                 const std::string& name) {
                                parallel_check(*discrete_function, name, ASTBacktrace::getInstance().sourceLocation());
                              }

                              ));

  this->_addBuiltinFunction("parallel_check",
                            std::function(

                              [](const std::shared_ptr<const ItemValueVariant>& item_value, const std::string& name) {
                                parallel_check(*item_value, name, ASTBacktrace::getInstance().sourceLocation());
                              }

                              ));
}

void
DevUtilsModule::registerOperators() const
{}