import { HIGHLIGHT_TOOTH_COLOR } from '../ModelViewer/defaultModelColors';
import { inferno, magma, plasma, turbo, viridis } from './ColorData';
import _ from 'lodash';
import * as THREE from 'three';

type RgbValue = { r: number; g: number; b: number };

/**
 * A mapping function that, given a range, will interpolate an RGB color for `v`
 */
export type ColorMapFn = (vMin: number, vMax: number, v: number) => RgbValue | undefined;
export type SimpleColorMapFn = (v: number) => RgbValue;

export const DEFAULT_MODEL_RGB_U8_COLOR = new THREE.Color(HIGHLIGHT_TOOTH_COLOR) as RgbValue;
export const DEFAULT_MODEL_CLEARANCE_COLOR = { r: 252, g: 242, b: 204 };

function clamp(x: number, lo: number, hi: number) {
    return Math.max(lo, Math.min(x, hi));
}

function toByte(x: number, gamma: number = 1) {
    return Math.round(255 * Math.pow(x, gamma));
}

type ColorTuple = readonly [number, number, number];
function toRgb([r, g, b]: ColorTuple) {
    return { r, g, b };
}

function fromTable(colorTable: readonly ColorTuple[]) {
    return (t: number) => {
        if (t <= 0) {
            return toRgb(colorTable[0] as ColorTuple);
        }
        if (1 <= t) {
            return toRgb(_.last(colorTable) as ColorTuple);
        }
        const u = clamp(t, 0, 1) * (colorTable.length - 1);
        const i = Math.trunc(u);
        const v = u - i;
        const [r0, g0, b0] = colorTable[i] as ColorTuple;
        const [r1, g1, b1] = colorTable[i + 1] as ColorTuple;
        return {
            r: Math.round(r0 + (r1 - r0) * v),
            g: Math.round(g0 + (g1 - g0) * v),
            b: Math.round(b0 + (b1 - b0) * v),
        };
    };
}

export const SIMPLE_COLOR_MAP_KEYS = [
    'jet',
    'greenToRed',
    'rgbSweep',
    'magma',
    'inferno',
    'plasma',
    'viridis',
    'turbo',
] as const;

export const SimpleColorMaps: Record<(typeof SIMPLE_COLOR_MAP_KEYS)[number], SimpleColorMapFn> = {
    jet: t => {
        const v = clamp(t, 0, 1);
        let r = 1.0;
        let g = 1.0;
        let b = 1.0;

        if (v < 0.25) {
            r = 0;
            g = 4 * v;
        } else if (v < 0.5) {
            r = 0;
            b = 1 - 4 * (v - 0.25);
        } else if (v < 0.75) {
            r = 4 * (v - 0.5);
            b = 0;
        } else {
            g = 1 - 4 * (v - 0.75);
            b = 0;
        }

        r = toByte(r, 0.45);
        g = toByte(g, 0.45);
        b = toByte(b, 0.45);

        return { r, g, b };
    },

    greenToRed: t => {
        const u = clamp(2 * t, 0, 1);
        const v = clamp(2 - 2 * t, 0, 1);
        return { r: toByte(u, 0.4), g: toByte(v, 0.4), b: 0 };
    },

    rgbSweep: t => {
        const u = clamp(2 * t, 0, 1);
        const v = clamp(2 * t - 1, 0, 1);
        return { r: toByte(u, 0.75), g: toByte(1 - u - v, 0.75), b: toByte(v, 0.75) };
    },

    magma: fromTable(magma),
    inferno: fromTable(inferno),
    plasma: fromTable(plasma),
    viridis: fromTable(viridis),
    turbo: fromTable(turbo),
};

export const jetColorMapSqrt: ColorMapFn = (vMin, vMax, v) => {
    const vPrime = Math.min(Math.max(v, vMin + 0.000001), vMax - 0.000001);
    let R = 1.0;
    let G = 1.0;
    let B = 1.0;

    const dV = vMax - vMin;
    const ratio = (vPrime - vMin) / dV;

    if (ratio < 0.25) {
        R = 0;
        G = (4 * (vPrime - vMin)) / dV;
    } else if (ratio < 0.5) {
        R = 0;
        B = 1 + (4 * (vMin + 0.25 * dV - vPrime)) / dV;
    } else if (ratio < 0.75) {
        R = (4 * (vPrime - vMin - 0.5 * dV)) / dV;
        B = 0;
    } else {
        G = 1 + (4 * (vMin + 0.75 * dV - vPrime)) / dV;
        B = 0;
    }

    return {
        r: Math.round(255 * Math.pow(R, 0.45)),
        g: Math.round(255 * Math.pow(G, 0.45)),
        b: Math.round(255 * Math.pow(B, 0.45)),
    };
};

export const rgbOcclusionColorMap: ColorMapFn = (vMin, vMax, v) => {
    // Bracket the input
    const vPrime = Math.min(Math.max(v, vMin + 0.000001), vMax - 0.000001);

    if (vPrime < 0) {
        return { r: 0, g: 0, b: 255 };
    }

    // here we rescale to be between 0 and vMax where you might expect to see vPrime - vMin
    const ratio = vPrime / vMax;

    return {
        r: Math.round(255 * Math.min(1.0, Math.pow(2 * (1 - ratio), 0.4))),
        g: Math.round(255 * Math.min(1.0, Math.pow(2 * ratio, 0.4))),
        b: 0,
    };
};

// This puts vMin as the hard cutoff
export const rgbOcclusionColorMap2: ColorMapFn = (vMin, vMax, v) => {
    if (v < vMin) {
        return { r: 0, g: 0, b: 255 };
    }

    const ratio = Math.min(1, (v - vMin) / (vMax - vMin));

    return {
        r: Math.round(255 * Math.min(1.0, Math.pow(2 * (1 - ratio), 0.4))),
        g: Math.round(255 * Math.min(1.0, Math.pow(2 * ratio, 0.4))),
        b: 0,
    };
};

export const rgbThicknessColorMap: ColorMapFn = (vMin, vMax, v) => {
    // Bracket the input
    const vPrime = Math.min(Math.max(v, vMin + 0.000001), vMax - 0.000001);

    const dV = vMax - vMin;
    const ratio = (2 * (vPrime - vMin)) / dV;

    let R;
    let G;
    let B;
    if (v < vMin) {
        R = 1.0;
        G = 0;
        B = 0;
    } else if (v > vMax) {
        return undefined;
    } else {
        R = Math.min(Math.max(0, 1 - ratio), 1);
        B = Math.min(Math.max(0, ratio - 1), 1);
        G = 1.0 - B - R;
    }

    return {
        r: Math.round(255 * Math.pow(R, 0.75)),
        g: Math.round(255 * Math.pow(G, 0.75)),
        b: Math.round(255 * Math.pow(B, 0.75)),
    };
};

// Does linear interpolation
function lerp(x: number, x0: number, x1: number, y0: number, y1: number): number {
    if (x <= x0) {
        return y0;
    }

    if (x >= x1) {
        return y1;
    }

    return y0 + ((y1 - y0) * (x - x0)) / (x1 - x0);
}

// Converts a color with components expressed in the floating point range [0.0, 1.0] to the integer range [0, 255]
function toEightBit(c?: RgbValue): RgbValue | undefined {
    if (c) {
        return { r: Math.round(255 * c.r), g: Math.round(255 * c.g), b: Math.round(255 * c.b) };
    }
}

// A linear segmented color map, diverging to purple at the low end and orange at the high end
export const orPuColorMap: ColorMapFn = (vMin, vMax, v) => {
    const colorTable = [
        { r: 0.17647058823529413, g: 0.0, b: 0.29411764705882354 }, // purple
        { r: 0.32941176470588235, g: 0.15294117647058825, b: 0.53333333333333333 },
        { r: 0.50196078431372548, g: 0.45098039215686275, b: 0.67450980392156867 },
        { r: 0.69803921568627447, g: 0.6705882352941176, b: 0.82352941176470584 },
        { r: 0.84705882352941175, g: 0.85490196078431369, b: 0.92156862745098034 },
        { r: 0.96862745098039216, g: 0.96862745098039216, b: 0.96862745098039216 },
        { r: 0.99607843137254903, g: 0.8784313725490196, b: 0.71372549019607845 },
        { r: 0.99215686274509807, g: 0.72156862745098038, b: 0.38823529411764707 },
        { r: 0.8784313725490196, g: 0.50980392156862742, b: 0.07843137254901961 },
        { r: 0.70196078431372544, g: 0.34509803921568627, b: 0.02352941176470588 },
        { r: 0.49803921568627452, g: 0.23137254901960785, b: 0.03137254901960784 }, // orange
    ];

    if (v <= vMin) {
        return toEightBit(colorTable[0]);
    }

    const lastColorIndex = colorTable.length - 1;
    if (v >= vMax) {
        return toEightBit(colorTable[lastColorIndex]);
    }

    const scaledRatio = (lastColorIndex * (v - vMin)) / (vMax - vMin);
    const breakpoint = Math.floor(scaledRatio);

    const alpha = scaledRatio - breakpoint;
    const c0 = colorTable[breakpoint];
    const c1 = colorTable[breakpoint + 1];

    if (c0 && c1) {
        return toEightBit({
            r: lerp(alpha, 0, 1, c0.r, c1.r),
            g: lerp(alpha, 0, 1, c0.g, c1.g),
            b: lerp(alpha, 0, 1, c0.b, c1.b),
        });
    }
};

// A discretly segmented color map, that goes from blue to green to red.
// KEEP IN SYNC WITH SHADER FUNCTION alignmentColorMap AlignmentShader.ts
// Used mainly for the alignment distance.
// ColorMapFn is strongly typed, so we dont have great avenue to change the width of the green band in the
// middle.  TODO, extend ColorMapFn
export const alignmentDistanceColorMap: ColorMapFn = (vMin, vMax, v) => {
    const colorTable = [
        { r: 0.0, g: 0.0, b: 245.0 },
        { r: 25.0, g: 66.0, b: 191.0 },
        { r: 36.0, g: 86.0, b: 197.0 },
        { r: 43.0, g: 101.0, b: 203.0 },
        { r: 63.0, g: 140.0, b: 213.0 },
        { r: 76.0, g: 166.0, b: 219.0 },
        { r: 84.0, g: 184.0, b: 223.0 },
        { r: 98.0, g: 213.0, b: 229.0 }, //blue
        { r: 98.0, g: 213.0, b: 63.0 }, // green
        { r: 98.0, g: 213.0, b: 63.0 },
        { r: 98.0, g: 213.0, b: 63.0 },
        { r: 98.0, g: 213.0, b: 63.0 }, // green
        { r: 235.0, g: 212.0, b: 71.0 }, // yellow
        { r: 235.0, g: 188.0, b: 64.0 },
        { r: 236.0, g: 170.0, b: 59.0 },
        { r: 236.0, g: 147.0, b: 53.0 },
        { r: 235.0, g: 123.0, b: 48.0 },
        { r: 235.0, g: 82.0, b: 40.0 },
        { r: 234.0, g: 59.0, b: 37.0 },
        { r: 234.0, g: 51.0, b: 35.0 },
    ];

    if (v <= vMin) {
        return colorTable[0];
    }

    const lastColorIndex = colorTable.length - 1;
    if (v >= vMax) {
        return colorTable[lastColorIndex];
    }

    const scaledRatio = (lastColorIndex * (v - vMin)) / (vMax - vMin);
    const breakpoint = Math.floor(scaledRatio);
    return colorTable[breakpoint];
};

/**
 * Generates a clearance color in the range 0-255 given a min, max, and value
 *
 * To normalize the output of this function into the range 0-1, one must divide
 * each color channel individually by 255
 */
export const modelClearanceColorMap: ColorMapFn = (vMin, vMax, v) => {
    if (v < vMin) {
        return { r: 255, g: 0, b: 0 };
    }

    if (v > vMax) {
        return DEFAULT_MODEL_CLEARANCE_COLOR;
    }

    const ratio = (v - vMin) / (vMax - vMin);
    return {
        r: Math.round(255 * Math.min(1.0, Math.pow(2 * (1 - ratio), 0.4))),
        g: Math.round(255 * Math.min(1.0, Math.pow(2 * ratio, 0.4))),
        b: 0,
    };
};

export const bwColorMap: ColorMapFn = (vMin, vMax, v) => {
    if (v < vMin || v > vMax) {
        return undefined;
    }

    const ratio = ((v - vMin) / (vMax - vMin)) * 255.0;
    return { r: ratio, g: ratio, b: ratio };
};
