/* eslint-disable max-lines */
import { isNamedScanGeometry } from '../DandyFinishing/Naming.util';
import { logger } from '../Utils/Logger';
import { AttributeName, ATTRIBUTE_MAP_INVALID_VALUE } from './BufferAttributeConstants';
import { HeatMapType } from './HeatMap.types';
import { ensureBoundingBox, getFaceNormal, getTriangleByIndex } from './Mesh3d.util';
import { ensureMeshIndex, isEmptyPositionOrIndex } from './MeshIndex';
import _ from 'lodash';
import * as THREE from 'three';
import type { HitPointInfo, MeshBVH } from 'three-mesh-bvh';

const MAX_DISTANCE: number = 0.5;

/**
 * Ensures that all distance attributes are initialized
 * @param geometry The geometry whose distance attributes to initialize
 * @param proximalModels Geometry of models adjacent to `geometry`
 * @param occlusalModels Geometry of models opposite to `geometry`
 */
export interface UndercutComputationInputs {
    scanModel?: THREE.BufferGeometry;
    curtainsGeometry?: THREE.BufferGeometry;
    upAxis?: THREE.Vector3;
}

const NEARBY_THRESHOLD_MM = 5;

function filterDifferentNearby(
    geometry: THREE.BufferGeometry,
    others: THREE.BufferGeometry[],
    threshold: number,
): THREE.BufferGeometry[] {
    const baseBox = ensureBoundingBox(geometry).clone().expandByScalar(threshold);
    return others.filter(g => {
        if (g === geometry) {
            return false;
        }
        const box = ensureBoundingBox(g);
        return box.intersectsBox(baseBox);
    });
}

export function ensureDistanceAttributesInitialized(
    geometry: THREE.BufferGeometry,
    proximalModels: THREE.BufferGeometry[],
    occlusalModels: THREE.BufferGeometry[],
    curtainsGeometry?: THREE.BufferGeometry,
    allowCurtainsBackFace: boolean = false,
): number {
    let anyWorkDone: boolean = false;
    const startTime = performance.now();

    if (!geometry.hasAttribute(AttributeName.ThicknessDistance)) {
        resetDistanceField(geometry, AttributeName.ThicknessDistance);
        recomputeThicknessAfterUpdate(geometry);
        anyWorkDone = true;
    }

    if (!geometry.hasAttribute(AttributeName.ProximalDistance)) {
        resetDistanceField(geometry, AttributeName.ProximalDistance);
        filterDifferentNearby(geometry, proximalModels, NEARBY_THRESHOLD_MM).forEach(pm =>
            computeContactDistanceToNeighborModel(geometry, pm),
        );
        anyWorkDone = true;
    }

    if (!geometry.hasAttribute(AttributeName.OcclusalDistance)) {
        resetDistanceField(geometry, AttributeName.OcclusalDistance);
        filterDifferentNearby(geometry, occlusalModels, NEARBY_THRESHOLD_MM).forEach(pm =>
            computeOcclusalDistanceToOpposalModel(geometry, pm),
        );
        anyWorkDone = true;
    }

    if (curtainsGeometry !== undefined && !geometry.hasAttribute(AttributeName.CurtainsDistance)) {
        resetDistanceField(geometry, AttributeName.CurtainsDistance);
        computeCurtainsDistance(geometry, curtainsGeometry, undefined, allowCurtainsBackFace);
        anyWorkDone = true;
    }

    const proximalScan = proximalModels.find(isNamedScanGeometry);

    if (proximalScan !== undefined && !geometry.hasAttribute(AttributeName.CementGapDistance)) {
        resetDistanceField(geometry, AttributeName.CementGapDistance);
        computeCementGapDistance(geometry, proximalScan);
        anyWorkDone = true;
    }

    anyWorkDone = ensureDisplacementAttributesInitialized(geometry) || anyWorkDone;

    if (anyWorkDone) {
        const endTime = performance.now();
        return endTime - startTime;
    }
    return 0;
}

/**
 * Checks that the alignment distance attribute exists on the geometry and, if not, initializes it
 * @param geometry The geometry to check
 */
export function ensureAlignmentDistanceAttributeInitialized(geometry: THREE.BufferGeometry): void {
    if (!geometry.hasAttribute(AttributeName.AlignmentDistance)) {
        resetDistanceField(geometry, AttributeName.AlignmentDistance);
    }
}

export function ensureSculptMaskAttributeInitialized(geometry: THREE.BufferGeometry): void {
    if (!geometry.hasAttribute(AttributeName.SculptMask)) {
        resetDistanceField(geometry, AttributeName.SculptMask);
    }
}

/**
 * Ensures that the sculpt displacement distance attributes are initialized
 * @param geometry The geometry whose sculpt displacement distance attributes to initialize
 * @returns True if any of the attributes were missing and thus initialized; false otherwise
 */
export function ensureDisplacementAttributesInitialized(geometry: THREE.BufferGeometry): boolean {
    let anyWorkDone = false;

    if (!geometry.hasAttribute(AttributeName.VertexDisplacement)) {
        resetDistanceField(geometry, AttributeName.VertexDisplacement);
        anyWorkDone = true;
    }

    if (!geometry.hasAttribute(AttributeName.SurfaceDisplacement)) {
        resetDistanceField(geometry, AttributeName.SurfaceDisplacement);
        anyWorkDone = true;
    }

    return anyWorkDone;
}

/**
 * Recomputes distance attributes for the specified heatmaps on the currently modified vertices
 * @param activeHeatMaps Specifies which distance attributes to recompute
 * @param geometry The geometry whose distance attributes to update
 * @param subsetVertices The indices of the vertices whose distance attributes should be recomputed
 * @param proximalModels Geometry of models adjacent to `geometry`
 * @param occlusalModels Geometry of models opposite to `geometry`
 * @param originalGeometry The state of `geometry` before any edits were made
 */
export function recomputeActiveDistanceForASubset(
    activeHeatMaps: HeatMapType[],
    geometry: THREE.BufferGeometry,
    subsetVertices: number[],
    proximalModels: THREE.BufferGeometry[],
    occlusalModels: THREE.BufferGeometry[],
    originalGeometry?: THREE.BufferGeometry,
    curtainsGeometry?: THREE.BufferGeometry,
    allowCurtainsBackFace: boolean = false,
) {
    if (activeHeatMaps.includes(HeatMapType.Thickness)) {
        // recompute thickness of the model currently being edited
        recomputeThicknessAfterUpdate(geometry, subsetVertices);
    }

    if (activeHeatMaps.includes(HeatMapType.Proximal)) {
        // second recompute proximal distances of the models on the same arch
        resetDistanceFieldForASubset(geometry, AttributeName.ProximalDistance, subsetVertices);
        proximalModels
            .filter(pm => pm !== geometry)
            .forEach(pm => computeContactDistanceToNeighborModel(geometry, pm, subsetVertices));

        if (curtainsGeometry !== undefined) {
            // include curtains in the proximal distance
            resetDistanceFieldForASubset(geometry, AttributeName.CurtainsDistance, subsetVertices);
            computeCurtainsDistance(geometry, curtainsGeometry, subsetVertices, allowCurtainsBackFace);
        }
    }

    if (activeHeatMaps.includes(HeatMapType.Occlusal)) {
        // third recompute occlusal distances of the models on the opposite arch
        resetDistanceFieldForASubset(geometry, AttributeName.OcclusalDistance, subsetVertices);
        occlusalModels
            .filter(pm => pm !== geometry)
            .forEach(pm => computeOcclusalDistanceToOpposalModel(geometry, pm, subsetVertices));
    }

    if (activeHeatMaps.includes(HeatMapType.VertexDisplacement)) {
        resetDistanceFieldForASubset(geometry, AttributeName.VertexDisplacement, subsetVertices);
        if (originalGeometry) {
            computeVertexDisplacementForASubset(geometry, originalGeometry, subsetVertices);
        }
    }

    if (activeHeatMaps.includes(HeatMapType.SurfaceDisplacement)) {
        resetDistanceFieldForASubset(geometry, AttributeName.SurfaceDisplacement, subsetVertices);
        if (originalGeometry) {
            computeSurfaceDisplacement(geometry, originalGeometry, subsetVertices);
        }
    }

    if (activeHeatMaps.includes(HeatMapType.CementGap)) {
        const proximalScan = proximalModels.find(isNamedScanGeometry);
        if (!proximalScan) {
            throw new Error('Failed to find proximal scan');
        }

        resetDistanceFieldForASubset(geometry, AttributeName.CementGapDistance, subsetVertices);
        computeCementGapDistance(geometry, proximalScan, subsetVertices);
    }
}

export function computeContactDistanceToNeighborModel(
    sourceModel: THREE.BufferGeometry,
    targetModel: THREE.BufferGeometry,
    rawSubsetVertices?: number[],
): void {
    if (!sourceModel.attributes.position || !sourceModel.attributes.normal) {
        logger.info('Neighbor Bailing early');
        return;
    }

    const sourceAttribute = sourceModel.attributes[AttributeName.ProximalDistance];
    const sourceDistances: ArrayLike<number> = sourceAttribute?.array ?? [];

    // ensure this geometry has an index
    const targetIndex: MeshBVH = ensureMeshIndex(targetModel);

    const positionsAttribute = sourceModel.attributes.position;
    const normalsAttribute = sourceModel.attributes.normal;
    const p = new THREE.Vector3();
    const n = new THREE.Vector3();
    const tiny_n = new THREE.Vector3();
    const ray_origin = new THREE.Vector3();
    const ray_direction = new THREE.Vector3();
    let subsetVertices = rawSubsetVertices;
    if (!subsetVertices) {
        subsetVertices = _.range(sourceModel.attributes.position.count);
    }
    for (const i of subsetVertices) {
        p.fromBufferAttribute(positionsAttribute, i);
        n.fromBufferAttribute(normalsAttribute, i);
        ray_origin.copy(p);
        ray_direction.copy(n);
        tiny_n.copy(n);

        const ray = new THREE.Ray(ray_origin.add(tiny_n.multiplyScalar(0.0001)), ray_direction.negate());
        const intersects = targetIndex.raycast(ray, THREE.DoubleSide);
        const intersection = _.minBy(intersects, intersect => intersect.distance);
        if (
            intersection &&
            intersection.point &&
            intersection.face &&
            intersection.face?.normal.dot(n) < 0 &&
            intersection.distance < MAX_DISTANCE
        ) {
            sourceAttribute?.setX(i, intersection.distance);
        } else {
            const target = {} as HitPointInfo;
            targetIndex.closestPointToPoint(p, target);
            if (target.distance < MAX_DISTANCE) {
                sourceAttribute?.setX(i, -1 * target.distance);
            } else {
                sourceAttribute?.setX(i, sourceDistances[i] ?? ATTRIBUTE_MAP_INVALID_VALUE);
            }
        }
    }

    if (sourceAttribute) {
        sourceAttribute.needsUpdate = true;
    }
}

export function computeOcclusalDistanceToOpposalModel(
    sourceModel: THREE.BufferGeometry,
    targetModel: THREE.BufferGeometry,
    rawSubsetVertices?: number[],
): void {
    if (!sourceModel.attributes.position || !sourceModel.attributes.normal) {
        logger.info('Occlusal Bailing early');
        return;
    }

    const sourceAttribute = sourceModel.attributes[AttributeName.OcclusalDistance];
    const sourceDistances: ArrayLike<number> = sourceAttribute?.array ?? [];

    // ensure this geometry has an index
    const targetIndex: MeshBVH = ensureMeshIndex(targetModel);

    const positionsAttribute = sourceModel.attributes.position;
    const normalsAttribute = sourceModel.attributes.normal;
    const p = new THREE.Vector3();
    const n = new THREE.Vector3();
    const tiny_n = new THREE.Vector3();
    const ray_origin = new THREE.Vector3();
    const ray_direction = new THREE.Vector3();

    let subsetVertices = rawSubsetVertices;
    if (!subsetVertices) {
        subsetVertices = _.range(sourceModel.attributes.position.count);
    }
    for (const i of subsetVertices) {
        p.fromBufferAttribute(positionsAttribute, i);
        n.fromBufferAttribute(normalsAttribute, i);
        ray_origin.copy(p);
        ray_direction.copy(n);
        tiny_n.copy(n);

        const ray = new THREE.Ray(ray_origin.add(tiny_n.multiplyScalar(-0.0001)), ray_direction.negate());
        const intersects = targetIndex.raycast(ray, THREE.DoubleSide);
        const intersection = _.minBy(intersects, intersect => intersect.distance);
        // check if inside, signed distance matters
        if (
            intersection &&
            intersection.point &&
            intersection.face &&
            intersection.face?.normal.dot(n) < 0 &&
            intersection.distance < MAX_DISTANCE * 1.5
        ) {
            sourceAttribute?.setX(i, -1 * intersection.distance);
        } else {
            const target = {} as HitPointInfo;
            targetIndex.closestPointToPoint(p, target);

            const d = target.distance;
            // should also make sure that it's not closer to a different object
            // the normal test is to make sure it's not the intaglio of the crown
            // which could be < 0.5mm away
            const sourceDistance: number = sourceDistances[i] ?? ATTRIBUTE_MAP_INVALID_VALUE;
            const hitNormal = getFaceNormal(targetModel, target.faceIndex);
            if (d < 1.5 * MAX_DISTANCE && d < sourceDistance && hitNormal.dot(n) < 0) {
                // store the layer if less than 2x the threshold
                sourceAttribute?.setX(i, d);
            } else {
                // we only bother to store data layer if d < 150% of MAX_DISTANCE
                sourceAttribute?.setX(i, sourceDistances[i] ?? ATTRIBUTE_MAP_INVALID_VALUE);
            }
        }
    }

    if (sourceAttribute) {
        sourceAttribute.needsUpdate = true;
    }
}

export function recomputeThicknessAfterUpdate(sourceModel: THREE.BufferGeometry, rawSubsetVertices?: number[]) {
    if (!sourceModel.attributes.position || !sourceModel.attributes.normal) {
        logger.info('Thickness Bailing early');
        return;
    }

    const sourceAttribute = sourceModel.attributes[AttributeName.ThicknessDistance];

    // ensure this geometry has an index
    const targetIndex: MeshBVH = ensureMeshIndex(sourceModel);

    const positionsAttribute = sourceModel.attributes.position;
    const normalsAttribute = sourceModel.attributes.normal;
    const p = new THREE.Vector3();
    const n = new THREE.Vector3();
    const tiny_n = new THREE.Vector3();
    const ray_origin = new THREE.Vector3();
    const ray_direction = new THREE.Vector3();

    let subsetVertices = rawSubsetVertices;
    if (!subsetVertices) {
        subsetVertices = _.range(sourceModel.attributes.position.count);
    }
    for (const i of subsetVertices) {
        p.fromBufferAttribute(positionsAttribute, i);
        n.fromBufferAttribute(normalsAttribute, i);
        ray_origin.copy(p);
        ray_direction.copy(n);
        tiny_n.copy(n);

        const ray = new THREE.Ray(ray_origin.add(tiny_n.multiplyScalar(-0.0001)), ray_direction.negate());
        const intersects = targetIndex.raycast(ray, THREE.DoubleSide);
        const intersection = _.minBy(intersects, intersect => intersect.distance);
        // check if inside, signed distance matters
        if (intersection && intersection.point) {
            sourceAttribute?.setX(i, intersection.distance);
        } else {
            sourceAttribute?.setX(i, ATTRIBUTE_MAP_INVALID_VALUE);
        }
    }

    if (sourceAttribute) {
        sourceAttribute.needsUpdate = true;
    }
}

/**
 * The minimum sculpt displacement to register, in millimeters. Due to floating point precision, we may get non-zero displacements for vertices that
 * were not actually modified. Displacements less than this value are filtered out.
 */
const MIN_DISPLACEMENT_MM = 0.001;

/**
 * Computes the sculpting vertex displacement for a subset of vertices.
 * @param geometry The geometry whose vertex displacement attribute to update
 * @param originalGeometry The state of `geometry` before any edits were made
 * @param subsetVertices The indices of the vertices whose vertex displacement to calculate
 */
export function computeVertexDisplacementForASubset(
    geometry: THREE.BufferGeometry,
    originalGeometry: THREE.BufferGeometry,
    subsetVertices: number[],
) {
    const positions = geometry.getAttribute(AttributeName.Position);
    const vertexDisplacements = geometry.getAttribute(AttributeName.VertexDisplacement);
    const originalPositions = originalGeometry.getAttribute(AttributeName.Position);

    if (!(positions && vertexDisplacements && originalPositions)) {
        return;
    }

    const pos = new THREE.Vector3();
    const originalPos = new THREE.Vector3();
    for (const i of subsetVertices) {
        pos.fromBufferAttribute(positions, i);
        originalPos.fromBufferAttribute(originalPositions, i);

        const distance = pos.sub(originalPos).length();
        // Mark this vertex as unmodified if the displacement is below the threshold.
        vertexDisplacements.setX(i, distance >= MIN_DISPLACEMENT_MM ? distance : ATTRIBUTE_MAP_INVALID_VALUE);
    }

    vertexDisplacements.needsUpdate = true;
}

/**
 * Computes the sculpting surface displacement for a subset of vertices.
 * @param geometry The geometry whose surface displacement attribute to update
 * @param originalGeometry The state of `geometry` before any edits were made
 * @param subsetVertices The indices of the vertices whose surface displacement to calculate. If not supplied, all
 * vertices are updated.
 */
export function computeSurfaceDisplacement(
    geometry: THREE.BufferGeometry,
    originalGeometry: THREE.BufferGeometry,
    rawSubsetVertices?: number[],
) {
    const positions = geometry.getAttribute(AttributeName.Position);
    const normals = geometry.getAttribute(AttributeName.Normal);
    const surfaceDisplacements = geometry.getAttribute(AttributeName.SurfaceDisplacement);

    if (!(positions && normals && surfaceDisplacements)) {
        return;
    }

    const originalIndex = ensureMeshIndex(originalGeometry);

    let subsetVertices = rawSubsetVertices;
    if (!subsetVertices) {
        subsetVertices = _.range(positions.count);
    }

    const position = new THREE.Vector3();
    const normal = new THREE.Vector3();
    const positionToOriginalIndex = new THREE.Vector3();
    for (const i of subsetVertices) {
        position.fromBufferAttribute(positions, i);

        const target = originalIndex.closestPointToPoint(position);
        if (!target) {
            continue;
        }

        if (target.distance < MIN_DISPLACEMENT_MM) {
            // Mark this vertex as unmodified.
            surfaceDisplacements.setX(i, ATTRIBUTE_MAP_INVALID_VALUE);
            continue;
        }

        positionToOriginalIndex.subVectors(target.point, position);
        normal.fromBufferAttribute(normals, i);

        const dotProduct = positionToOriginalIndex.dot(normal);
        surfaceDisplacements.setX(i, dotProduct >= 0 ? -target.distance : target.distance);
    }

    surfaceDisplacements.needsUpdate = true;
}

export function computeCurtainsDistance(
    sourceModel: THREE.BufferGeometry,
    curtainsGeometry?: THREE.BufferGeometry | undefined,
    rawSubsetVertices?: number[],
    allowBackFace: boolean = false,
): void {
    if (!sourceModel.attributes.position || !sourceModel.attributes.normal || !curtainsGeometry) {
        return;
    }

    const sourceAttribute = sourceModel.attributes[AttributeName.CurtainsDistance];
    const sourceDistances: ArrayLike<number> = sourceAttribute?.array ?? [];

    const proximalAttribute = sourceModel.attributes[AttributeName.ProximalDistance];
    const proximalDistances: ArrayLike<number> = proximalAttribute?.array ?? [];

    if (isEmptyPositionOrIndex(curtainsGeometry)) {
        return;
    }
    // ensure this geometry has an index
    const curtainsBVH: MeshBVH = ensureMeshIndex(curtainsGeometry);

    const positionsAttribute = sourceModel.attributes.position;
    const normalsAttribute = sourceModel.attributes.normal;
    const p = new THREE.Vector3();
    const n = new THREE.Vector3();
    const tiny_n = new THREE.Vector3();
    const ray_origin = new THREE.Vector3();
    const ray_direction = new THREE.Vector3();

    let subsetVertices = rawSubsetVertices;
    if (!subsetVertices) {
        subsetVertices = _.range(sourceModel.attributes.position.count);
    }
    for (const i of subsetVertices) {
        p.fromBufferAttribute(positionsAttribute, i);
        n.fromBufferAttribute(normalsAttribute, i);
        ray_origin.copy(p);
        ray_direction.copy(n);
        tiny_n.copy(n);

        const ray = new THREE.Ray(ray_origin.add(tiny_n.multiplyScalar(0.0001)), ray_direction.negate());
        const intersects = curtainsBVH.raycast(ray, THREE.DoubleSide);
        const intersection = _.maxBy(
            intersects.filter(i => i.distance < MAX_DISTANCE),
            intersect => intersect.distance,
        );
        let distance = sourceDistances[i] ?? ATTRIBUTE_MAP_INVALID_VALUE;
        if (
            intersection &&
            intersection.point &&
            (allowBackFace || (intersection.face && intersection.face.normal.dot(n) < 0))
        ) {
            distance = intersection.distance;
        } else {
            const target = {} as HitPointInfo;
            curtainsBVH.closestPointToPoint(p, target);
            if (target.distance < MAX_DISTANCE) {
                distance = -1 * target.distance;
            }
        }
        if (proximalDistances[i] !== undefined && proximalDistances[i] !== ATTRIBUTE_MAP_INVALID_VALUE) {
            distance = Math.max(distance, proximalDistances[i] ?? ATTRIBUTE_MAP_INVALID_VALUE);
        }

        sourceAttribute?.setX(i, distance);
    }

    if (sourceAttribute) {
        sourceAttribute.needsUpdate = true;
    }
}

export function computeCementGapDistance(
    sourceModel: THREE.BufferGeometry,
    targetModel: THREE.BufferGeometry,
    rawSubsetVertices?: number[],
) {
    const targetBvh = ensureMeshIndex(targetModel);
    const posAttr = sourceModel.getAttribute(AttributeName.Position);
    const intaglioAttr = sourceModel.attributes[AttributeName.IsIntaglio];
    const sourceAttr = sourceModel.attributes[AttributeName.CementGapDistance];
    const vertices = rawSubsetVertices ?? _.range(posAttr.count);
    const pt = new THREE.Vector3();
    const triangle = new THREE.Triangle();
    for (const vIdx of vertices) {
        // Make the value invalid here in case we skip this vertex.
        sourceAttr?.setX(vIdx, ATTRIBUTE_MAP_INVALID_VALUE);
        const isIntaglio = intaglioAttr?.getX(vIdx) ?? true;
        if (!isIntaglio) {
            continue;
        }
        pt.fromBufferAttribute(posAttr, vIdx);
        const hit = targetBvh.closestPointToPoint(pt);
        if (!hit) {
            continue;
        }
        getTriangleByIndex(targetModel, hit.faceIndex, triangle);
        pt.sub(hit.point);
        const dist = triangle.isFrontFacing(pt) ? -hit.distance : hit.distance;
        sourceAttr?.setX(vIdx, dist);
    }

    if (sourceAttr) {
        sourceAttr.needsUpdate = true;
    }
}

export function resetDistanceField(sourceModel: THREE.BufferGeometry, distance_attribute_name: string): void {
    if (!sourceModel.attributes.position) {
        return;
    }

    const nVertices = sourceModel.attributes.position.count;
    const dists = Array(nVertices).fill(ATTRIBUTE_MAP_INVALID_VALUE);

    sourceModel.setAttribute(distance_attribute_name, new THREE.Float32BufferAttribute(dists, 1));
    const sourceAttribute = sourceModel.attributes[distance_attribute_name];

    if (sourceAttribute) {
        sourceAttribute.needsUpdate = true;
    }
}

export function resetDistanceFieldForASubset(
    sourceModel: THREE.BufferGeometry,
    distance_attribute_name: string,
    subsetVertices: number[],
): void {
    const sourceAttribute = sourceModel.attributes[distance_attribute_name];
    if (sourceAttribute) {
        // reset distances for this region
        for (let i = 0; i < subsetVertices.length; i++) {
            const vi = subsetVertices[i] ?? 0;
            sourceAttribute.setX(vi, ATTRIBUTE_MAP_INVALID_VALUE);
        }
    }
}
