// We create a custom light setup in the shader so that we can
// light the scan mesh objects independently of other objects
import { getHeatmapRange } from '../../ModelAppearance/ModelAppearance.utils';
import type { MeshPhysicalHSVMaterialProps } from '../ModelMeshes';
import { SCAN_MATERIAL_BASE_COLOR } from '../ModelMeshes';
import { HeatMapType } from '@orthly/forceps';
import React from 'react';
import * as THREE from 'three';

// Lighting modifications
const shaderPatchDefinitions = `uniform float saturation;
uniform float lightness;
uniform bool doSRGBToLinear;
vec3 rgb2hsv(vec3 c)
{
    vec4 K = vec4(0.0, -1.0 / 3.0, 2.0 / 3.0, -1.0);
    vec4 p = mix(vec4(c.bg, K.wz), vec4(c.gb, K.xy), step(c.b, c.g));
    vec4 q = mix(vec4(p.xyw, c.r), vec4(c.r, p.yzx), step(p.x, c.r));
    float d = q.x - min(q.w, q.y);
    float e = 1.0e-10;
    return vec3(abs(q.z + (q.w - q.y) / (6.0 * d + e)), d / (q.x + e), q.x);
}
vec3 hsv2rgb(vec3 c)
{
    vec4 K = vec4(1.0, 2.0 / 3.0, 1.0 / 3.0, 3.0);
    vec3 p = abs(fract(c.xxx + K.xyz) * 6.0 - K.www);
    return c.z * mix(K.xxx, clamp(p - K.xxx, 0.0, 1.0), c.y);
}
void main() {
`;

const shaderPatchMain = `
    vec3 hsv = rgb2hsv(outgoingLight);
    hsv.y = clamp(hsv.y * saturation, 0.0, 1.0);
    hsv.z = clamp(hsv.z * lightness, 0.0, 1.0);

    outgoingLight = hsv2rgb(hsv);

    gl_FragColor = vec4( outgoingLight, diffuseColor.a );
    if (doSRGBToLinear) {
        gl_FragColor = sRGBToLinear(gl_FragColor);
    }
`;

const updatedLightingShader = `
/**
 * This is a template that can be used to light a material, it uses pluggable
 * RenderEquations (RE)for specific lighting scenarios.
 *
 * Instructions for use:
 * - Ensure that both RE_Direct, RE_IndirectDiffuse and RE_IndirectSpecular are defined
 * - If you have defined an RE_IndirectSpecular, you need to also provide a Material_LightProbeLOD. <---- ???
 * - Create a material parameter that is to be passed as the third parameter to your lighting functions.
 *
 * TODO:
 * - Add area light support.
 * - Add sphere light support.
 * - Add diffuse light probe (irradiance cubemap) support.
 */
GeometricContext geometry;
geometry.position = - vViewPosition;
geometry.normal = normal;
geometry.viewDir = ( isOrthographic ) ? vec3( 0, 0, 1 ) : normalize( vViewPosition );
#ifdef CLEARCOAT
	geometry.clearcoatNormal = clearcoatNormal;
#endif
IncidentLight directLight;
const int NUM_NEW_LIGHTS = 3;
PointLight light1 = PointLight(vec3(65.0, 0.0, 0.0), vec3(0.2, 0.2, 0.2), 0.0, 1.0);
PointLight light2 = PointLight(vec3(-65.0, 0.0, 0.0), vec3(0.2, 0.2, 0.2), 0.0, 1.0);
PointLight light3 = PointLight(vec3(0.0, -80.0, 0.0), vec3(0.2, 0.2, 0.2), 0.0, 1.0);
PointLight updatedPointLights[ NUM_NEW_LIGHTS ];
updatedPointLights[0] = light1;
updatedPointLights[1] = light2;
updatedPointLights[2] = light3;
#if defined( RE_Direct )
	PointLight pointLight;
	#if defined( USE_SHADOWMAP ) && NUM_POINT_LIGHT_SHADOWS > 0
	PointLightShadow pointLightShadow;
	#endif
    #pragma unroll_loop_start
	for ( int i = 0; i < NUM_NEW_LIGHTS; i ++ ) {
		pointLight = updatedPointLights[ i ];
		getPointDirectLightIrradiance( pointLight, geometry, directLight );
		#if defined( USE_SHADOWMAP ) && ( UNROLLED_LOOP_INDEX < NUM_POINT_LIGHT_SHADOWS )
		pointLightShadow = pointLightShadows[ i ];
		directLight.color *= all( bvec2( directLight.visible, receiveShadow ) ) ? getPointShadow( pointShadowMap[ i ], pointLightShadow.shadowMapSize, pointLightShadow.shadowBias, pointLightShadow.shadowRadius, vPointShadowCoord[ i ], pointLightShadow.shadowCameraNear, pointLightShadow.shadowCameraFar ) : 1.0;
		#endif
		RE_Direct( directLight, geometry, material, reflectedLight );
	}
	#pragma unroll_loop_end
#endif
#if ( NUM_SPOT_LIGHTS > 0 ) && defined( RE_Direct )
	SpotLight spotLight;
	#if defined( USE_SHADOWMAP ) && NUM_SPOT_LIGHT_SHADOWS > 0
	SpotLightShadow spotLightShadow;
	#endif
	#pragma unroll_loop_start
	for ( int i = 0; i < NUM_SPOT_LIGHTS; i ++ ) {
		spotLight = spotLights[ i ];
		getSpotDirectLightIrradiance( spotLight, geometry, directLight );
		#if defined( USE_SHADOWMAP ) && ( UNROLLED_LOOP_INDEX < NUM_SPOT_LIGHT_SHADOWS )
		spotLightShadow = spotLightShadows[ i ];
		directLight.color *= all( bvec2( directLight.visible, receiveShadow ) ) ? getShadow( spotShadowMap[ i ], spotLightShadow.shadowMapSize, spotLightShadow.shadowBias, spotLightShadow.shadowRadius, vSpotShadowCoord[ i ] ) : 1.0;
		#endif
		RE_Direct( directLight, geometry, material, reflectedLight );
	}
	#pragma unroll_loop_end
#endif
#if ( NUM_DIR_LIGHTS > 0 ) && defined( RE_Direct )
	DirectionalLight directionalLight;
	#if defined( USE_SHADOWMAP ) && NUM_DIR_LIGHT_SHADOWS > 0
	DirectionalLightShadow directionalLightShadow;
	#endif
	#pragma unroll_loop_start
	for ( int i = 0; i < NUM_DIR_LIGHTS; i ++ ) {
		directionalLight = directionalLights[ i ];
		getDirectionalDirectLightIrradiance( directionalLight, geometry, directLight );
		#if defined( USE_SHADOWMAP ) && ( UNROLLED_LOOP_INDEX < NUM_DIR_LIGHT_SHADOWS )
		directionalLightShadow = directionalLightShadows[ i ];
		directLight.color *= all( bvec2( directLight.visible, receiveShadow ) ) ? getShadow( directionalShadowMap[ i ], directionalLightShadow.shadowMapSize, directionalLightShadow.shadowBias, directionalLightShadow.shadowRadius, vDirectionalShadowCoord[ i ] ) : 1.0;
		#endif
		RE_Direct( directLight, geometry, material, reflectedLight );
	}
	#pragma unroll_loop_end
#endif
#if ( NUM_RECT_AREA_LIGHTS > 0 ) && defined( RE_Direct_RectArea )
	RectAreaLight rectAreaLight;
	#pragma unroll_loop_start
	for ( int i = 0; i < NUM_RECT_AREA_LIGHTS; i ++ ) {
		rectAreaLight = rectAreaLights[ i ];
		RE_Direct_RectArea( rectAreaLight, geometry, material, reflectedLight );
	}
	#pragma unroll_loop_end
#endif
#if defined( RE_IndirectDiffuse )
	vec3 iblIrradiance = vec3( 0.0 );
    vec3 updatedAmbientLightColor = vec3( 0.8, 0.8, 0.8 );
	vec3 irradiance = getAmbientLightIrradiance( updatedAmbientLightColor );
	irradiance += getLightProbeIrradiance( lightProbe, geometry );
	#if ( NUM_HEMI_LIGHTS > 0 )
		#pragma unroll_loop_start
		for ( int i = 0; i < NUM_HEMI_LIGHTS; i ++ ) {
			irradiance += getHemisphereLightIrradiance( hemisphereLights[ i ], geometry );
		}
		#pragma unroll_loop_end
	#endif
#endif
#if defined( RE_IndirectSpecular )
	vec3 radiance = vec3( 0.0 );
	vec3 clearcoatRadiance = vec3( 0.0 );
#endif
`;

// Modifications to show the color maps
// See description of each shader chunk in the comments inside the onCompile method below
const define_distance_attribute = `
attribute float occlusion_layer;
attribute float vertex_displacement;
attribute float surface_displacement;
#include <common>
`;

const color_pars_vertex = `
#if defined( USE_COLOR_ALPHA )
	varying vec4 vColor;
#elif defined( USE_COLOR ) || defined( USE_INSTANCING_COLOR )
	varying vec3 vColor;
#endif
varying float vOcclusionLayer;
varying float vVertexDisplacement;
varying float vSurfaceDisplacement;
`;

const color_vertex = `
vOcclusionLayer = 1.0;
vVertexDisplacement = 1.0;
vSurfaceDisplacement = 1.0;
#if defined( USE_COLOR_ALPHA )
	vColor = vec4( 1.0 );
#elif defined( USE_COLOR ) || defined( USE_INSTANCING_COLOR )
	vColor = vec3( 1.0 );
#endif
#ifdef USE_COLOR
	vColor *= color;
#endif
#ifdef USE_INSTANCING_COLOR
	vColor.xyz *= instanceColor.xyz;
#endif
vOcclusionLayer *= occlusion_layer;
vVertexDisplacement *= vertex_displacement;
vSurfaceDisplacement *= surface_displacement;
`;

// add our color maps functions to the standard colors fragment
const color_pars_fragment = `
uniform int activeHeatMap;
uniform float vMin;
uniform float vMax;
uniform bool showHeatmap;

#if defined( USE_COLOR_ALPHA )
	varying vec4 vColor;
#elif defined( USE_COLOR )
	varying vec3 vColor;
#endif
varying float vOcclusionLayer;
varying float vVertexDisplacement;
varying float vSurfaceDisplacement;

// define our color maps functions
vec3 occlusalColorMap(float v, float vMin, float vMax, vec3 DEFAULT_MODEL_RGB_U8_COLOR)
{
    if(v > vMax || v < -100.0){ //out of range
        return DEFAULT_MODEL_RGB_U8_COLOR;
    }

    float R = 0.0;
    float G = 0.0;
    float B = 0.0;
    if (v < vMin) {
        R = 0.0;
        G = 0.0;
        B = 1.0;
    }else{
        float ratio = min(1.0, (v - vMin) / (vMax - vMin));

        R = min(1.0, pow(2.0 * (1.0 - ratio), 0.4));
        G = min(1.0, pow(2.0 * ratio, 0.4));
        B = 0.0;
    }

    vec3 result = vec3(R, G, B);
    return result;
}
vec3 highlightValid(float v, vec3 DEFAULT_MODEL_RGB_U8_COLOR)
{
    if (v > 100.0 || v < -100.0 || v == 0.0) {
        return DEFAULT_MODEL_RGB_U8_COLOR;
    }
    return vec3(0.0, 1.0, 0.0);
}
vec3 surfaceDisplacementColorMap(float v, float vMin, float vMax, vec3 defaultColor)
{
    vec3 colorTable[11];
    colorTable[0] = vec3(0.17647058823529413,  0.0                ,  0.29411764705882354);  // purple
    colorTable[1] = vec3(0.32941176470588235,  0.15294117647058825,  0.53333333333333333);
    colorTable[2] = vec3(0.50196078431372548,  0.45098039215686275,  0.67450980392156867);
    colorTable[3] = vec3(0.69803921568627447,  0.6705882352941176 ,  0.82352941176470584);
    colorTable[4] = vec3(0.84705882352941175,  0.85490196078431369,  0.92156862745098034);
    colorTable[5] = vec3(0.96862745098039216,  0.96862745098039216,  0.96862745098039216);
    colorTable[6] = vec3(0.99607843137254903,  0.8784313725490196 ,  0.71372549019607845);
    colorTable[7] = vec3(0.99215686274509807,  0.72156862745098038,  0.38823529411764707);
    colorTable[8] = vec3(0.8784313725490196 ,  0.50980392156862742,  0.07843137254901961);
    colorTable[9] = vec3(0.70196078431372544,  0.34509803921568627,  0.02352941176470588);
    colorTable[10] = vec3(0.49803921568627452,  0.23137254901960785,  0.03137254901960784); // orange

    if (v > 100.0 || v < -100.0) {
        return defaultColor;
    }

    if (vMax <= vMin) {
        return defaultColor;
    }

    float scaledRatio = 10.0 * (v - vMin) / (vMax - vMin);
    if (scaledRatio <= 0.0) {
        return colorTable[0];
    }

    if (scaledRatio >= 10.0) {
        return colorTable[10];
    }

    int breakpoint = int(scaledRatio);
    float alpha = scaledRatio - float(breakpoint);
    return mix(colorTable[breakpoint], colorTable[breakpoint + 1], alpha);
}
`;

const color_fragment = `
#if defined( USE_COLOR_ALPHA )
	diffuseColor *= vColor;
#elif defined( USE_COLOR )
	diffuseColor.rgb *= vColor;
#endif

vec3 DEFAULT_MODEL_RGB_U8_COLOR = vec3(1.0, 1.0, 1.0);
vec3 multiplyColor = DEFAULT_MODEL_RGB_U8_COLOR;

if(showHeatmap){
    if (activeHeatMap == 3) {
        multiplyColor = occlusalColorMap(vOcclusionLayer, vMin, vMax, DEFAULT_MODEL_RGB_U8_COLOR);
    } else if (activeHeatMap == 5) {
        multiplyColor = highlightValid(vVertexDisplacement, DEFAULT_MODEL_RGB_U8_COLOR);
    } else if (activeHeatMap == 6) {
        multiplyColor = surfaceDisplacementColorMap(vSurfaceDisplacement, vMin, vMax, DEFAULT_MODEL_RGB_U8_COLOR);
    }
    diffuseColor.rgb *= multiplyColor;
}
`;

/**
 * This material wraps around the default THREE.js MeshPhysycalMaterial
 * and ads some color management controls into it as well as
 * a custom lighting scheme  to ensure we can make the scans look attractive
 * independently of the restorations.
 *
 * It does this by patching material shader source code
 * with a literal string find/replace.
 *
 */

// These are the very specific properties to get as close to Trios
// as possible

export const _DEFAULT_3SHAPE_SCAN_MESH_SHADER_PROPS = {
    sRGBToLinear: false,
    saturation: 1.0,
    lightness: 1.0,
    roughness: 0.15,
    reflectivity: 0.2,
    clearcoat: 0.135,
    clearcoatRoughness: 0.175,
    color: SCAN_MATERIAL_BASE_COLOR,
    side: THREE.DoubleSide,
    flatShading: false,
    showHeatmap: false,
    activeHeatMap: HeatMapType.SurfaceDisplacement,
};

export type ScanMeshShaderMaterialDesignProps = MeshPhysicalHSVMaterialProps;

export const ScanMeshShaderMaterialDesign: React.FC<ScanMeshShaderMaterialDesignProps> = props => {
    const { saturation, lightness, sRGBToLinear, showHeatmap, activeHeatMap, ...meshPhysicalMaterialProps } = {
        ..._DEFAULT_3SHAPE_SCAN_MESH_SHADER_PROPS,
        ...props,
    };
    const heatMapRange = getHeatmapRange({ activeHeatMap, heatMapRange: props.heatMapRange });

    const saturationUniformRef = React.useRef(new THREE.Uniform(saturation ?? 1.0));
    React.useEffect(() => {
        saturationUniformRef.current.value = saturation ?? 1.0;
    }, [saturation]);

    const lightnessUniformRef = React.useRef(new THREE.Uniform(lightness ?? 1.0));
    React.useEffect(() => {
        lightnessUniformRef.current.value = lightness ?? 1.0;
    }, [lightness]);

    const sRGBToLinearUniformRef = React.useRef(new THREE.Uniform(sRGBToLinear ?? false));
    React.useEffect(() => {
        sRGBToLinearUniformRef.current.value = sRGBToLinear ?? false;
    }, [sRGBToLinear]);

    const showHeatmapUniformRef = React.useRef(new THREE.Uniform(showHeatmap ?? false));
    React.useEffect(() => {
        showHeatmapUniformRef.current.value = showHeatmap ?? false;
    }, [showHeatmap]);

    const activeHeatMapUniformRef = React.useRef(new THREE.Uniform(activeHeatMap ? activeHeatMap : 1));
    React.useEffect(() => {
        activeHeatMapUniformRef.current.value = activeHeatMap ? activeHeatMap : 1;
    }, [activeHeatMap]);

    const vMinUniformRef = React.useRef(new THREE.Uniform(heatMapRange.min));
    const vMaxUniformRef = React.useRef(new THREE.Uniform(heatMapRange.max));
    React.useEffect(() => {
        vMinUniformRef.current.value = heatMapRange.min;
        vMaxUniformRef.current.value = heatMapRange.max;
    }, [heatMapRange]);

    const onBeforeCompileHook = React.useCallback(shader => {
        // see MeshPhysycalHSVMaterial for tech description of this

        shader.uniforms['lightness'] = lightnessUniformRef.current;
        shader.uniforms['saturation'] = saturationUniformRef.current;
        shader.uniforms['doSRGBToLinear'] = sRGBToLinearUniformRef.current;

        shader.uniforms['showHeatmap'] = showHeatmapUniformRef.current;
        shader.uniforms['activeHeatMap'] = activeHeatMapUniformRef.current;
        shader.uniforms['vMin'] = vMinUniformRef.current;
        shader.uniforms['vMax'] = vMaxUniformRef.current;

        // Additions to vertex shader:
        // 1- distances attributes
        // 2- varying variables that read from the distances attributes to be sent to the fragment shader
        let vs = shader.vertexShader;
        vs = vs.replace('#include <common>', define_distance_attribute);
        vs = vs.replace('#include <color_pars_vertex>', color_pars_vertex);
        vs = vs.replace('#include <color_vertex>', color_vertex);
        shader.vertexShader = vs;

        let fs = shader.fragmentShader;
        // This isn't a very reliable way to patch a shader code
        // if this become a problem in a future, consider to build
        // a simple Abstract Syntax Tree and modify shader via it.
        fs = fs.replace('#include <lights_fragment_begin>', updatedLightingShader);
        fs = fs.replace('void main() {', shaderPatchDefinitions);
        fs = fs.replace('\tgl_FragColor = vec4( outgoingLight, diffuseColor.a );', shaderPatchMain);

        // Modifications of the fragmentShader:
        // 1- added the varying variables and three heatmap functions to compute the colors from the distances (the varying variables)
        // 2- based on the activeHeatMap uniform we call one of the three heatmap functions added to the definitions
        fs = fs.replace('#include <color_pars_fragment>', color_pars_fragment);
        fs = fs.replace('#include <color_fragment>', color_fragment);

        shader.fragmentShader = fs;
    }, []);

    const materialRef = React.useRef<THREE.MeshPhysicalMaterial>();

    React.useEffect(() => {
        const material = materialRef.current;
        if (material) {
            // It would be ideal to for performance to just mark the
            // uniforms for update not the whole material
            // but since we don't create custom material but are patching
            // existing one, we have to mark the whole material for update
            // because THREE doesn't expose methods or flags to do that
            // for standard materials such as MeshPhysicalMaterial
            material.needsUpdate = true;
        }
    }, [meshPhysicalMaterialProps.map, meshPhysicalMaterialProps.vertexColors]);

    return (
        <meshPhysicalMaterial ref={materialRef} {...meshPhysicalMaterialProps} onBeforeCompile={onBeforeCompileHook} />
    );
};
