/**
 * This shader extends `DandyMeshPhysicalShader` and displays the occlusal distance heatmap
 */
import { MODEL_GRAY_COLOR } from './Colors';
import type { DandyMeshPhysicalShaderParams } from './DandyMeshPhysicalShader';
import {
    vertexShader as baseVertexShader,
    fragmentShader as baseFragmentShader,
    DEFAULT_DANDY_MESH_PHYSICAL_SHADER_PARAMS,
} from './DandyMeshPhysicalShader';
import * as THREE from 'three';

let _vertexShader = baseVertexShader;

const define_distance_attribute = `
attribute float occlusion_layer;
#include <common>
`;
_vertexShader = _vertexShader.replace('#include <common>', define_distance_attribute);

const color_pars_vertex = `
#include <color_pars_vertex>
varying float vOcclusionLayer;
`;
_vertexShader = _vertexShader.replace('#include <color_pars_vertex>', color_pars_vertex);

const color_vertex = `
vOcclusionLayer = 1.0;

#include <color_vertex>

vOcclusionLayer *= occlusion_layer;
`;
_vertexShader = _vertexShader.replace('#include <color_vertex>', color_vertex);

export const vertexShader = _vertexShader;

let _fragmentShader = baseFragmentShader;

// add our color maps functions to the standard colors fragment
const color_pars_fragment = `
uniform float vMin;
uniform float vMax;
uniform bool showHeatmap;

#include <color_pars_fragment>

varying float vOcclusionLayer;

// define our color maps functions
vec3 occlusalColorMap(float v, float vMin, float vMax, vec3 DEFAULT_MODEL_RGB_U8_COLOR)
{
    if(v > vMax || v < -100.0){ //out of range
        return DEFAULT_MODEL_RGB_U8_COLOR;
    }

    float R = 0.0;
    float G = 0.0;
    float B = 0.0;
    if (v < vMin) {
        R = 0.0;
        G = 0.0;
        B = 1.0;
    }else{
        float ratio = min(1.0, (v - vMin) / (vMax - vMin));

        R = min(1.0, pow(2.0 * (1.0 - ratio), 0.4));
        G = min(1.0, pow(2.0 * ratio, 0.4));
        B = 0.0;
    }

    vec3 result = vec3(R, G, B);
    return result;
}
`;
_fragmentShader = _fragmentShader.replace('#include <color_pars_fragment>', color_pars_fragment);

const color_fragment = `
#include <color_fragment>

vec3 DEFAULT_MODEL_RGB_U8_COLOR = vec3(1.0, 1.0, 1.0); 
vec3 multiplyColor = DEFAULT_MODEL_RGB_U8_COLOR;

if(showHeatmap){
    multiplyColor = occlusalColorMap(vOcclusionLayer, vMin, vMax, DEFAULT_MODEL_RGB_U8_COLOR);
    diffuseColor.rgb *= multiplyColor;
}
`;
_fragmentShader = _fragmentShader.replace('#include <color_fragment>', color_fragment);

export const fragmentShader = _fragmentShader;

export type OcclusalHeatmapShaderParams = DandyMeshPhysicalShaderParams & {
    showHeatmap: boolean;
    heatMapRange: { min: number; max: number };
    opacity: number;
    map?: THREE.Texture;
};

export const DEFAULT_OCCLUSAL_SHADER_PARAMS: OcclusalHeatmapShaderParams = {
    ...DEFAULT_DANDY_MESH_PHYSICAL_SHADER_PARAMS,
    showHeatmap: true,
    heatMapRange: { min: -0.1, max: 0.4 },
    opacity: 1.0,
};

export interface OcclusalHeatmapShader {
    vertexShader: string;
    fragmentShader: string;
    uniforms: Record<string, THREE.IUniform>;
    lights: true;
    flatShading: false;
    map?: THREE.Texture;
}

export function createOcclusalHeatmapShader(params: Partial<OcclusalHeatmapShaderParams> = {}): OcclusalHeatmapShader {
    const { showHeatmap, heatMapRange, opacity, sRGBToLinear, saturation, lightness, map } = {
        ...DEFAULT_OCCLUSAL_SHADER_PARAMS,
        ...params,
    };

    // N.B. UniformUtils.merge() cannot be used with Texture values as cloning
    // a texture breaks it (not sure why).
    const uniforms: Record<string, THREE.IUniform> = {
        ...THREE.UniformsUtils.clone(THREE.ShaderLib.physical.uniforms),
        opacity: { value: opacity },
        transparent: { value: opacity < 1.0 },
        diffuse: { value: MODEL_GRAY_COLOR },
        sRGBToLinear: { value: sRGBToLinear },
        saturation: { value: saturation },
        lightness: { value: lightness },
        roughness: { value: 0.15 },
        reflectivity: { value: 0.2 },
        clearcoat: { value: 0.135 },
        clearcoatRoughness: { value: 0.175 },
        showHeatmap: { value: showHeatmap },
        vMin: { value: heatMapRange.min },
        vMax: { value: heatMapRange.max },
        map: { value: map },
    };

    return {
        vertexShader,
        fragmentShader,
        uniforms,
        lights: true,
        flatShading: false,
        map,
    };
}
