import { logger } from '../Utils/Logger';
import { AttributeName } from './BufferAttributeConstants';
import {
    Matrix3,
    getRotationBySvd,
    getTranslationFromRotation,
    makeAlignmentMatrix,
} from './IterativeClosestPoint.utils';
import { ensureMeshIndex } from './MeshIndex';
import _ from 'lodash';
import prand from 'pure-rand';
import * as THREE from 'three';
import type { MeshBVH, HitPointInfo } from 'three-mesh-bvh';

interface DeconflictionOptions {
    maxIterations: number;
    numSamples: number;
    nonCollidingBackoff: number;
    offsetDistance: number;
    registrationDistance: number;
    /** The distance by which to overlap conflicting geometries. */
    conflictDistance: number;
    prngSeed?: number;
    worldTStatic: THREE.Matrix4;
    worldTMoving: THREE.Matrix4;
}

interface IterationResult {
    numCollidingSamples: number;
    numOffsetSamples: number;
    numNoncollidingSamples: number;
    collidingWeight: number;
    noncollidingWeight: number;
    transformation: THREE.Matrix4;
}

interface DeconflictionResult {
    success: boolean;
    transformation: THREE.Matrix4;
    numIterations: number;
    iterationResults: IterationResult[];
}

const IDENTITY = new THREE.Matrix4().identity();

/**
 * Deconflict two geometries, i.e. find a transformation such that the geometries do not collide.
 * NB: The returned transformation is to be applied to the moving geometry in the world frame.
 *
 * This algorithm is a modified version of ICP. The main difference is in how the sample points are selected. On each
 * iteration, we include samples from the following three categories:
 * 1. Colliding points: These are points in the moving geometry that are colliding with the static geometry. We include
 * them to force the geometries apart.
 * 2. Offset points: These are points in the moving geometry that are not colliding with the static geometry, but are
 * within `offsetDistance` of it. We include them, offset by `offsetDistance` to further force the geometries apart, and
 * to increase the number of samples in the colliding region, which will shrink in the latter iterations, in order to
 * not give too much weight to only a handful of colliding points.
 * 3. Non-colliding points: These are points in the moving geometry that are not colliding with the static geometry, but
 * are within `registrationDistance` of it. We include them to encourage the geometries to stay close together, so that
 * the final transformation is not larger in magnitude than necessary to remove the collisions.
 *
 * @param staticGeometry The static, i.e. non-moving geometry
 * @param movingGeometry The moving geometry, to which the output of this function should be applied
 * @param options Options for the deconfliction algorithm
 * @returns The resultant transformation and some statistics
 */
export function deconflictGeometries(
    staticGeometry: THREE.BufferGeometry,
    movingGeometry: THREE.BufferGeometry,
    options: Partial<DeconflictionOptions> = {},
): DeconflictionResult {
    const DEFAULT_OPTIONS: DeconflictionOptions = {
        maxIterations: 25,
        numSamples: 5000,
        nonCollidingBackoff: 0.8,
        offsetDistance: 0.02,
        registrationDistance: 0.3,
        conflictDistance: 0.0,
        worldTStatic: IDENTITY,
        worldTMoving: IDENTITY,
    };
    const {
        maxIterations,
        numSamples,
        nonCollidingBackoff,
        offsetDistance,
        registrationDistance,
        conflictDistance,
        worldTStatic,
        worldTMoving,
        prngSeed,
    } = { ...DEFAULT_OPTIONS, ...options };

    const rng = prand.mersenne(prngSeed ?? Date.now());

    // Get the positions and normals of the moving geometry in the static geometry frame.

    const movingPosition = movingGeometry.getAttribute(AttributeName.Position);
    const movingNormal = movingGeometry.getAttribute(AttributeName.Normal);
    if (!(movingPosition && movingNormal)) {
        logger.warn('Missing positions and/or normals for moving geometry.');
        return { success: false, transformation: IDENTITY.clone(), numIterations: 0, iterationResults: [] };
    }

    const staticTWorld = worldTStatic.clone().invert();
    const staticTMoving = new THREE.Matrix4().multiplyMatrices(staticTWorld, worldTMoving);

    const movingPoints = _.chunk(movingPosition.array, 3).map(p =>
        new THREE.Vector3(p[0], p[1], p[2]).applyMatrix4(staticTMoving),
    );
    const movingNormals = _.chunk(movingNormal.array, 3).map(p =>
        new THREE.Vector3(p[0], p[1], p[2]).transformDirection(staticTMoving),
    );

    const staticBoundsTree = ensureMeshIndex(staticGeometry);

    // Initialize loop variables.

    let i = 0;
    let hasCollisions = true;
    const iterationResults: IterationResult[] = [];
    let noncollidingWeight = 0.1;
    const accumulatedTransformation = IDENTITY.clone();

    // Pre-allocate the array of closest points and signed distances to avoid re-allocating on each iteration.
    const numMovingPoints = movingPoints.length;
    const closestPoints = new Array(numMovingPoints)
        .fill(null)
        .map(() => ({ closestPoint: new THREE.Vector3(), signedDistance: undefined }));

    for (; i < maxIterations; ++i) {
        calculateClosestPoints(staticBoundsTree, movingPoints, movingNormals, registrationDistance, closestPoints);
        if (!checkHasCollisions(closestPoints, conflictDistance)) {
            hasCollisions = false;
            break;
        }

        const correspondences = createCorrespondences(
            movingPoints,
            movingNormals,
            closestPoints,
            numSamples,
            offsetDistance,
            conflictDistance,
            rng,
        );
        const { numCollidingSamples, numOffsetSamples, numNoncollidingSamples } = correspondences;

        const transformationResult = findTransformation(correspondences, noncollidingWeight);
        const incrementalTransformation = makeAlignmentMatrix(transformationResult.R, transformationResult.T);

        const worldIncrementalTransformation = worldTStatic
            .clone()
            .multiply(incrementalTransformation)
            .multiply(staticTWorld);
        iterationResults.push({
            numCollidingSamples,
            numOffsetSamples,
            numNoncollidingSamples,
            collidingWeight: transformationResult.collidingWeight,
            noncollidingWeight: transformationResult.nonCollidingWeight,
            transformation: worldIncrementalTransformation,
        });

        movingPoints.forEach(p => p.applyMatrix4(incrementalTransformation));
        movingNormals.forEach(n => n.transformDirection(incrementalTransformation));
        accumulatedTransformation.premultiply(incrementalTransformation);

        noncollidingWeight *= nonCollidingBackoff;
    }

    const worldTotalMatrix = worldTStatic.clone().multiply(accumulatedTransformation).multiply(staticTWorld);

    if (!hasCollisions) {
        return { success: true, transformation: worldTotalMatrix, numIterations: i, iterationResults };
    }

    // We may have finally removed all collisions on the last iteration, but exit checks happen at the beginning of the
    // loop, so we need to check here.
    calculateClosestPoints(staticBoundsTree, movingPoints, movingNormals, registrationDistance, closestPoints);
    if (!checkHasCollisions(closestPoints, conflictDistance)) {
        return { success: true, transformation: worldTotalMatrix, numIterations: i, iterationResults };
    }

    return { success: false, transformation: worldTotalMatrix, numIterations: i, iterationResults };
}

interface ClosestPointResult {
    closestPoint: THREE.Vector3;
    signedDistance?: number;
}

// Calculates the closest point on the reference geometry to each query point, along with the signed distance.
function calculateClosestPoints(
    referenceBvh: MeshBVH,
    queryVertexPositions: THREE.Vector3[],
    queryVertexNormals: THREE.Vector3[],
    maxDistance: number,
    closestPoints: ClosestPointResult[],
): void {
    const minDistance = 0;
    const targetInfo = {} as HitPointInfo;
    const diffVector = new THREE.Vector3();

    const numQueryPoints = queryVertexPositions.length;
    for (let i = 0; i < numQueryPoints; ++i) {
        const position = queryVertexPositions[i];
        const normal = queryVertexNormals[i];
        const result = closestPoints[i];
        if (!(position && normal && result)) {
            throw new Error('Index error');
        }

        if (!referenceBvh.closestPointToPoint(position, targetInfo, minDistance, maxDistance)) {
            result.signedDistance = undefined;
            continue;
        }

        const closestPoint = targetInfo.point;
        diffVector.subVectors(closestPoint, position);
        const signedDistance = Math.sign(normal.dot(diffVector)) * targetInfo.distance;

        result.closestPoint.copy(closestPoint);
        result.signedDistance = signedDistance;
    }
}

function checkHasCollisions(closestPoints: ClosestPointResult[], conflictDistance: number): boolean {
    return closestPoints.some(p => p.signedDistance !== undefined && p.signedDistance + conflictDistance < 0);
}

interface CorrespondencesResult {
    staticSamples: THREE.Vector3[];
    movingSamples: THREE.Vector3[];
    numCollidingSamples: number;
    numOffsetSamples: number;
    numNoncollidingSamples: number;
}

// Creates correspondences between the query geometry points (potentially modified if they are in the offset region) and
// the reference geometry points.
// EPDPLT-3246 High cognitive complexity. Consider refactoring to make this function easier to test and maintain.
// eslint-disable-next-line sonarjs/cognitive-complexity
function createCorrespondences(
    queryVertexPositions: THREE.Vector3[],
    queryVertexNormals: THREE.Vector3[],
    closestPoints: ClosestPointResult[],
    numSamples: number,
    offsetDistance: number,
    conflictDistance: number,
    rng: prand.RandomGenerator,
): CorrespondencesResult {
    // Scramble the indices so that we process the points in a random order. If we then select the first N points from
    // a set, we will have gotten a random sample of N points.
    const numPoints = queryVertexPositions.length;
    const scrambledIndices = createScrambledIndexVector(numPoints, rng);

    // Sort the points into three categories: colliding, within the offset region, and within the registration region.

    // These are the indices of the moving geometry vertices that collide with the static geometry.
    const collidingIndices: number[] = [];
    // These are the indices of the moving geometry vertices that are within `offsetDistance` of the static geometry,
    // but not colliding.
    const withinOffsetIndices: number[] = [];
    // These are the indices of the moving geometry vertices that are within `registrationDistance` of but further than
    // `offsetDistance` from the static geometry.
    const withinRegistrationIndices: number[] = [];

    for (const i of scrambledIndices) {
        const signedDistance = closestPoints[i]?.signedDistance;
        if (signedDistance === undefined) {
            continue;
        }

        if (signedDistance <= 0) {
            collidingIndices.push(i);
        } else if (signedDistance <= offsetDistance) {
            withinOffsetIndices.push(i);
        } else {
            withinRegistrationIndices.push(i);
        }
    }

    const desiredNumCollidingSamples = Math.floor(numSamples / 2);
    const desiredNumNoncollidingSamples = numSamples - desiredNumCollidingSamples;
    const staticSamples: THREE.Vector3[] = [];
    const movingSamples: THREE.Vector3[] = [];

    // Add colliding samples.
    for (const i of collidingIndices.slice(0, desiredNumCollidingSamples)) {
        const staticPoint = closestPoints[i]?.closestPoint;
        const movingPoint = queryVertexPositions[i];
        const movingNormal = queryVertexNormals[i];
        if (!(staticPoint && movingPoint && movingNormal)) {
            throw new Error('Index error');
        }

        staticSamples.push(staticPoint.clone());
        // "Bump in" the moving geometry point to encourage leaving some overlap between the geometries.
        movingSamples.push(movingPoint.clone().sub(movingNormal.clone().multiplyScalar(conflictDistance)));
    }

    const numCollidingSamples = staticSamples.length;

    // Add samples from the offset region, if necessary, to augment the colliding samples.
    if (staticSamples.length < desiredNumCollidingSamples) {
        for (const i of withinOffsetIndices.slice(0, desiredNumCollidingSamples - staticSamples.length)) {
            const staticPoint = closestPoints[i]?.closestPoint;
            const movingPoint = queryVertexPositions[i];
            const movingNormal = queryVertexNormals[i];
            if (!(staticPoint && movingPoint && movingNormal)) {
                throw new Error('Index error');
            }

            staticSamples.push(staticPoint.clone());
            // "Bump out" the moving geometry point to encourage separation between the geometries.
            // Subtract out any desired conflict distance.
            movingSamples.push(
                movingPoint.clone().add(movingNormal.clone().multiplyScalar(offsetDistance - conflictDistance)),
            );
        }
    }

    const numOffsetSamples = staticSamples.length - numCollidingSamples;

    // Add non-colliding samples.
    for (const i of withinRegistrationIndices.slice(0, desiredNumNoncollidingSamples)) {
        const staticPoint = closestPoints[i]?.closestPoint;
        const movingPoint = queryVertexPositions[i];
        if (!(staticPoint && movingPoint)) {
            throw new Error('Index error');
        }

        staticSamples.push(staticPoint.clone());
        movingSamples.push(movingPoint.clone());
    }

    const numNoncollidingSamples = staticSamples.length - numCollidingSamples - numOffsetSamples;

    return { staticSamples, movingSamples, numCollidingSamples, numOffsetSamples, numNoncollidingSamples };
}

function createScrambledIndexVector(length: number, rng: prand.RandomGenerator): number[] {
    const result = _.range(length);
    for (let i = 0; i < length - 1; ++i) {
        const swapIndex = prand.unsafeUniformIntDistribution(i, length - 1, rng);
        const temp = result[i] as number;
        if (result[i] !== undefined) {
            result[i] = result[swapIndex] as number;
        }
        if (result[swapIndex] !== undefined) {
            result[swapIndex] = temp;
        }
    }
    return result;
}

// Finds the incremental transformation that best aligns the moving samples to the static samples.
function findTransformation(correspondences: CorrespondencesResult, rawNoncollidingWeight: number) {
    const { staticSamples, movingSamples } = correspondences;
    const rawCollidingWeight = 1 - rawNoncollidingWeight;

    // The number of colliding and non-colliding samples may not be equal (there may not have been enough colliding
    // points, or even enough non-colliding points within the registration distance), so we need to modify the colliding
    // and non-colliding weights to account for this disparity.

    const numSamples = staticSamples.length;
    const numCollidingSamples = correspondences.numCollidingSamples + correspondences.numOffsetSamples;
    const numNoncollidingSamples = numSamples - numCollidingSamples;
    const ratioColliding = numCollidingSamples / numSamples;

    const collidingWeight = rawCollidingWeight * (1 - ratioColliding);
    const nonCollidingWeight = rawNoncollidingWeight * ratioColliding;

    const staticCentroid = findWeightedCentroid(
        staticSamples,
        numCollidingSamples,
        collidingWeight,
        nonCollidingWeight,
    );
    const movingCentroid = findWeightedCentroid(
        movingSamples,
        numCollidingSamples,
        collidingWeight,
        nonCollidingWeight,
    );

    // Create the `C` matrix
    // https://sigmaland.ir/wp-content/uploads/2022/01/Sigmaland-Robust-registration-of-point-sets-using-iteratively-reweighted-least-squares.pdf

    const C = new Matrix3();
    for (let i = 0; i < numSamples; ++i) {
        const staticSample = staticSamples[i];
        const movingSample = movingSamples[i];
        if (!(staticSample && movingSample)) {
            throw new Error('Index error');
        }

        const product = Matrix3.fromVectorTransposeProduct(movingSample, staticSample).multiplyScalar(
            i < numCollidingSamples ? collidingWeight : nonCollidingWeight,
        );
        C.add(product);
    }

    const totalWeight = numCollidingSamples * collidingWeight + numNoncollidingSamples * nonCollidingWeight;
    C.multiplyScalar(1 / totalWeight);

    C.add(Matrix3.fromVectorTransposeProduct(movingCentroid, staticCentroid).multiplyScalar(-1));

    // calculate R and T from C by SVD
    const R = getRotationBySvd(C);
    const T = getTranslationFromRotation(R, movingCentroid, staticCentroid);
    return { R, T, collidingWeight, nonCollidingWeight };
}

function findWeightedCentroid(
    samples: THREE.Vector3[],
    numCollidingSamples: number,
    collidingWeight: number,
    nonCollidingWeight: number,
): THREE.Vector3 {
    const numNoncollidingSamples = samples.length - numCollidingSamples;
    const totalWeight = numCollidingSamples * collidingWeight + numNoncollidingSamples * nonCollidingWeight;

    const centroid = new THREE.Vector3();
    samples.forEach((p, i) => {
        const weight = i < numCollidingSamples ? collidingWeight : nonCollidingWeight;
        centroid.add(p.clone().multiplyScalar(weight));
    });

    centroid.multiplyScalar(1 / totalWeight);

    return centroid;
}
