#define MODULE "spiflash"
#define DEBUG 1

#include "common.h"
#include "spiflash.h"
#include "spz.h"
#include "fw.h"
#include "jtag.h"

/*
 * SPI flash parameters
 */
#define JTAG_SPIFLASH_HZ 10000000	/* Max 26 MHz due to ESP32 */

static const struct jtag_config jtag_config_spiflash = {
    .hz      = JTAG_SPIFLASH_HZ,
    .pin_tms = 10,		/* CS# */
    .pin_tdi = 12,		/* MOSI */
    .pin_tdo = 13,		/* MISO */
    .pin_tck = 11,		/* SCLK */
    .be      = true		/* Bit order within bytes */
};

/*
 * Set up a command header with an address according to the SPI
 * addressing mode. Returns a pointer to the first byte past the
 * address.
 */
static void *spiflash_setup_addrcmd(uint32_t addr,
				    uint8_t cmd24, uint8_t cmd32,
				    void *cmdbuf)
{
    enum spiflash_addr_mode mode = SPIFLASH_ADDR_DYNAMIC;
    uint8_t *cmd = cmdbuf;

    if (!mode)
	mode = addr < (1 << 24) ? SPIFLASH_ADDR_24BIT : SPIFLASH_ADDR_32BIT;

    if (mode == SPIFLASH_ADDR_24BIT) {
	*cmd++ = cmd24;
    } else {
	*cmd++ = cmd32;
	*cmd++ = addr >> 24;
    }
    *cmd++ = addr >> 16;
    *cmd++ = addr >> 8;
    *cmd++ = addr;

    return cmd;
}

# define SHOW_COMMAND()					\
  do {							\
      MSG("command: ", cmdbuf, cmdlen);			\
      for (size_t i = 0; i < cmdlen; i++)		\
	  CMSG(" %02x", ((const uint8_t *)cmdbuf)[i]);	\
      CMSG("\n");					\
  } while(0)

static int spiflash_simple_command(uint32_t cmd)
{
    jtag_io(8, JIO_CS, &cmd, NULL);
    return 0;
}

static int spiflash_plain_command(const void *cmdbuf, size_t cmdlen)
{
#if DEBUG > 1
    MSG("plain: cmdbuf = %p (%zu), databuf = %p (%zu)\n",
	cmdbuf, cmdlen);
    SHOW_COMMAND();
#endif

    jtag_io(cmdlen << 3, JIO_CS, cmdbuf, NULL);
    return 0;
}

static int spiflash_output_command(const void *cmdbuf, size_t cmdlen,
				    const void *databuf, size_t datalen)
{
    if (!datalen)
	return spiflash_plain_command(cmdbuf, cmdlen);

#if DEBUG > 1
    MSG("output: cmdbuf = %p (%zu), databuf = %p (%zu)\n",
	cmdbuf, cmdlen, databuf, datalen);
    SHOW_COMMAND();
#endif

    jtag_io(cmdlen << 3, 0, cmdbuf, NULL);
    jtag_io(datalen << 3, JIO_CS, databuf, NULL);
    return 0;
}

static int spiflash_input_command(const void *cmdbuf, size_t cmdlen,
				  void *databuf, size_t datalen)
{
    if (!datalen)
	return spiflash_plain_command(cmdbuf, cmdlen);

#if DEBUG > 1
    MSG("input: cmdbuf = %p (%zu), databuf = %p (%zu)\n",
	cmdbuf, cmdlen, databuf, datalen);
    SHOW_COMMAND();
#endif

    jtag_io(cmdlen << 3, 0, cmdbuf, NULL);
    jtag_io(datalen << 3, JIO_CS, NULL, databuf);
    return 0;
}

static int spiflash_read_status(uint32_t reg)
{
    uint32_t val = 0;

    jtag_io(8, 0, (const uint8_t *)&reg, NULL);
    jtag_io(8, JIO_CS, NULL, (uint8_t *)&val);
    return val;
}

/* This needs a timeout function */
static int spiflash_wait_status(uint8_t mask, uint8_t val)
{
    unsigned int wait_loops = 100000;

#if DEBUG > 1
    MSG("waiting for status %02x/%02x... ", mask, val);
#endif

    while (wait_loops--) {
	uint8_t sr1 = spiflash_read_status(ROM_READ_SR1);

	if ((sr1 & mask) == val) {
#if DEBUG > 1
	    CMSG("ok\n");
#endif
	    return 0;
	}

	yield();
    }
#if DEBUG > 1
    CMSG("timeout\n");
#endif
    return -1;
}

static int spiflash_read(uint32_t addr, void *buffer, size_t len)
{
    uint32_t cmdbuf[2];
    uint8_t *cmd = (uint8_t *)cmdbuf;
    const uint8_t cmd24 = ROM_FAST_READ;
    const uint8_t cmd32 = ROM_FAST_READ_32BIT;
    const size_t max_read_len = -1;
    int rv;

    while (len) {
	size_t clen = len;
	if (clen > max_read_len)
	    clen = max_read_len;

	cmd = spiflash_setup_addrcmd(addr, cmd24, cmd32, cmdbuf);
	*cmd++ = 0;			/* Dummy cycles */

	rv = spiflash_input_command(cmdbuf, cmd - (uint8_t *)cmdbuf,
				    buffer, clen);
	if (rv)
	    return rv;

	addr += clen;
	buffer = (uint8_t *)buffer + clen;
	len -= clen;
    }

    return 0;
}

static int spiflash_write_enable(void)
{
    int rv;

    rv = spiflash_wait_status(1, 0);
    if (rv)
	return rv;

    spiflash_simple_command(ROM_WRITE_ENABLE);
    return spiflash_wait_status(3, 2);
}

static int spiflash_program_sector(uint32_t addr, const void *buffer)
{
    uint32_t cmdbuf[2];
    uint8_t *cmd = (uint8_t *)cmdbuf;
    const uint8_t cmd24 = ROM_PAGE_PROGRAM;
    const uint8_t cmd32 = ROM_PAGE_PROGRAM_32BIT;
    int rv;
    int loops = SPIFLASH_SECTOR_SIZE / SPIFLASH_PAGE_SIZE;
    const char *p = buffer;

    while (loops--) {
	rv = spiflash_write_enable();
	if (rv)
	    return rv;

	cmd = spiflash_setup_addrcmd(addr, cmd24, cmd32, cmdbuf);
	spiflash_output_command(cmdbuf, cmd - (uint8_t *)cmdbuf,
				p, SPIFLASH_PAGE_SIZE);

	rv = spiflash_wait_status(3, 0);
	if (rv)
	    return rv;

	addr += SPIFLASH_PAGE_SIZE;
	p += SPIFLASH_PAGE_SIZE;
    }

    return 0;
}

static int spiflash_erase_sector(uint32_t addr)
{
    uint32_t cmdbuf[2];
    uint8_t *cmd = (uint8_t *)cmdbuf;
    const uint8_t cmd24 = ROM_ERASE_4K;
    const uint8_t cmd32 = ROM_ERASE_4K_32BIT;
    int rv;

    rv = spiflash_write_enable();
    if (rv)
	return rv;

    cmd = spiflash_setup_addrcmd(addr, cmd24, cmd32, cmdbuf);
    spiflash_plain_command(cmdbuf, cmd - (uint8_t *)cmdbuf);
    return spiflash_wait_status(3, 0);
}

/*
 * from: current flash contents
 * to:   desired flash contents
 *
 * These are assumed to be aligned full block buffers
 */
enum flashmem_status {
    FMS_DONE,			/* All done, no programming needed */
    FMS_PROGRAM,		/* Can be programmed */
    FMS_ERASE,			/* Needs erase before programming */
    FMS_NOTCHECKED		/* Not checked yet */
};
static enum flashmem_status
spiflash_memcmp(const void *from, const void *to, size_t len)
{
    const uint32_t *pf = from;
    const uint32_t *pt = to;
    const uint32_t *pfend = (const uint32_t *)((const char *)from + len);
    uint32_t doprog  = 0;
    uint32_t doerase = 0;

    while (pf < pfend) {
	uint32_t f = *pf++;
	uint32_t t = *pt++;

	doprog  +=  !!(f ^ t);	/* Need programming if any data mismatch */
	doerase += !!(~f & t);	/* Need erasing if any 0 -> 1 */
    }

    return doerase ? FMS_ERASE : doprog ? FMS_PROGRAM : FMS_DONE;
}

static int spiflash_write_sector(spz_stream *spz, unsigned int addr)
{
    enum flashmem_status status = FMS_NOTCHECKED;

    MSG("flash sector at 0x%06x: ", addr);

    while (1) {
	enum flashmem_status oldstatus = status;

	status = spiflash_memcmp(spz->vbuf, spz->dbuf, SPIFLASH_SECTOR_SIZE);

	if (status >= oldstatus) {
	    CMSG("X [%u>%u]", oldstatus, status);
	    break;
	} else if (status == FMS_DONE) {
	    CMSG("V");
	    break;
	} else if (status == FMS_ERASE) {
	    CMSG("E");
	    if (spiflash_erase_sector(addr))
		break;
	} else if (status == FMS_PROGRAM) {
	    CMSG("P");
	    if (spiflash_program_sector(addr, spz->dbuf))
		break;
	}

	memset(spz->vbuf, 0xdd, SPIFLASH_SECTOR_SIZE);
	spiflash_read(addr, spz->vbuf, SPIFLASH_SECTOR_SIZE);
    }

    int rv;

    if (status == FMS_DONE) {
	CMSG(" OK\n");
	rv = 0;
    } else {
	CMSG(" FAILED\n");
	rv = (status == FMS_PROGRAM)
	    ? FWUPDATE_ERR_PROGRAM_FAILED : FWUPDATE_ERR_ERASE_FAILED;
    }

    if (!spz->err)
	spz->err = rv;

    return rv;
}

static int spiflash_read_jedec_id(void)
{
    const uint32_t cmd = ROM_JEDEC_ID;
    uint32_t jid = 0;

    spiflash_input_command((uint8_t *)&cmd, 1, (uint8_t *)&jid, 3);

    MSG("JEDEC ID: vendor %02x type %02x capacity %02x\n",
	(uint8_t)jid, (uint8_t)(jid >> 8), (uint8_t)(jid >> 16));
    return 0;
}

static void spiflash_show_status(void)
{
    MSG("status regs: %02x %02x %02x\n",
	spiflash_read_status(ROM_READ_SR1),
	spiflash_read_status(ROM_READ_SR2),
	spiflash_read_status(ROM_READ_SR3));
}

#define PIN_FPGA_READY		9
#define PIN_FPGA_BOARD_ID	1

/* Set data and data_len if the data to be written is not from the spz */
int spiflash_write_spz(spz_stream *spz,
		       const void *data, unsigned int data_left)
{
    esp_err_t rv;
    const uint8_t *dptr = data;
    unsigned int addr   = spz->header.addr;

    if (!dptr)
	data_left = spz->header.len;

    if (!data_left || spz->err)
	return spz->err;

    pinMode(PIN_FPGA_READY,    INPUT);
    pinMode(PIN_FPGA_BOARD_ID, INPUT);

    if (digitalRead(PIN_FPGA_READY) == LOW) {
	MSG("waiting for FPGA bypass to be ready..");
	while (digitalRead(PIN_FPGA_READY) != LOW) {
	    CMSG(".");
	    yield();
	}
	CMSG("\n");
    }
    MSG("FPGA bypass ready, board version v%c.\n",
	digitalRead(PIN_FPGA_BOARD_ID) ? '1' : '2');

    jtag_enable(&jtag_config_spiflash);

    spiflash_read_jedec_id();
    spiflash_show_status();

    while (data_left && !spz->err) {
	unsigned int pre_padding  = addr & (SPIFLASH_SECTOR_SIZE-1);
	unsigned int post_padding;
	unsigned int bytes;

	bytes = SPIFLASH_SECTOR_SIZE - pre_padding;
	post_padding = 0;
	if (bytes > data_left) {
	    post_padding = bytes - data_left;
	    bytes = data_left;
	}

	addr -= pre_padding;

	/* Read the current content of this block into vbuf */
	memset(spz->vbuf, 0xee, SPIFLASH_SECTOR_SIZE);
	rv = spiflash_read(addr, spz->vbuf, SPIFLASH_SECTOR_SIZE);
	if (rv)
	    goto err;

	/* Copy any invariant chunk */
	if (pre_padding)
	    memcpy(spz->dbuf, spz->vbuf, pre_padding);
	if (post_padding)
	    memcpy(spz->dbuf+SPIFLASH_SECTOR_SIZE-post_padding,
		   spz->vbuf+SPIFLASH_SECTOR_SIZE-post_padding,
		   post_padding);

	if (dptr) {
	    memcpy(spz->dbuf+pre_padding, dptr, bytes);
	    dptr += bytes;
	} else {
	    rv = spz_read_data(spz, spz->dbuf+pre_padding, bytes);
	    if (rv != (int)bytes) {
		MSG("needed %u bytes got %d\n", bytes, rv);
		rv = Z_DATA_ERROR;
		goto err;
	    }
	}

	rv = spiflash_write_sector(spz, addr);
	if (rv) {
	    spz->err = rv;
	    goto err;
	}

	addr += pre_padding + bytes;
	data_left -= bytes;
    }
    rv = 0;

err:
    if (!spz->err)
	spz->err = rv;

    jtag_disable(NULL);

    return spz->err;
}