/*
 * Simple disk cache
 */

#include "common.h"
#include "list.h"
#include "sdcard.h"
#include "systime.h"
#include "console.h"

#define CACHE_BLOCKS		1024
#define CACHE_BLOCK_BITS	5
#define CACHE_BLOCK_SECTORS	(1 << CACHE_BLOCK_BITS)
#define CACHE_BLOCK_SIZE	(SECTOR_SIZE << CACHE_BLOCK_BITS)

#define CACHE_HASH_SIZE		(CACHE_BLOCKS*2)

typedef sector_t block_t;
#define NO_BLOCK ((block_t)(-1))

enum block_status {
    FL_INVALID   = 0,
    FL_VALID     = 1,
    FL_DIRTY_BIT = 2,
    FL_DIRTY     = FL_VALID|FL_DIRTY_BIT
};

struct cache_block {
    struct dll hash;		/* Link in hash chain or free list */
    struct dll lru;		/* Link in LRU chain */
    block_t block;		/* Physical block index */
    enum block_status flags;	/* Status flags */
    char data[CACHE_BLOCK_SIZE] __attribute__((aligned(4)));
};

static struct dll __dram_noinit cache_hash[CACHE_HASH_SIZE];
static struct cache_block __dram_noinit disk_cache[CACHE_BLOCKS];

static struct dll lru_list;
static struct dll free_list;

void disk_cache_init(void)
{
    struct cache_block *bp, *bep;
    struct dll *hp, *hep;

    dll_init(&free_list);
    dll_init(&lru_list);

    bp = disk_cache;
    bep = bp + CACHE_BLOCKS;
    while (bp < bep) {
	dll_insert_head(&free_list, &bp->hash);
	dll_insert_head(&lru_list, &bp->lru);
	bp->block = NO_BLOCK;
	bp->flags = FL_INVALID;
	bp++;
    }

    hp  = cache_hash;
    hep = hp + CACHE_HASH_SIZE;
    while (hp < hep)
	dll_init(hp++);
}

static inline __attribute__((const)) struct dll *hash_slot(block_t block)
{
    uint64_t m;
    uint32_t hash;

    m = UINT64_C(0x34f1f85d) * block;
    hash = (m >> 32) + m;

    return &cache_hash[hash % CACHE_BLOCKS];
}

static struct cache_block *disk_cache_find(block_t block)
{
    struct dll *hp, *bhp;
    struct cache_block *bp;

    hp = hash_slot(block);

    for (bhp = hp->next; bhp != hp; bhp = bhp->next) {
	bp = container_of(bhp, struct cache_block, hash);
	if (bp->block == block)
	    return bp;
    }

    return NULL;
}

static void invalidate_block(struct cache_block *bp)
{
    dll_remove(&bp->hash);
    dll_insert_head(&free_list, &bp->hash);
    bp->block = NO_BLOCK;
    bp->flags = FL_INVALID;
    dll_demote(&lru_list, &bp->lru);
}

static DRESULT sync_block(struct cache_block *bp)
{
    if (bp->flags == FL_DIRTY) {
	sector_t sector = bp->block << CACHE_BLOCK_BITS;
	sector_t size = sdc.lbasize;
	sector_t sectors = min(CACHE_BLOCK_SECTORS, size - sector);

	if (sdcard_write_sectors(bp->data, sector, sectors) != (int)sectors) {
	    invalidate_block(bp); /* Or...? */
	    return RES_ERROR;
	}

	bp->flags = FL_VALID;
    }

    return RES_OK;
}

static DRESULT clear_block(struct cache_block *bp)
{
    DRESULT rv;

    rv = sync_block(bp);
    if (rv != RES_OK)
	return rv;

    invalidate_block(bp);
    return RES_OK;
}

static DRESULT sync_all(void)
{
    DRESULT rv = RES_OK;
    struct dll *bhp;
    struct cache_block *bp;

    for (bhp = lru_list.next; bhp != &lru_list; bhp = bhp->next) {
	bp = container_of(bhp, struct cache_block, lru);

	if (bp->flags == FL_DIRTY)
	    rv |= sync_block(bp);
    }

    return rv;
}

static struct cache_block *disk_cache_get(block_t block, bool do_read)
{
    const sector_t size = sdc.lbasize;
    struct cache_block *bp;

    bp = disk_cache_find(block);

    if (!bp) {
	/* Block not in cache, need to get it */
	sector_t sector = block << CACHE_BLOCK_BITS;
	int sectors = CACHE_BLOCK_SECTORS;

	if (sector >= size)
	    return NULL;

	if (sector + sectors > size)
	    sectors = size - sectors; /* Truncated final block */

	/* Get the oldest block */
	bp = container_of(lru_list.prev, struct cache_block, lru);
	clear_block(bp);

	if (do_read) {
	    if (sdcard_read_sectors(bp->data, sector, sectors) != sectors)
		return NULL;

	    bp->flags = FL_VALID;
	}

	bp->block = block;
	dll_insert_head(hash_slot(block), &bp->hash);
    }

    dll_promote(&lru_list, &bp->lru);
    return bp;
}

/* --------------------------------------------------------------------------
 *  Interface to fatfs
 * ------------------------------------------------------------------------- */

DSTATUS disk_initialize(BYTE drive)
{
    DSTATUS status, old_status;

    if (drive != 0)
	return STA_NOINIT | STA_NODISK;

    old_status = sdc.status;
    status = sdcard_init();
    if ((status ^ old_status) & STA_NOINIT)
	disk_cache_init();

    return sdc.fsstatus = status;
}

DRESULT disk_ioctl(BYTE drive, BYTE command, void *buffer)
{
    if (drive != 0)
	return STA_NOINIT | STA_NODISK;

    if (sdc.status & STA_NOINIT)
	return RES_NOTRDY;

    switch (command) {
    case CTRL_SYNC:
	return sync_all();
    case GET_SECTOR_SIZE:
	*(WORD *)buffer = SECTOR_SIZE;
	return RES_OK;
    case GET_SECTOR_COUNT:
	*(DWORD *)buffer = sdc.lbasize;
	return RES_OK;
    case GET_BLOCK_SIZE:
	*(DWORD *)buffer = CACHE_BLOCK_SECTORS;
	return RES_OK;
    default:
	return RES_PARERR;
    }
}

DRESULT disk_read(BYTE drive, BYTE *buffer,
		  LBA_t sectornumber, UINT sectorcount)
{
    (void)drive;

    if (sdc.status & STA_NOINIT)
	return RES_NOTRDY;

    if (!sectorcount)
	return RES_OK;

    block_t block = sectornumber >> CACHE_BLOCK_BITS;
    block_t last = (sectornumber + sectorcount - 1) >> CACHE_BLOCK_BITS;
    size_t offset = (sectornumber & (CACHE_BLOCK_SECTORS-1)) << SECTOR_SHIFT;
    size_t len = sectorcount << SECTOR_SHIFT;

    while (block <= last) {
	struct cache_block *bp = disk_cache_get(block, true);
	if (!bp)
	    return RES_ERROR;

	size_t bytes = min(CACHE_BLOCK_SIZE - offset, len);

	memcpy(buffer, bp->data + offset, bytes);
	len -= bytes;
	block++;
	offset = 0;
    }

    return RES_OK;
}

DRESULT disk_write(BYTE drive, const BYTE *buffer, LBA_t sectornumber,
		   UINT sectorcount)
{
    (void)drive;

    if (sdc.status & STA_NOINIT)
	return RES_NOTRDY;

    if (!sectorcount)
	return RES_OK;

    block_t block = sectornumber >> CACHE_BLOCK_BITS;
    block_t last = (sectornumber + sectorcount - 1) >> CACHE_BLOCK_BITS;
    size_t offset = (sectornumber & (CACHE_BLOCK_SECTORS-1)) << SECTOR_SHIFT;
    size_t len = sectorcount << SECTOR_SHIFT;
    size_t size = sdc.lbasize;

    while (block <= last) {
	sector_t sector = block << CACHE_BLOCK_BITS;
	sector_t sectors = min(CACHE_BLOCK_SECTORS, size - sector);
	size_t block_bytes = sectors << SECTOR_SHIFT;
	size_t bytes = min(block_bytes - offset, len);
	struct cache_block *bp;

	bp = disk_cache_get(block, bytes < block_bytes);
	if (!bp)
	    return RES_ERROR;

	memcpy(bp->data + offset, buffer, bytes);
	bp->flags = FL_DIRTY;

	len -= bytes;
	block++;
	offset = 0;
    }

    return RES_OK;
}

DWORD get_fattime(void)
{
    return SYSCLOCK_DATETIME;	/* Already in FAT format */
}