import _ from 'lodash';
import numeric from 'numeric';
import * as THREE from 'three';

type Matrix3Data = [[number, number, number], [number, number, number], [number, number, number]];

// Represents a 3x3 square matrix
export class Matrix3 {
    data: Matrix3Data;

    constructor(data?: Matrix3Data) {
        if (data && data.length !== 3) {
            throw new Error('Matrix3 data must have 3 rows');
        }

        if (data && data.some(row => row.length !== 3)) {
            throw new Error('Every row of Matrix3 data must have 3 elements');
        }

        this.data = data
            ? _.cloneDeep(data)
            : [
                  [0, 0, 0],
                  [0, 0, 0],
                  [0, 0, 0],
              ];
    }

    // Creates a matrix from (a x b^T) where a and b are column vectors
    static fromVectorTransposeProduct(a: THREE.Vector3, b: THREE.Vector3): Matrix3 {
        return new Matrix3([
            [a.x * b.x, a.x * b.y, a.x * b.z],
            [a.y * b.x, a.y * b.y, a.y * b.z],
            [a.z * b.x, a.z * b.y, a.z * b.z],
        ]);
    }

    // Adds `other` to this matrix in-place
    add(other: Matrix3): Matrix3 {
        for (let i = 0; i < 3; ++i) {
            for (let j = 0; j < 3; ++j) {
                (this.data as any)[i][j] += (other.data as any)[i][j];
            }
        }

        return this;
    }

    // Multiplies this matrix by the scalar value `scalar` in-place
    multiplyScalar(scalar: number): Matrix3 {
        for (let i = 0; i < 3; ++i) {
            for (let j = 0; j < 3; ++j) {
                (this.data as any)[i][j] *= scalar;
            }
        }

        return this;
    }
}

// Finds the centroid of a set of vertices
export function findCentroid(vertices: THREE.Vector3[]) {
    const result = new THREE.Vector3(0, 0, 0);

    if (!vertices.length) {
        return result;
    }

    vertices.forEach(v => result.add(v));
    result.multiplyScalar(1 / vertices.length);

    return result;
}

// Finds the optimal rotation matrix by the SVD method
// https://sigmaland.ir/wp-content/uploads/2022/01/Sigmaland-Robust-registration-of-point-sets-using-iteratively-reweighted-least-squares.pdf
export function getRotationBySvd(h: Matrix3): number[][] {
    // TODO: test performance using other SVD libraries
    const result = numeric.svd(h.data);
    const U = result.U;
    const V = result.V;

    const sign = numeric.det(U) * numeric.det(V);
    if (sign < 0) {
        for (let i = 0; i < 3; ++i) {
            const vi = V[i];
            if (vi !== undefined && vi[2] !== undefined) {
                vi[2] = vi[2] * -1;
            }
        }
    }

    return numeric.dot(V, numeric.transpose(U)) as number[][];
}

// Gets the optimal translation vector from the rotation matrix and the sample centroids
// https://sigmaland.ir/wp-content/uploads/2022/01/Sigmaland-Robust-registration-of-point-sets-using-iteratively-reweighted-least-squares.pdf
export function getTranslationFromRotation(
    R: number[][],
    firstcentroid: THREE.Vector3,
    secondcentroid: THREE.Vector3,
): [number, number, number] {
    let t0 = 0;
    let t1 = 0;
    let t2 = 0;
    const multiply = numeric.dot(R, [firstcentroid.x, firstcentroid.y, firstcentroid.z]) as number[];
    if (multiply !== undefined && multiply.length > 2) {
        t0 = multiply[0] !== undefined ? secondcentroid.x - multiply[0] : 0;
        t1 = multiply[1] !== undefined ? secondcentroid.y - multiply[1] : 0;
        t2 = multiply[2] !== undefined ? secondcentroid.z - multiply[2] : 0;
    }
    return [t0, t1, t2];
}

function getNumber(value: number | undefined) {
    return value !== undefined ? value : 0;
}

export function makeAlignmentMatrix(R: number[][], T: number[]): THREE.Matrix4 {
    const m = new THREE.Matrix4();
    const r0 = R[0];
    const r1 = R[1];
    const r2 = R[2];
    if (r0 !== undefined && r1 !== undefined && r2 !== undefined) {
        m.set(
            getNumber(r0[0]),
            getNumber(r0[1]),
            getNumber(r0[2]),
            getNumber(T[0]),
            getNumber(r1[0]),
            getNumber(r1[1]),
            getNumber(r1[2]),
            getNumber(T[1]),
            getNumber(r2[0]),
            getNumber(r2[1]),
            getNumber(r2[2]),
            getNumber(T[2]),
            0,
            0,
            0,
            1,
        );
    }
    return m;
}

export function generateMatrixFromPairs(
    pointsStatic: THREE.Vector3[],
    pointsMoving: THREE.Vector3[],
    centroidStatic: THREE.Vector3,
    centroidMoving: THREE.Vector3,
): Matrix3 {
    const H = new Matrix3();
    // provide ray length safety so we don't keep indexing into undefined later
    const numPairs = Math.min(pointsStatic.length, pointsMoving.length);

    // reusable objects
    const vDiff1 = new THREE.Vector3();
    const vDiff2 = new THREE.Vector3();

    for (let i = 0; i < numPairs; i++) {
        const pm = pointsMoving[i];
        const ps = pointsStatic[i];
        if (!pm || !ps) {
            continue;
        }

        const product = Matrix3.fromVectorTransposeProduct(
            vDiff1.subVectors(pm, centroidMoving),
            vDiff2.subVectors(ps, centroidStatic),
        );
        H.add(product);
    }
    return H;
}

interface UnstructuredAlignmentResult {
    R: number[][];
    T: [number, number, number];
}

export function solveAlignmentForPairs(
    pointsStatic: THREE.Vector3[],
    pointsMoving: THREE.Vector3[],
): UnstructuredAlignmentResult {
    // make sure we only use a matched set of points
    const numPairs = Math.min(pointsStatic.length, pointsMoving.length);

    // bail and return identity
    if (numPairs < 3) {
        return {
            R: [
                [1, 0, 0],
                [0, 1, 0],
                [0, 0, 1],
            ],
            T: [0, 0, 0],
        };
    }

    const sameLength = pointsMoving.length === pointsStatic.length;

    // ternary to avoid the shallow copy if not needed
    const centroidStatic = findCentroid(sameLength ? pointsStatic : pointsStatic.slice(0, numPairs));
    const centroidMoving = findCentroid(sameLength ? pointsMoving : pointsMoving.slice(0, numPairs));

    const H: Matrix3 = generateMatrixFromPairs(pointsStatic, pointsMoving, centroidStatic, centroidMoving);

    // calculate R and T from H by SVD
    const R: number[][] = getRotationBySvd(H);
    const T: [number, number, number] = getTranslationFromRotation(R, centroidMoving, centroidStatic);
    return { R, T };
}
