#include <stdbool.h>
#include "./aoc.h"

#define ROM_LENGTH 1000
#define OUTPUT_BUFFER_LENGTH 1000

#define POSITION_MODE 0
#define IMMEDIATE_MODE 1

int rom[ROM_LENGTH];
int rom_count;

typedef struct IC {
  int *rom;
  int *data;
  int program_counter;
  bool halted;
  int input;
  int output_buffer[OUTPUT_BUFFER_LENGTH];
  int output_buffer_count;
} IC;

int ic_peek(IC *c, int address) {
  return c->data[address];
}

void ic_poke(IC *c, int address, int value) {
  c->data[address] = value;
}

int ic_read(IC *c) {
  return ic_peek(c, c->program_counter++);
}

void ic_write(IC *c, int value) {
  ic_poke(c, ic_peek(c, c->program_counter++), value);
}

void ic_jump(IC *c, int address) {
  c->program_counter = address;
}

void ic_halt(IC *c) {
  c->halted = true;
}

void ic_instruction_add(IC *c, int modes) {
  bool x_mode = modes % 10;
  modes /= 10;
  bool y_mode = modes % 10;
  modes /= 10;

  int x = ic_read(c);
  if (x_mode == POSITION_MODE) x = ic_peek(c, x);
  int y = ic_read(c);
  if (y_mode == POSITION_MODE) y = ic_peek(c, y);

  ic_write(c, x + y);
}

void ic_instruction_multiply(IC *c, int modes) {
  bool x_mode = modes % 10;
  modes /= 10;
  bool y_mode = modes % 10;
  modes /= 10;

  int x = ic_read(c);
  if (x_mode == POSITION_MODE) x = ic_peek(c, x);
  int y = ic_read(c);
  if (y_mode == POSITION_MODE) y = ic_peek(c, y);

  ic_write(c, x * y);
}

void ic_instruction_input(IC *c, int _modes) {
  ic_write(c, c->input);
}

void ic_instruction_output(IC *c, int modes) {
  if (c->output_buffer_count >= OUTPUT_BUFFER_LENGTH) {
    printf("Output buffer overflow!\n");
    ic_halt(c);
  }

  int mode = modes % 10;
  modes /= 10;
  int to_output = ic_read(c); 
  if (mode == POSITION_MODE) to_output = ic_peek(c, to_output);
  c->output_buffer[c->output_buffer_count++] = to_output;
}

void ic_instruction_jump_if_true(IC *c, int modes) {
  int predicate_mode = modes % 10;
  modes /= 10;
  int predicate = ic_read(c);
  if (predicate_mode == POSITION_MODE) predicate = ic_peek(c, predicate);
  printf("Predicate: %d\n", predicate);
  if (predicate == 0) {
    ic_read(c);
    return;
  }

  int pointer_mode = 0;//modes % 10;
  modes /= 10;
  int pointer = ic_read(c);
  if (pointer_mode == POSITION_MODE) pointer = ic_peek(c, pointer);
  ic_jump(c, pointer);
}

void ic_instruction_jump_if_false(IC *c, int modes) {
  int predicate_mode = modes % 10;
  modes /= 10;
  int predicate = ic_read(c);
  if (predicate_mode == POSITION_MODE) predicate = ic_peek(c, predicate);
  if (predicate != 0) {
    ic_read(c);
    return;
  }

  int pointer_mode = 0;//modes % 10;
  modes /= 10;
  int pointer = ic_read(c);
  if (pointer_mode == POSITION_MODE) pointer = ic_peek(c, pointer);
  ic_jump(c, pointer);
}

void ic_instruction_less_than(IC *c, int modes) {
  int x_mode = modes % 10;
  modes /= 10;
  int x = ic_read(c);
  if (x_mode == POSITION_MODE) x = ic_peek(c, x);

  int y_mode = modes % 10;
  modes /= 10;
  int y = ic_read(c);
  if (y_mode == POSITION_MODE) y = ic_peek(c, y);

  int output_mode = 1;//modes % 10;
  modes /= 10;
  int output = ic_read(c);
  if (output_mode == POSITION_MODE) output = ic_peek(c, output);

  if (x < y) ic_poke(c, output, 1);
  else ic_poke(c, output, 0);
}

void ic_instruction_equals(IC *c, int modes) {
  int x_mode = modes % 10;
  modes /= 10;
  int x = ic_read(c);
  if (x_mode == POSITION_MODE) x = ic_peek(c, x);

  int y_mode = modes % 10;
  modes /= 10;
  int y = ic_read(c);
  if (y_mode == POSITION_MODE) y = ic_peek(c, y);

  int output_mode = 1;//modes % 10;
  modes /= 10;
  int output = ic_read(c);
  if (output_mode == POSITION_MODE) output = ic_peek(c, output);

  if (x == y) ic_poke(c, output, 1);
  else ic_poke(c, output, 0);
}

void ic_instruction_halt(IC *c, int _modes) {
  ic_halt(c);
}

int ic_execute_instruction(IC *c) {
  //ic_print(c);
  int instruction = ic_read(c);
  int opcode = instruction % 100;
  int modes = instruction / 100;
  printf("Running %d: opcode %d with modes %d\n", instruction, opcode, modes);
  printf("PC: %d\n", c->program_counter - 1);
  switch(opcode) {
  case 1:
    ic_instruction_add(c, modes);
    break;
  case 2:
    ic_instruction_multiply(c, modes);
    break;
  case 3:
    ic_instruction_input(c, modes);
    break;
  case 4:
    ic_instruction_output(c, modes);
    break;
  case 5:
    ic_instruction_jump_if_true(c, modes);
    break;
  case 6:
    ic_instruction_jump_if_false(c, modes);
    break;
  case 7:
    ic_instruction_less_than(c, modes);
    break;
  case 8:
    ic_instruction_equals(c, modes);
    break;
  case 99:
    ic_instruction_halt(c, modes);
    break;
  default:
    printf("Invalid opcode [%d] encountered at %d\n", opcode, c->program_counter - 1);
    ic_instruction_halt(c, 0);
    break;
  }
}

void ic_print(IC *c) {
  printf("PC: %d\n", c->program_counter);

  for (int i = 0; i < rom_count; i++) {
    printf("%6d ", c->data[i]);
    if (i % 8 == 7) printf("\n");
  }
  printf("\n");

  printf("Output buffer: ");
  for (int i = 0; i < c->output_buffer_count; i++) {
    printf("[%d] ", c->output_buffer[i]);
  }
  printf("\n");
}

void ic_reset(IC *c) {
  memcpy(c->data, c->rom, rom_count * sizeof(int));
  c->program_counter = 0;
  c->halted = true;
  c->output_buffer_count = 0;
}

void ic_run(IC *c) {
  c->halted = false;
  while (!c->halted) ic_execute_instruction(c);
}

IC *ic_new_computer(void) {
  IC *c = malloc(sizeof(IC));
  c->rom = rom;
  c->data = malloc(rom_count * sizeof(int));
  ic_reset(c);
  return c;
}

int ic_load_rom_from_input(void) {
  char *input = aoc_read_input();

  char *token = strtok(input, ",");
  while (token != NULL) {
    rom[rom_count++] = atoi(token);
    token = strtok(NULL, ",");
  }

  return rom_count;
}