import { AttributeName } from './BufferAttributeConstants';
import { ensureMeshIndex } from './MeshIndex';
import * as THREE from 'three';
import type { HitPointInfo } from 'three-mesh-bvh';

export interface VertexResult {
    closestPoint: THREE.Vector3;
    signedDistance: number;
}
/**
 * Computes, on the CPU, the closest point on a reference geometry
 * @param referenceMesh The reference geometry on which to find the closest point
 * @param queryMesh The closest point is calculated for each vertex in this geometry
 * @param queryToReference The transformation from the query mesh frame to the reference mesh frame. If not provided,
 *   the identity transformation is used.
 * @returns The closest point and signed distance for each query vertex, if successful. Closest points are expressed in
 *   the reference mesh frame.
 */
export function computeClosestPointOnCpu(
    referenceMesh: THREE.BufferGeometry,
    queryMesh: THREE.BufferGeometry,
    queryToReference?: THREE.Matrix4,
): VertexResult[] | undefined {
    const referenceBvh = ensureMeshIndex(referenceMesh);

    const referencePositions = referenceMesh.attributes[AttributeName.Position];
    const referenceIndex = referenceMesh.index?.array;
    const queryPositions = queryMesh.attributes[AttributeName.Position];

    if (!(queryPositions && referencePositions && referenceIndex)) {
        return;
    }

    const transform = queryToReference ?? new THREE.Matrix4().identity();

    const queryPosition = new THREE.Vector3();
    const hitFace = new THREE.Triangle();
    const hitFaceNormal = new THREE.Vector3();
    const results: VertexResult[] = [];

    const numVertices = queryPositions.count;
    for (let i = 0; i < numVertices; i++) {
        const target: HitPointInfo = {
            point: new THREE.Vector3(),
            distance: 0,
            faceIndex: -1,
        };

        queryPosition.fromBufferAttribute(queryPositions, i).applyMatrix4(transform);
        referenceBvh.closestPointToPoint(queryPosition, target);

        if (target.faceIndex === -1) {
            return;
        }

        if (!setFaceTriangle(target.faceIndex, referencePositions, referenceIndex, hitFace)) {
            return;
        }

        hitFace.getNormal(hitFaceNormal);
        // `queryPosition` will now contain the vector from the closest point on the mesh to the query point.
        queryPosition.sub(target.point);

        const side = queryPosition.dot(hitFaceNormal) > 0 ? 1 : -1;

        results.push({
            closestPoint: target.point.clone(),
            signedDistance: target.distance * side,
        });
    }

    return results;
}

function setFaceTriangle(
    faceIndex: number,
    positions: THREE.BufferAttribute | THREE.InterleavedBufferAttribute,
    meshIndex: ArrayLike<number>,
    triangle: THREE.Triangle,
): boolean {
    const offset = faceIndex * 3;
    if (offset + 2 >= meshIndex.length) {
        return false;
    }

    triangle.a.fromBufferAttribute(positions, meshIndex[offset] as number);
    triangle.b.fromBufferAttribute(positions, meshIndex[offset + 1] as number);
    triangle.c.fromBufferAttribute(positions, meshIndex[offset + 2] as number);

    return true;
}
