import { logger } from '../Utils/Logger';
import { makeAlignmentMatrix, solveAlignmentForPairs } from './IterativeClosestPoint.utils';
import { ensureMeshIndex } from './MeshIndex';
import { kdTree } from 'kd-tree-javascript';
import prand from 'pure-rand';
import * as THREE from 'three';
import type { HitPointInfo, MeshBVH } from 'three-mesh-bvh';

// Main class for Alignment/alignment
// example use:
// const Alignment = new Alignment({stopCriteria: {minError: 0.0001, maxIterations: 25}});
// const alignmentError = Alignment.DoAlignmentUsingICP(baseGeometry, moveableGeometry);

/*
 *  My rule of thumb is 1 tenth the scale of a meaningful distance in the system
 *  minRotationRadians values:
 *   COARSE: .0004 rad -> .1 * rotation that would result in 200 microns of movement at estimated 50mm cross arch radius
 *   FINE: .0001 rad -> .1 * 50 microns of movement at estimated 50mm cross arch radius
 *   ULTRA_FINE:   .00002rad  -> 1. * 10 microns of translation at 50mm cross arch radius -> theoretical cross arch accuracy of single scan
 *
 *  minTranslation values:
 *    COARSE: 0.01 -> .1 * 100 microns
 *    MEDIUM: 0.005 -> .1 * 50 microns
 *    FINE:   0.001 -> .1 * 10 microns  -> expected scanner precision local to prep area or max cross arch
 *    ULTRAFINE:  0.0005 -> .1 * 5 microns -> smaller than local scanner precision
 *
 * For min mean error, having the mean be on the same scale as the surface noise or
 * precision in the system is a crude starting point
 *  minError values:
 *    COARSE: 0.05 -> 50 microns of surface noise is a lot
 *    MEDIUM: 0.01 -> 10 microns -> Equivalant to cross arch precision
 *    FINE:   0.005 -> 5 microns  -> Equivalent to scanner precision
 *    ULTRAFINE:  0.0001 -> .1 microns -> expected for actual same objects
 * */
interface AlignmentConvergenceCriteria {
    minError: number;
    maxIterations: number;
    criteriaJoin: 'AND' | 'OR';
    minTranslation?: number;
    minRotationRadians?: number;
}

export interface AlignmentOptions {
    stopCriteria: AlignmentConvergenceCriteria;
    sampleSize: number;
    useBvh: boolean;
    registrationDistance: number;
    prngSeed?: number;
}

export interface AlignmentResult {
    resultMatrix: THREE.Matrix4;
    alignmentError: number;
    averageCorrespondenceTimePerIteration: number;
    averageAlignmentTimePerIteration: number;
    numIterations: number;
    converged: boolean;
}

const DEFAULT_STOP_CRITERIA: AlignmentConvergenceCriteria = { minError: 0.0001, maxIterations: 40, criteriaJoin: 'OR' };

export const DEFAULT_BITE_STOP_CRITERIA: AlignmentConvergenceCriteria = {
    minError: 0.05,
    maxIterations: 60,
    minRotationRadians: 0.00002,
    minTranslation: 0.0001,
    criteriaJoin: 'OR',
};

export const DEFAULT_PREPREP_STOP_CRITERIA: AlignmentConvergenceCriteria = {
    minError: 0.025,
    maxIterations: 60,
    minRotationRadians: 0.00005,
    minTranslation: 0.00001,
    criteriaJoin: 'OR',
};

export class Alignment {
    stopCriteria: AlignmentConvergenceCriteria = DEFAULT_STOP_CRITERIA;

    sampleSize = 10000;
    useBvh = true;
    registrationDistance = 0.3;

    rng: prand.RandomGenerator;

    baseCentroid: THREE.Vector3;

    movableSamples: THREE.Vector3[];

    pairsMoveableVertices: THREE.Vector3[];
    pairsBaseVertices: THREE.Vector3[];

    baseKDTree: kdTree<THREE.Vector3>;
    baseBvh: MeshBVH;
    stepError: number = 0;

    private static readonly REGISTRATION_DISTANCE_EPSILON = 0.0001;

    constructor(options: Partial<AlignmentOptions> = {}) {
        const DEFAULT_OPTIONS: AlignmentOptions = {
            stopCriteria: DEFAULT_STOP_CRITERIA,
            sampleSize: 10000,
            useBvh: true,
            registrationDistance: 0.3,
        };
        const allOptions: AlignmentOptions = {
            ...DEFAULT_OPTIONS,
            ...options,
        };

        Object.assign(this, allOptions);

        this.rng = prand.mersenne(allOptions.prngSeed ?? Date.now());
    }

    /**
     * Do alignment using an ICP (iterative closest point) algorithm.
     *
     * NB: The returned transformation is to be applied to the moveable geometry in the world frame, i.e. apply the
     * following transformation to a point p, belong to the moveable geometry and expressed in the moveable geometry
     * frame, to get q, the aligned moveable geometry point expressed in the base geometry frame, where `resultMatrix`
     * is the transformation returned by this function:
     *
     * q = worldTBase * resultMatrix * moveableTWorld * p
     *
     * @param baseGeometry The base (static) geometry
     * @param moveableGeometry The moveable geometry, which will be aligned to the base geometry
     * @param worldTBase Transformation to the world frame from the frame of `baseGeometry`. If not supplied, the
     *   identity matrix is used.
     * @param worldTMoveable Transformation to the world frame from the frame of `moveableGeometry`. If not supplied,
     *   the identity matrix is used.
     * @returns The adjustment transformation and metadata about the alignment
     */
    DoAlignmentUsingICP(
        baseGeometry: THREE.BufferGeometry,
        moveableGeometry: THREE.BufferGeometry,
        worldTBase?: THREE.Matrix4,
        worldTMoveable?: THREE.Matrix4,
    ): AlignmentResult {
        const IDENTITY = new THREE.Matrix4().identity();
        const baseTWorld = worldTBase?.clone().invert() ?? IDENTITY.clone();
        const baseTMoveable = new THREE.Matrix4().multiplyMatrices(baseTWorld, worldTMoveable || IDENTITY);

        // Compute the Alignment distance threshold (10% of the bounding sphere radius)
        // any pair with a distance less than this threshold will be considered as a correspondence
        this.registrationDistance = 0.3;

        this.movableSamples = getSomeSamplesFromVector3BufferAttrib(
            this.sampleSize,
            this.rng,
            moveableGeometry.attributes.position,
        );
        this.movableSamples.forEach(v => v.applyMatrix4(baseTMoveable));

        // Build the spatial query data structure for the base model
        if (this.useBvh) {
            this.baseBvh = ensureMeshIndex(baseGeometry);
            baseGeometry.boundingBox = this.baseBvh.getBoundingBox(new THREE.Box3());
        } else {
            const baseVertices = new THREE.Geometry().fromBufferGeometry(baseGeometry).vertices;
            const baseSamples = baseVertices;
            this.baseKDTree = new kdTree(baseSamples.slice(), distanceFunction, ['x', 'y', 'z']);
        }

        let done: boolean = false;

        // Resulting transformation matrix
        let totalMatrix = new THREE.Matrix4().identity();

        // start ICP iterations
        let averageCorrespondenceTime = 0;
        let averageAlignmentTime = 0;
        let stepError = 0;
        let iteration = 0;
        let stepTranslationMagnitudeSq = Number.MAX_VALUE;
        let stepRotationMagnitude = Number.MAX_VALUE;
        let converged: boolean = false;

        const stepTranslationVector = new THREE.Vector3();
        const stepRotationQuaternion = new THREE.Quaternion();
        const identityQuaternion = new THREE.Quaternion();

        const stopRotationMagnitude = this.stopCriteria.minRotationRadians;
        const stopTranslationMagnitudeSq = this.stopCriteria.minTranslation
            ? this.stopCriteria.minTranslation ** 2
            : undefined;
        while (!done) {
            const alignmentResult = this.SharedStepsOfICP();

            averageCorrespondenceTime += alignmentResult.correspondenceTime;
            averageAlignmentTime += alignmentResult.alignmentTime;

            // From rotation and translation, calculate the transformation matrix
            const stepMatrix = makeAlignmentMatrix(alignmentResult.R, alignmentResult.T);
            this.movableSamples.forEach(v => {
                v.applyMatrix4(stepMatrix);
            });

            // Compute step magnitudes
            stepTranslationVector.setFromMatrixPosition(stepMatrix);
            stepTranslationMagnitudeSq = stepTranslationVector.lengthSq();

            stepRotationQuaternion.setFromRotationMatrix(stepMatrix);
            stepRotationMagnitude = stepRotationQuaternion.angleTo(identityQuaternion);

            totalMatrix = totalMatrix.premultiply(stepMatrix);
            stepError = this.stepError;

            // These two must always both converge
            const convergedRotation =
                stopRotationMagnitude !== undefined && stepRotationMagnitude < stopRotationMagnitude;
            const convergedTranslation =
                stopTranslationMagnitudeSq !== undefined && stepTranslationMagnitudeSq < stopTranslationMagnitudeSq;
            const transformConverged = convergedTranslation && convergedRotation;

            const convergedError = Math.abs(stepError) <= this.stopCriteria.minError;

            converged =
                this.stopCriteria.criteriaJoin === 'AND'
                    ? convergedError && transformConverged
                    : convergedError || transformConverged;

            if (iteration >= this.stopCriteria.maxIterations || converged) {
                done = true;
                logger.info(`   Iterations: ${iteration}`);
                logger.info(`   stepError: ${stepError}`);
                logger.info(`   lastTransMag: ${Math.sqrt(stepTranslationMagnitudeSq)}`);
                logger.info(`   lastRotMag: ${stepRotationMagnitude}`);
            } else {
                iteration++;
            }
        }

        const averageCorrespondenceTimePerIteration = averageCorrespondenceTime / iteration;
        const averageAlignmentTimePerIteration = averageAlignmentTime / iteration;

        // Apply a similarity transformation to get the alignment transformation that shall be applied to the moveable
        // geometry points expressed in the world frame.
        const worldTotalMatrix = (worldTBase || IDENTITY).clone().multiply(totalMatrix).multiply(baseTWorld);

        return {
            converged,
            resultMatrix: worldTotalMatrix,
            alignmentError: stepError,
            averageCorrespondenceTimePerIteration,
            averageAlignmentTimePerIteration,
            numIterations: iteration,
        };
    }

    // main function of the ICP iteration
    SharedStepsOfICP() {
        // first step : Matching
        const correspondenceStartTime = performance.now();
        this.FindCorrespondenses();
        const correspondenceEndTime = performance.now();
        const correspondenceTime = correspondenceEndTime - correspondenceStartTime;

        // Weighting
        // Rejection

        // second step: Alignment
        const alignmentStartTime = performance.now();
        const transform = this.FindTransformationUsingCorrespondence();
        const alignmentEndTime = performance.now();
        const alignmentTime = alignmentEndTime - alignmentStartTime;

        return { R: transform.R, T: transform.T, correspondenceTime, alignmentTime };
    }

    FindCorrespondenses() {
        this.pairsMoveableVertices = new Array();
        this.pairsBaseVertices = new Array();
        const vertices = this.movableSamples;

        this.stepError = 0;

        for (let i = 0; i < vertices.length; i++) {
            const v = vertices[i];
            const vc = v?.clone();
            if (v === undefined || vc === undefined) {
                continue;
            }
            if (this.useBvh) {
                const target = {} as HitPointInfo;
                this.baseBvh.closestPointToPoint(
                    vc,
                    target,
                    Alignment.REGISTRATION_DISTANCE_EPSILON,
                    this.registrationDistance + Alignment.REGISTRATION_DISTANCE_EPSILON,
                );
                const distance = target.distance;
                if (distance < this.registrationDistance) {
                    this.pairsMoveableVertices.push(vc);
                    this.pairsBaseVertices.push(target.point.clone());
                    this.stepError += distance;
                }
            } else {
                const nn = this.baseKDTree.nearest(vc, 1);
                if (nn !== undefined && nn[0] !== undefined) {
                    const nearest_point = nn[0][0];
                    const distance = nn[0][1];
                    if (distance < this.registrationDistance) {
                        this.pairsMoveableVertices.push(v);
                        this.pairsBaseVertices.push(nearest_point);
                        this.stepError += distance;
                    }
                }
            }
        }
        this.stepError /= this.pairsMoveableVertices.length;
    }

    FindTransformationUsingCorrespondence() {
        // finding centroids
        const { R, T } = solveAlignmentForPairs(this.pairsBaseVertices, this.pairsMoveableVertices);
        return { R, T };
    }

    CalculateError() {
        if (this.useBvh) {
            if (!this.baseBvh) {
                return 0;
            }
            return CalculateAlignmentErrorUsingBVH(this.movableSamples, this.baseBvh);
        }

        if (!this.baseKDTree) {
            return 0;
        }
        return CalculateAlignmentErrorUsingKDTree(this.movableSamples, this.baseKDTree);
    }
}

function CalculateAlignmentErrorUsingKDTree(vertices: THREE.Vector3[], kdtree: kdTree<THREE.Vector3>) {
    let totalError = 0;
    for (let i = 0; i < vertices.length; i++) {
        const v = vertices[i]?.clone();
        if (v === undefined) {
            continue;
        }
        const nn = kdtree.nearest(v, 1);
        if (nn !== undefined && nn[0] !== undefined) {
            const d = Math.sqrt(nn[0][1]);
            totalError += d;
        }
    }
    return totalError / vertices.length;
}

function CalculateAlignmentErrorUsingBVH(vertices: THREE.Vector3[], bvh: MeshBVH) {
    let totalError = 0;
    for (let i = 0; i < vertices.length; i++) {
        const v = vertices[i]?.clone();
        if (v === undefined) {
            continue;
        }
        const target = {} as HitPointInfo;
        bvh.closestPointToPoint(v, target);
        const distance = target.distance;

        totalError += distance;
    }
    return totalError / vertices.length;
}

function distanceFunction(a: THREE.Vector3, b: THREE.Vector3) {
    const dx = a.x - b.x;
    const dy = a.y - b.y;
    const dz = a.z - b.z;
    return dx * dx + dy * dy + dz * dz;
}

export function getSomeSamplesFromVector3BufferAttrib(
    numberOfSamples: number,
    rng: prand.RandomGenerator,
    positionAttribute: THREE.BufferAttribute | THREE.InterleavedBufferAttribute | undefined,
) {
    const samples: THREE.Vector3[] = [];
    let total = 0;
    if (!positionAttribute) {
        return samples;
    }
    const nVertices = positionAttribute.count;
    while (total < numberOfSamples) {
        const i = prand.unsafeUniformIntDistribution(0, nVertices - 1, rng);
        const v = new THREE.Vector3().fromBufferAttribute(positionAttribute, i);
        samples.push(v);
        total++;
    }
    return samples;
}
