#include <stdio.h>
#include <stdlib.h>
#include "SS_Assembler.hpp"
#include "SS_Internal.hpp"
#include "MemoryBuffer.hpp"


#define MAX_LINE 4096


// ---- BEGIN HASH TABLE ---- //

struct HashNode
{
  int index;
  HashNode* next;
};

// this memory is created once, and is freed at program exit
static bool HashTablesCreated = false;
static HashNode* OpcodeHashTable[256];
static HashNode* SystemFunctionHashTable[256];

static void CreateHashTables()
{
  // OPCODES
  for (int i = 0; i < g_NumOpcodes; i++)
  {
    const SS_OPCODE* opcode = g_Opcodes + i;

    // calculate index of opcode into hash
    int index = 0;
    const char* p = opcode->name;
    while (*p)
      index += *p++;
    index %= 256;

    HashNode* prev = NULL;
    HashNode* q = OpcodeHashTable[index];
    while (q)
    {
      prev = q;
      q = q->next;
    }

    HashNode* node = new HashNode;
    node->next = NULL;
    node->index = i;

    if (prev == NULL)
      OpcodeHashTable[index] = node;
    else
      prev->next = node;
  }

  // SYSTEM FUNCTIONS
  for (int i = 0; i < g_NumSystemFunctions; i++)
  {
    const SS_SYSTEM_FUNCTION* function = g_SystemFunctions + i;

    // calculate index of opcode into hash
    int index = 0;
    const char* p = function->name;
    while (*p)
      index += *p++;
    index %= 256;

    HashNode* prev = NULL;
    HashNode* q = SystemFunctionHashTable[index];
    while (q)
    {
      prev = q;
      q = q->next;
    }

    HashNode* node = new HashNode;
    node->next = NULL;
    node->index = i;

    if (prev == NULL)
      SystemFunctionHashTable[index] = node;
    else
      prev->next = node;
  }
};

// ---- END HASH TABLE ---- //


// pseudo-mnemonics
//   depend
//   string
//   global
//   function
//   endfunction
//   parameter
//   lc
//   ld
//   label


////////////////////////////////////////////////////////////////////////////////

sAssembler::sAssembler()
: Pass(NO_PASS)
, inbuffer(NULL)
, inbufferlocation(0)
, outbuffer(NULL)
, CurrentLocation(0)
, CurrentLine(0)
{
  if (!HashTablesCreated)
  {
    HashTablesCreated = true;
    CreateHashTables();
  }
}

////////////////////////////////////////////////////////////////////////////////

sAssembler::~sAssembler()
{
}

////////////////////////////////////////////////////////////////////////////////

void
sAssembler::Execute(CBuffer* in, CBuffer* out)
{
  in->Seek(0);
  inbuffer = new char[in->Size()];
  in->Read(in->Size(), (byte*)inbuffer);
  inbufferlocation = 0;
  inbuffersize = in->Size();

  // create an output cache
  outbuffer = new CMemoryBuffer;

  // assemble the file
  Dependancies.clear();
  FirstPass();
  SecondPass();

  // copy the memory buffer (cache) to the output file
  byte* bytes = new byte[outbuffer->Size()];
  outbuffer->Seek(0);
  outbuffer->Read(outbuffer->Size(), bytes);
  out->Write(outbuffer->Size(), bytes);
  delete[] bytes;

  // free temporary buffers
  delete[] inbuffer;
  delete outbuffer;
}

////////////////////////////////////////////////////////////////////////////////

void
sAssembler::FirstPass()
{
  inbufferlocation = 0;

  CurrentLocation = 0;
  CurrentLine = 0;

  // read a line and tokenize it
  sString line;
  while (ReadLine(line))
  {
    CurrentLine++;

    sVector<sString> tokens;
    Tokenize(line, tokens);

    // if it's an empty line, skip it
    if (tokens.size() == 0)
      continue;

    const SS_OPCODE* opcode = GetOpcode(tokens[0].c_str());
    VerifyNumParameters(tokens, opcode);

    // dependancy
    if (tokens[0] == "depend")
    {
      AddDependancy(tokens[1] + ".ssx");
    }

    // define global variable
    else if (tokens[0] == "global")
    {
      Globals.push_back(Label(tokens[1], CurrentLocation));
      CurrentLocation++;
    }

    // string literal
    else if (tokens[0] == "string")
    {
      StringLiterals.push_back(Label(tokens[1], CurrentLocation));

      int string_length = tokens[2].length() - 2 + 1; // remove quotations and add terminating zero
      int aligned_length = (string_length + 3) / 4;   // everything in SS is dword-aligned
      CurrentLocation += aligned_length;
    }

    // begin function
    else if (tokens[0] == "function")
    {
      Functions.push_back(Label(tokens[1], CurrentLocation));
    }

    // end function
    else if (tokens[0] == "endfunction")
    {
      Parameters.clear();
      CurrentLocation++;  // ret instruction
    }

    // define function parameter
    else if (tokens[0] == "parameter")
    {
      Parameters.push_back(tokens[1]);
      CurrentLocation++;
    }

    // create local
    else if (tokens[0] == "lc")
    {
      Locals.push_back(tokens[1]);
      CurrentLocation++;
    }

    // destroy local
    else if (tokens[0] == "ld")
    {
      if (Locals[Locals.size() - 1] != tokens[1])
        Error("Must destroy locals in reverse order");
      Locals.pop_back();
      CurrentLocation++;
    }

    // local label
    else if (tokens[0] == "label")
    {
      Labels.push_back(Label(tokens[1], CurrentLocation));
    }

    // opcodes
    else if (opcode)
    {
      CurrentLocation++;  // opcode

      // calculate operand sizes
      for (int i = 0; i < opcode->num_parameters; i++)
        CurrentLocation += GetOperandSize(tokens[i + 1]);
    }

    else
    {
      Error("Unknown line");
    }

  }
}

////////////////////////////////////////////////////////////////////////////////

void
sAssembler::SecondPass()
{
  inbufferlocation = 0;

  // write the output file header
  outbuffer->Write(4, (byte*)".ssx");        // signature
  outbuffer->WriteWord(2);                   // version
  outbuffer->WriteWord(Functions.size());    // num_labels
  outbuffer->WriteDword(CurrentLocation);    // codesize
  outbuffer->WriteWord(Dependancies.size()); // num_dependancies
  // reserved space
  for (int i = 0; i < 50; i++)
    outbuffer->WriteByte(0);

  // write function labels
  for (int i = 0; i < Functions.size(); i++)
  {
    outbuffer->WriteWord(Functions[i].name.length() + 1);
    outbuffer->Write(Functions[i].name.length() + 1, (const byte*)Functions[i].name.c_str());
    outbuffer->WriteDword(Functions[i].location);
  }

  // write dependancies
  for (int i = 0; i < Dependancies.size(); i++)
  {
    outbuffer->WriteWord(Dependancies[i].length() + 1);
    outbuffer->Write(Dependancies[i].length() + 1, (const byte*)Dependancies[i].c_str());
  }

  CurrentLocation = 0;
  CurrentLine = 0;

  // read a line and tokenize it
  sString line;
  while (ReadLine(line))
  {
    CurrentLine++;

    sVector<sString> tokens;
    Tokenize(line, tokens);

    // if it's an empty line, skip it
    if (tokens.size() == 0)
      continue;

    const SS_OPCODE* opcode = GetOpcode(tokens[0].c_str());
    VerifyNumParameters(tokens, opcode);

    // dependancy
    if (tokens[0] == "depend")
    {
    }

    // define global variable
    else if (tokens[0] == "global")
    {
      outbuffer->WriteDword(0);
    }

    // string literal
    else if (tokens[0] == "string")
    {
      int string_length = tokens[2].length() - 2 + 1; // remove quotations and add terminating zero
      int aligned_length = (string_length + 3) / 4;   // everything in SS is dword-aligned

      for (int i = 0; i < tokens[2].length() - 2; i++)
        outbuffer->WriteByte(tokens[2][i + 1]);
      outbuffer->WriteByte(0); // terminating zero
      for (int i = 0; i < (aligned_length * 4 - string_length); i++)
        outbuffer->WriteByte(0);
    }

    // begin function
    else if (tokens[0] == "function")
    {
    }

    // end function
    else if (tokens[0] == "endfunction")
    {
      Parameters.clear();
      outbuffer->WriteDword(GetOpcode("ret")->raw_value);
    }

    // define function parameter
    else if (tokens[0] == "parameter")
    {
      Parameters.push_back(tokens[1]);
      outbuffer->WriteDword(opcode->raw_value);
    }

    // create local
    else if (tokens[0] == "lc")
    {
      Locals.push_back(tokens[1]);
      outbuffer->WriteDword(opcode->raw_value);
    }

    // destroy local
    else if (tokens[0] == "ld")
    {
      if (Locals[Locals.size() - 1] != tokens[1])
        Error("Must destroy locals in reverse order");
      Locals.pop_back();
      outbuffer->WriteDword(opcode->raw_value);
    }

    // local label
    else if (tokens[0] == "label")
    {
    }
    
    // opcodes
    else if (opcode) 
    {
      // output opcode
      outbuffer->WriteDword(opcode->raw_value);

      // output operands
      for (int i = 0; i < opcode->num_parameters; i++)
      {
        sString& _t = tokens[i + 1];
        char* t = newstr(_t.c_str());

        char* token = strtok(t, ":");
        dword reference = 0;
        while (token != NULL)
        {
          outbuffer->WriteDword(GetOperandType(token) | reference);
          outbuffer->WriteDword(GetOperandCode(token));

          token = strtok(NULL, ":");
          reference = VP_REFERENCE;
        }

        delete[] t;
      }
    }

    else
    {
      Error("Unknown line");
    }

  }
}

////////////////////////////////////////////////////////////////////////////////

bool
sAssembler::ReadLine(sString& line)
{
  char l[MAX_LINE];
  memset(l, 0, sizeof(l));
  char* m = l;

  if (inbufferlocation >= inbuffersize)
    return false;

  char c = inbuffer[inbufferlocation];
  inbufferlocation++;

  while (c != '\n')
  {
    if (c != '\r')
      *m++ = c;
    
    if (inbufferlocation >= inbuffersize)
      return true;

    c = inbuffer[inbufferlocation];
    inbufferlocation++;
  }

  line = l;
  return true;
}

////////////////////////////////////////////////////////////////////////////////

inline bool IsWhitespace(char c)
{
  return (c == ' ' || c == '\t' || c == '\n' || c == '\r');
}

inline const char* SkipWhitespace(const char* p)
{
  while (IsWhitespace(*p))
    p++;
  return p;
}

void
sAssembler::Tokenize(const sString& line, sVector<sString>& tokens)
{
  const char* p = line.c_str();
  while (*p)
  {
    p = SkipWhitespace(p);
    if (*p == 0)
      break;

    char token[MAX_LINE];  // max token length
    memset(token, 0, sizeof(token));
    char* t = token;

    if (*p == '"')
    {
      *t++ = *p++;
      while (*p != '"')
        *t++ += *p++;
      *t++ += *p++;
    }
    else
    {
      while (!IsWhitespace(*p) && *p != 0)
      {
        // if we have a comment, skip the rest of the line
        if (*p == ';')
        {
          if (t > token)
            tokens.push_back(token);
          return;
        }

        *t++ += *p++;
      }
    }

    tokens.push_back(token);
  }
}

////////////////////////////////////////////////////////////////////////////////

void
sAssembler::VerifyNumParameters(const sVector<sString>& tokens, const SS_OPCODE* opcode)
{
  int num_parameters;

  if (tokens[0] == "depend")
    num_parameters = 1;
  else if (tokens[0] == "string")
    num_parameters = 2;
  else if (tokens[0] == "global")
    num_parameters = 1;
  else if (tokens[0] == "function")
    num_parameters = 1;
  else if (tokens[0] == "endfunction")
    num_parameters = 0;
  else if (tokens[0] == "parameter")
    num_parameters = 1;
  else if (tokens[0] == "lc")
    num_parameters = 1;
  else if (tokens[0] == "ld")
    num_parameters = 1;
  else if (tokens[0] == "label")
    num_parameters = 1;
  else
  {
    if (opcode)
      num_parameters = opcode->num_parameters;
    else
      Error("Unknown opcode");
  }

  if (tokens.size() != num_parameters + 1)
    Error(tokens[0] + ": Incorrect number of parameters");
}

////////////////////////////////////////////////////////////////////////////////

const SS_OPCODE*
sAssembler::GetOpcode(const char* s)
{
  // find hash index
  int index = 0;
  const char* p = s;
  while (*p)
  {
    index += *p;
    p++;
  }
  index %= 256;

  HashNode* q = OpcodeHashTable[index];
  while (q)
  {
    if (strcmp(s, g_Opcodes[q->index].name) == 0)
      return g_Opcodes + q->index;
    q = q->next;
  }

  return NULL;
}

////////////////////////////////////////////////////////////////////////////////

int
sAssembler::GetOperandSize(const sString& s)
{
  int size = 2;
  const char* p = s.c_str();
  while (strchr(p, ':')  != NULL)
  {
    size += 2;
    p = strchr(p, ':') + 1;
  }

  return size;
}

////////////////////////////////////////////////////////////////////////////////

inline bool IsRegister(const char* s)
{
  if (s[0] == 0 || s[1] == 0)
    return false;
  return (s[0] == 'r' && (s[1] >= '0' && s[1] <= '7') && s[2] == 0);
}

sdword
sAssembler::GetOperandType(const sString& s)
{
  if (IsRegister(s.c_str()))
    return VP_REGISTER;
  if (IsLiteral(s))
    return VP_LITERAL;
  if (IsLabel(s))
    return VP_LABEL;
  if (IsLocal(s))
    return VP_LOCAL;
  if (IsParameter(s))
    return VP_PARAMETER;
  Error("Unknown operand type");
  return 0;
}

////////////////////////////////////////////////////////////////////////////////

sdword
sAssembler::GetOperandCode(const sString& s)
{
  if (IsRegister(s.c_str()))
  {
    return s.c_str()[1] - '0';
  }

  else if (IsLiteral(s))
  {
    return aLiteral(s);
  }

  else if (IsLabel(s))
  {
    for (int i = 0; i < StringLiterals.size(); i++)
      if (StringLiterals[i].name == s)
        return StringLiterals[i].location;
    for (int i = 0; i < Globals.size(); i++)
      if (Globals[i].name == s)
        return Globals[i].location;
    for (int i = 0; i < Functions.size(); i++)
      if (Functions[i].name == s)
        return Functions[i].location;
    for (int i = 0; i < Labels.size(); i++)
      if (Labels[i].name == s)
        return Labels[i].location;
    Error("Internal Assembler Error");
    return 0;
  }

  else if (IsLocal(s))
  {
    for (int i = 0; i < Locals.size(); i++)
      if (Locals[i] == s)
        return i;
    Error("Internal Assembler Error");
    return 0;
  }

  else if (IsParameter(s))
  {
    for (int i = 0; i < Parameters.size(); i++)
      if (Parameters[i] == s)
        return i;
    Error("Internal Assembler Error");
    return 0;
  }

  return 0;
}

////////////////////////////////////////////////////////////////////////////////

bool
sAssembler::IsLiteral(const sString& s)
{
  // system functions are considered literals

  if (s[0] == '_')
  {
    // find index into hash table
    int index = 0;
    const char* p = s.c_str() + 1;
    while (*p)
      index += *p++;
    index %= 256;

    HashNode* q = SystemFunctionHashTable[index];
    while (q)
    {
      if (s == sString("_") + g_SystemFunctions[q->index].name)
        return true;
      q = q->next;
    }
  }

  // check if numerical literal
  bool is_integer = true;

  const char* _s = s.c_str();
  if (*_s == '-')
    _s++;

  while (*_s)
  {
    if (*_s == '.')
    {
      if (is_integer == false)
        return false;
      is_integer = false;
    }
    else
    {
      if (!isdigit(*_s))
        return false;
    }
    _s++;
  }

  return true;
}

////////////////////////////////////////////////////////////////////////////////

sdword
sAssembler::aLiteral(const sString& s)
{
  // system functions are considered literals

  if (s[0] == '_')
  {
    // find index into hash table
    int index = 0;
    const char* p = s.c_str() + 1;
    while (*p)
      index += *p++;
    index %= 256;

    HashNode* q = SystemFunctionHashTable[index];
    while (q)
    {
      if (s == sString("_") + g_SystemFunctions[q->index].name)
        return g_SystemFunctions[q->index].ordinal;
      q = q->next;
    }
  }

  // test for numeric values
  bool is_integer = true;

  const char* _s = s.c_str();
  if (*_s == '-')
    _s++;

  while (*_s)
  {
    if (*_s == '.')
    {
      if (is_integer == false)
        return false;
      is_integer = false;
    }
    else
    {
      if (!isdigit(*_s))
        return false;
    }
    _s++;
  }

  
  if (is_integer == true)
    return atoi(s.c_str());
  else
  {
    sdword a[1];
    *(float*)a = (float)atof(s.c_str());
    return *a;
  }
}

////////////////////////////////////////////////////////////////////////////////

bool
sAssembler::IsLabel(const sString& s)
{
  for (int i = 0; i < StringLiterals.size(); i++)
    if (StringLiterals[i].name == s)
      return true;
  for (int i = 0; i < Globals.size(); i++)
    if (Globals[i].name == s)
      return true;
  for (int i = 0; i < Functions.size(); i++)
    if (Functions[i].name == s)
      return true;
  for (int i = 0; i < Labels.size(); i++)
    if (Labels[i].name == s)
      return true;
  return false;
}

////////////////////////////////////////////////////////////////////////////////

bool
sAssembler::IsLocal(const sString& s)
{
  for (int i = 0; i < Locals.size(); i++)
    if (Locals[i] == s)
      return true;
  return false;
}

////////////////////////////////////////////////////////////////////////////////

bool
sAssembler::IsParameter(const sString& s)
{
  for (int i = 0; i < Parameters.size(); i++)
    if (Parameters[i] == s)
      return true;
  return false;
}

////////////////////////////////////////////////////////////////////////////////

void
sAssembler::AddDependancy(const sString& dependancy)
{
  Dependancies.push_back(dependancy);
}

////////////////////////////////////////////////////////////////////////////////

void
sAssembler::Error(const sString& message)
{
  char CurrentLine_string[80];
  sprintf(CurrentLine_string, "%d", CurrentLine);
  throw sScriptException(message + " - Line: " + CurrentLine_string);
}

////////////////////////////////////////////////////////////////////////////////
