/* eslint-disable sonarjs/cognitive-complexity, max-lines */
import { logger } from '../Utils/Logger';
import { AttributeName } from './BufferAttributeConstants';
import { makeAlignmentMatrix, solveAlignmentForPairs } from './IterativeClosestPoint.utils';
import { computeLocalMaximaOfVertices, fitAPlaneToPoints, getAverageVector } from './Mesh3d.util';
import type { ToothNumber } from '@orthly/items';
import * as THREE from 'three';

export type ToothFacesInfo = {
    toothNumber: ToothNumber;
    faces: number[];
};

export type ToothInfo = {
    toothNumber: ToothNumber;
    faces: number[];
    verticesIndices: number[];
    vertices: THREE.Vector3[];
    normals: THREE.Vector3[];
    centroid: THREE.Vector3;
    localMaxima: THREE.Vector3;
    avgNormal: THREE.Vector3;
};

export function extractTeethFacesInfoFromFacetMarks(
    bufferGeom: THREE.BufferGeometry,
    facetMarks: number[],
): Map<number, ToothFacesInfo> {
    const index = bufferGeom.getIndex();
    if (!index) {
        throw new Error('No index found');
    }
    const result: Map<number, ToothFacesInfo> = new Map();

    const numFaces = index.count / 3;
    for (let i = 0; i < numFaces; ++i) {
        const mark = facetMarks[i];
        if (mark === undefined) {
            continue;
        }
        if (!result.has(mark)) {
            result.set(mark, {
                toothNumber: mark as ToothNumber,
                faces: [],
            });
        }
        result.get(mark)?.faces.push(i);
    }

    return result;
}

// from facet marks, extract the vertices, normals, faces, centroid of each tooth
export function extractAllTeethInfo(
    bufferGeom: THREE.BufferGeometry,
    teethFacesInfo: Map<number, ToothFacesInfo>,
): Map<number, ToothInfo> {
    const positions = bufferGeom.getAttribute(AttributeName.Position);
    const normals = bufferGeom.getAttribute(AttributeName.Normal);
    if (!positions) {
        throw new Error('No positions found');
    }
    if (!normals) {
        throw new Error('No normals found');
    }
    const index = bufferGeom.getIndex();
    if (!index) {
        throw new Error('No index found');
    }
    const backupIndex = bufferGeom.getAttribute(AttributeName.BackupIndex);
    const seen = new Set();
    if (!backupIndex) {
        throw new Error('No backup index to check');
    }

    const result: Map<number, ToothInfo> = new Map();

    for (const toothFacesInfo of teethFacesInfo.values()) {
        const mark = toothFacesInfo.toothNumber;
        if (mark === undefined) {
            continue;
        }
        const numFaces = toothFacesInfo.faces.length;

        for (let i = 0; i < numFaces; ++i) {
            const faceIndex = toothFacesInfo.faces[i];
            if (faceIndex === undefined) {
                continue;
            }
            if (!result.has(mark)) {
                result.set(mark, {
                    toothNumber: mark as ToothNumber,
                    faces: [],
                    verticesIndices: [],
                    vertices: [],
                    normals: [],
                    centroid: new THREE.Vector3(0, 0, 0),
                    localMaxima: new THREE.Vector3(0, 0, 0),
                    avgNormal: new THREE.Vector3(0, 0, 0),
                });
            }
            result.get(mark)?.faces.push(faceIndex);
            [0, 1, 2].forEach(offset => {
                const vIndex = backupIndex.array[3 * faceIndex + offset];
                if (vIndex !== undefined && !seen.has(vIndex)) {
                    result.get(mark)?.verticesIndices.push(vIndex);
                }
            });
        }
    }

    for (const toothInfo of result.values()) {
        toothInfo.faces = [...new Set(toothInfo.faces)];
        toothInfo.verticesIndices = [...new Set(toothInfo.verticesIndices)];

        for (const vertex of toothInfo.verticesIndices) {
            const v = new THREE.Vector3().fromBufferAttribute(positions, vertex);
            toothInfo.vertices.push(v);
            toothInfo.centroid.add(v);

            if (!normals) {
                continue;
            }
            const nx = normals.getX(vertex);
            const ny = normals.getY(vertex);
            const nz = normals.getZ(vertex);
            const n = new THREE.Vector3(nx, ny, nz);
            toothInfo.normals.push(n);
            toothInfo.avgNormal.add(n);
        }
        toothInfo.centroid.divideScalar(toothInfo.vertices.length);
        if (toothInfo.normals.length === 0) {
            continue;
        }
        toothInfo.avgNormal.divideScalar(toothInfo.normals.length);
        toothInfo.avgNormal.normalize();
    }
    return result;
}

// given the teethInfo and a plane, compute the local maximas of each tooth (centroid of the tooth projected on the plane)
export function computeLocalMaximas(teethInfo: Map<number, ToothInfo>, plane: THREE.Plane) {
    const result: THREE.Vector3[] = [];
    for (const toothInfo of teethInfo.values()) {
        // Tooth number 1 designate the Gingiva not the tooth (for some DCM files)
        // TODO: check the tooth number 1 is always the gingiva
        if (toothInfo.toothNumber === 1) {
            continue;
        }
        const maxVertex = computeLocalMaximaOfVertices(toothInfo.vertices, plane);
        toothInfo.localMaxima.copy(maxVertex);
        result.push(maxVertex);
    }
    return result;
}

// get the canonical teeth coordinates (read manually by raycasting the mouse clicks on an occlusal plane object)
export function getCanonicalTeethCoords3D() {
    const canonicalTeethCoords: Map<ToothNumber, THREE.Vector3> = new Map();

    // upper jaw
    canonicalTeethCoords.set(15, new THREE.Vector3(24.607746937862345, -14.960910696061728, 1.9984014443252818e-14));
    canonicalTeethCoords.set(14, new THREE.Vector3(24.165344602566172, -4.094496433483943, 1.9984014443252818e-14));
    canonicalTeethCoords.set(13, new THREE.Vector3(21.398219782582363, 5.827530352401, 1.9984014443252818e-14));
    canonicalTeethCoords.set(12, new THREE.Vector3(18.771611080318866, 13.268926310968657, 1.9984014443252818e-14));
    canonicalTeethCoords.set(11, new THREE.Vector3(15.129115534041938, 19.165141497996817, 1.9984014443252818e-14));
    canonicalTeethCoords.set(10, new THREE.Vector3(10.569541297201559, 23.662650311815852, 1.9984014443252818e-14));
    canonicalTeethCoords.set(9, new THREE.Vector3(3.6048077874378612, 26.24893475664708, 1.9984014443252818e-14));
    canonicalTeethCoords.set(8, new THREE.Vector3(-4.099249041356152, 26.397404580878465, 1.9984014443252818e-14));
    canonicalTeethCoords.set(7, new THREE.Vector3(-10.438859610448826, 24.034301258078635, 1.9984014443252818e-14));
    canonicalTeethCoords.set(6, new THREE.Vector3(-15.28443131657158, 19.171542098553793, 1.9984014443252818e-14));
    canonicalTeethCoords.set(5, new THREE.Vector3(-19.015804548615087, 13.180610896708878, 1.9984014443252818e-14));
    canonicalTeethCoords.set(4, new THREE.Vector3(-21.422949915075087, 6.759663414648651, 1.9984014443252818e-14));
    canonicalTeethCoords.set(3, new THREE.Vector3(-24.252140158080255, -2.9921722979938856, 1.9984014443252818e-14));
    canonicalTeethCoords.set(2, new THREE.Vector3(-25.298314928528114, -14.418390108657245, 1.9984014443252818e-14));

    // lower jaw
    canonicalTeethCoords.set(31, new THREE.Vector3(-23.162983049700518, -13.14116602084976, 0));
    canonicalTeethCoords.set(30, new THREE.Vector3(-19.953443374187664, -1.0243383195155147, 0));
    canonicalTeethCoords.set(29, new THREE.Vector3(-17.468077600573956, 7.2862234052236765, 0));
    canonicalTeethCoords.set(28, new THREE.Vector3(-15.196327927934773, 14.946343289570398, 0));
    canonicalTeethCoords.set(27, new THREE.Vector3(-11.514967816529603, 21.292788109840284, 0));
    canonicalTeethCoords.set(26, new THREE.Vector3(-6.683022778019173, 23.847866922394783, 0));
    canonicalTeethCoords.set(25, new THREE.Vector3(-2.058298148837304, 25.040023845198675, 0));
    canonicalTeethCoords.set(24, new THREE.Vector3(2.1148906501409277, 24.948565454661683, 0));
    canonicalTeethCoords.set(23, new THREE.Vector3(6.593793509795268, 23.517849233464087, 0));
    canonicalTeethCoords.set(22, new THREE.Vector3(11.53190702962837, 20.978759650024454, 0));
    canonicalTeethCoords.set(21, new THREE.Vector3(14.919704894484726, 14.60161550985404, 0));
    canonicalTeethCoords.set(20, new THREE.Vector3(17.633393443316837, 7.304131341832312, 0));
    canonicalTeethCoords.set(19, new THREE.Vector3(20.507065378259952, -1.3025509061140568, 0));
    canonicalTeethCoords.set(18, new THREE.Vector3(23.45753923574053, -12.518675345605452, 0));

    return canonicalTeethCoords;
}

// align the canonical 3d coordinates to the computed 3d coordinates
function alignTwoSetsOf3DCoords(
    computed3DTeethCoords: Map<ToothNumber, THREE.Vector3>,
    canonical3DTeethCoords: Map<ToothNumber, THREE.Vector3>,
) {
    // set correspondences
    const pairsBaseVertices: THREE.Vector3[] = [];
    const pairsMoveableVertices: THREE.Vector3[] = [];
    for (const toothNumber of computed3DTeethCoords.keys()) {
        const computedCoord = computed3DTeethCoords.get(toothNumber);
        const canonicalCoord = canonical3DTeethCoords.get(toothNumber);
        if (!computedCoord || !canonicalCoord) {
            continue;
        }
        pairsBaseVertices.push(computedCoord.clone());
        pairsMoveableVertices.push(canonicalCoord.clone());
    }

    let totalDistanceBefore = 0;
    pairsMoveableVertices.forEach((v, index) => {
        const vb = pairsBaseVertices[index];
        if (vb) {
            const d = v.distanceTo(vb);
            totalDistanceBefore += d;
        }
    });
    logger.info(`totalDistanceBefore: ${totalDistanceBefore}`);

    const { R, T } = solveAlignmentForPairs(pairsBaseVertices, pairsMoveableVertices);
    const Rvalid = R.every(row => row.every(value => !isNaN(value)));
    if (!Rvalid) {
        logger.error('Alignment step - SVD failed, R is invalid', R);
        return undefined;
    }

    const alignMatrix = makeAlignmentMatrix(R, T);

    // compute errors
    let totalDistanceAfter = 0;
    pairsMoveableVertices.forEach((v, index) => {
        v.applyMatrix4(alignMatrix);
        const vb = pairsBaseVertices[index];
        if (vb) {
            const d = v.distanceTo(vb);
            totalDistanceAfter += d;
        }
    });
    logger.info(`totalDistanceAfter: ${totalDistanceAfter}`);

    return alignMatrix;
}

// compute the final alignment matrix that aligns the occlusal plane mesh to the final position
export function computeOcclusalPlaneAlignment(teethInfo: Map<number, ToothInfo>, plane: THREE.Plane) {
    const canonical3DTeethCoords = getCanonicalTeethCoords3D();
    const computed3DTeethCoords: Map<ToothNumber, THREE.Vector3> = new Map();

    // first compute the 2d coordinates of the local maximas
    for (const toothInfo of teethInfo.values()) {
        // Tooth number 1 designate the Gingiva not the tooth (for some DCM files)
        // TODO: check the tooth number 1 is always the gingiva
        if (toothInfo.toothNumber === 1) {
            continue;
        }
        const centroid = toothInfo.centroid.clone();
        // project the centroid onto the plane
        const projectedCentroid = new THREE.Vector3();
        plane.projectPoint(centroid, projectedCentroid);

        // rotate the projected centroid to align with the z axis (thus keeping only x and y)
        computed3DTeethCoords.set(toothInfo.toothNumber as ToothNumber, projectedCentroid);
    }

    // Compute rotation and translation to align the 2d canonical positions to the 2d centroids position
    const A = alignTwoSetsOf3DCoords(computed3DTeethCoords, canonical3DTeethCoords);
    return A;
}

export type ProcessedTeethAndOcclusalEstimateData = {
    alignmentResult: THREE.Matrix4 | undefined;
    plane: THREE.Plane;
    initialPlane: THREE.Plane;
    avgMaxima: THREE.Vector3;
    localMaximas: THREE.Vector3[];
    avgCentroid: THREE.Vector3;
    avgNormal: THREE.Vector3;
    centroids: THREE.Vector3[];
    teethInfo: Map<number, ToothInfo>;
};

// Main function to process facet marks (labels) and estimate the occlusal plane
export function processTeethLabelsAndEstimateOcclusalPlane(
    bufferGeom: THREE.BufferGeometry,
    teethFacesInfo: Map<number, ToothFacesInfo>,
    debug: boolean = false,
): ProcessedTeethAndOcclusalEstimateData | null {
    const teethInfo = extractAllTeethInfo(bufferGeom, teethFacesInfo);
    if (teethInfo.size < 4) {
        logger.error('Not enough teeth to compute occlusal plane', { size: teethInfo.size });
        return null;
    }
    const centroids: THREE.Vector3[] = [];
    const avgNormals: THREE.Vector3[] = [];

    if (debug) {
        logger.info('****** statistics ******');
    }
    for (const toothInfo of teethInfo.values()) {
        if (debug) {
            logger.info(`tooth ${toothInfo.toothNumber} vertices: ${toothInfo.vertices.length}`);
        }

        // Tooth number 1 designate the Gingiva not the tooth (for some DCM files)
        // TODO: check the tooth number 1 is always the gingiva
        if (toothInfo.toothNumber === 1) {
            continue;
        }
        const centroid = toothInfo.centroid;
        centroids.push(centroid);
        const avgNormal = toothInfo.avgNormal;
        avgNormals.push(avgNormal);
    }
    const avgCentroid = getAverageVector(centroids);
    const avgNormal = getAverageVector(avgNormals);
    let plane = fitAPlaneToPoints(centroids);
    if (plane.normal.dot(avgNormal) < 0) {
        plane = plane.negate();
    }
    const initialPlane = plane.clone();
    let localMaximas = null;
    let iteration = 0;
    let prePlane = initialPlane.clone();
    let distanceToPrePlane = 0;
    do {
        localMaximas = computeLocalMaximas(teethInfo, plane);
        plane = fitAPlaneToPoints(localMaximas);
        if (plane.normal.dot(avgNormal) < 0) {
            plane = plane.negate();
        }
        iteration++;
        distanceToPrePlane = plane.normal.distanceTo(prePlane.normal);
        prePlane = plane.clone();
    } while (distanceToPrePlane > 0.01 && iteration < 50);
    if (debug) {
        logger.info(`finding maxima ${iteration} iterations, D=${distanceToPrePlane}`);
    }
    const avgMaxima = getAverageVector(localMaximas);
    if (debug) {
        logger.info('******* computing 2d alignment with one ICP iteration ********');
    }
    const alignmentResult = computeOcclusalPlaneAlignment(teethInfo, plane);
    if (debug) {
        logger.info('Final alignment matrix: ', { alignmentResult });
    }
    return {
        alignmentResult,
        plane,
        initialPlane,
        avgMaxima,
        localMaximas,
        avgCentroid,
        avgNormal,
        centroids,
        teethInfo,
    };
}
