import { logger } from '../../Utils/Logger';
import { AttributeName } from '../BufferAttributeConstants';
import type { IndexedBufferGeometry } from '../BufferGeometry.types';
import { getFacetVerts, isIndexedGeometry } from '../BufferGeometry.util';
import { getTriangleByIndex, generateFacets } from '../Mesh3d.util';
import type { AdjacencyMatrix } from '../MeshConnectivityGraph';
import { ensureMeshIndex, getIndexMap } from '../MeshIndex';
import { THREE_CACHE } from '../ThreeObjectCache';
import { triangleFromBufferAttributes, safeTriangleClosestPointToPoint } from '../Triangle.util';
import { computeMillInterference } from './Interference';
import { getMillStatistics, isItMillableFromStatistics } from './Measurements';
import type { MillInterference, MillInterferenceAdjustmentParams } from './types';
import { getMillAdjustmentParams } from './utils';
import * as THREE from 'three';
import type { MeshBVH } from 'three-mesh-bvh';

function getMapDefaulter<Key, Value>(map: Map<Key, Value>, init: () => Value) {
    return (key: Key) => {
        let value = map.get(key);
        if (!value) {
            value = init();
            map.set(key, value);
        }
        return value;
    };
}

function computePushbackAdjustment(
    geom: IndexedBufferGeometry,
    interference: MillInterference,
    outNormal: THREE.Vector3,
) {
    return THREE_CACHE.autoAcquire(
        'triangle',
        'vector3',
    )((triangle, pos) => {
        const index = geom.getIndex();
        const posAttr = geom.getAttribute(AttributeName.Position);
        triangleFromBufferAttributes(triangle, index, posAttr, interference.sourceIndex);
        triangle.getNormal(outNormal);
        let maxDist = 0;
        const c = interference.sphere.center;
        const r = interference.sphere.radius;
        for (const fIdx of interference.collisionIndices) {
            triangleFromBufferAttributes(triangle, index, posAttr, fIdx);
            safeTriangleClosestPointToPoint(triangle, c, pos).sub(c);
            const xi = outNormal.dot(pos);
            // Ignore any interferences that are on the near side of the sphere since this
            // will produce extreme and incorrect adjustments. Tilt adjustment will generally
            // take care of these.
            if (xi < 0) {
                continue;
            }
            const ri2 = pos.lengthSq();
            const disc = xi * xi + r * r - ri2;
            if (disc >= 0) {
                maxDist = Math.max(maxDist, Math.sqrt(disc) - xi);
            }
        }

        return -maxDist;
    });
}

function computeTiltAdjustment(
    geom: IndexedBufferGeometry,
    interference: MillInterference,
    target: THREE.Vector3,
): { axis: THREE.Vector3; constant: number } {
    return THREE_CACHE.autoAcquire(
        'triangle',
        'vector3',
        'vector3',
    )((triangle, vec, norm) => {
        target.set(0, 0, 0);
        const { sourceIndex, collisionIndices, sphere } = interference;
        getTriangleByIndex(geom, sourceIndex, triangle);
        triangle.getNormal(norm);
        let totalWeight = 0;
        for (const fIdx of collisionIndices) {
            getTriangleByIndex(geom, fIdx, triangle);
            safeTriangleClosestPointToPoint(triangle, sphere.center, vec);
            vec.sub(sphere.center).divideScalar(sphere.radius);
            // Ignore any interfering facets that aren't facing the center of the sphere. Tilt
            // always moves the center away from the interference, so if the triangle isn't front
            // facing the adjustment would be making the interference worse.
            if (!triangle.isFrontFacing(vec)) {
                continue;
            }
            // We want a vector from the center of the source facet to the interference point and,
            // at this point, `vec` is a vector from the sphere center normalized by sphere radius.
            // Adding the source normal gets us the vector we want.
            vec.add(norm);
            const s = vec.length();
            const dn = vec.dot(norm);
            const cosA = dn / s;
            const sinA = Math.sqrt(1 - cosA * cosA);
            const tanA = sinA / cosA;
            if (s > 2 || sinA === 0) {
                continue;
            }
            const disc = Math.sqrt(4 - s * s);
            const tanT = (disc - s * tanA) / (s + disc * tanA);
            vec.addScaledVector(norm, -dn).multiplyScalar(tanT / (s * sinA));
            totalWeight += 1;
            target.add(vec);
        }

        if (totalWeight > 0) {
            target.divideScalar(totalWeight);
        }

        getTriangleByIndex(geom, sourceIndex, triangle);

        const constant = -Math.max(triangle.a.dot(target), triangle.b.dot(target), triangle.c.dot(target));
        return { axis: target, constant };
    });
}

function constrainBehindBvh(bvh: MeshBVH, pos: THREE.Vector3, target: THREE.Vector3) {
    THREE_CACHE.autoAcquire(
        'triangle',
        'vector3',
    )((tri, norm) => {
        // Compute the desired vertex position and store it in `target`.
        const nearest = bvh.closestPointToPoint(target);
        if (!nearest) {
            return;
        }

        getTriangleByIndex(bvh.geometry, nearest.faceIndex, tri);
        tri.getNormal(norm);
        const c = tri.a.dot(norm);
        const u = target.dot(norm) - c;
        if (u <= 0) {
            return;
        }

        target.copy(pos);
    });
}

function findBoundaryVerts(meshAdj: AdjacencyMatrix, target: Set<number> = new Set()): Set<number> {
    const edgeCounts = new Map<string, number>();
    meshAdj.forEach((neighbors, vI) => {
        if (neighbors.length === 0) {
            target.add(vI);
        }
        for (const vJ of neighbors) {
            const key = `${Math.min(vI, vJ)},${Math.max(vI, vJ)}`;
            const count = edgeCounts.get(key) ?? 0;
            edgeCounts.set(key, count + 1);
        }
    });
    for (const [edge, count] of edgeCounts) {
        if (count !== 1) {
            continue;
        }
        const [vI, vJ] = edge.split(',').map(parseInt);
        if (vI) {
            target.add(vI);
        }
        if (vJ) {
            target.add(vJ);
        }
    }
    return target;
}

type Adjustment = { adjustment: THREE.Vector3; count: number };

function applyInterferenceAdjustments(
    outGeom: IndexedBufferGeometry,
    interferences: MillInterference[],
    adjustmentsByVert: Map<number, Adjustment>,
    meshAdj: AdjacencyMatrix,
    params: MillInterferenceAdjustmentParams,
) {
    THREE_CACHE.autoAcquire(
        'vector3',
        'vector3',
        'vector3',
    )((pos, tilt, norm) => {
        const getAdjustmentByVert = getMapDefaulter(adjustmentsByVert, () => ({
            adjustment: new THREE.Vector3(),
            count: 0,
        }));
        const posAttr = outGeom.getAttribute(AttributeName.Position);
        for (const interference of interferences) {
            const u = computePushbackAdjustment(outGeom, interference, norm);
            const { axis, constant } = computeTiltAdjustment(outGeom, interference, tilt);
            for (const vI of getFacetVerts(outGeom, interference.sourceIndex)) {
                const adjAndCount = getAdjustmentByVert(vI);
                pos.fromBufferAttribute(posAttr, vI);
                const t = pos.dot(axis) + constant;
                adjAndCount.adjustment.addScaledVector(norm, u * params.pushbackTension + t * params.tiltTension);
                adjAndCount.count += 1;

                // Include adjacent vertices so the surface tension adjustment also applies to them
                meshAdj[vI]?.forEach(getAdjustmentByVert);
            }
        }

        for (const [, { adjustment, count }] of adjustmentsByVert) {
            if (count > 0) {
                adjustment.divideScalar(count);
            }
        }
    });
}

function applySurfaceTensions(
    geom: IndexedBufferGeometry,
    adjustmentsByVert: Map<number, Adjustment>,
    meshAdj: AdjacencyMatrix,
    params: MillInterferenceAdjustmentParams,
) {
    THREE_CACHE.autoAcquire(
        'vector3',
        'vector3',
        'vector3',
    )((pos, avgPos, vec) => {
        const posAttr = geom.getAttribute(AttributeName.Position);
        for (const [vIdx, { adjustment }] of adjustmentsByVert) {
            pos.fromBufferAttribute(posAttr, vIdx);
            const neighbors = meshAdj[vIdx];
            if (!neighbors) {
                continue;
            }
            avgPos.set(0, 0, 0);
            for (const vJ of neighbors) {
                vec.fromBufferAttribute(posAttr, vJ);
                avgPos.add(vec);
            }
            avgPos.divideScalar(neighbors.length).sub(pos).multiplyScalar(params.surfaceTension);
            adjustment.add(avgPos);
        }
    });
}

function applyNeighborLimit(
    geom: IndexedBufferGeometry,
    adjustmentsByVert: Map<number, Adjustment>,
    meshAdj: AdjacencyMatrix,
) {
    THREE_CACHE.autoAcquire(
        'vector3',
        'vector3',
    )((pos, vec) => {
        const posAttr = geom.getAttribute(AttributeName.Position);
        for (const [vIdx, { adjustment }] of adjustmentsByVert) {
            pos.fromBufferAttribute(posAttr, vIdx);
            const neighbors = meshAdj[vIdx];
            if (!neighbors) {
                continue;
            }
            let t = 1;
            for (const vJ of neighbors) {
                vec.fromBufferAttribute(posAttr, vJ).sub(pos);
                const otherAdj = adjustmentsByVert.get(vJ)?.adjustment;
                if (otherAdj) {
                    vec.add(adjustment).sub(otherAdj);
                }
                // Stop just short of edge midpoint
                vec.multiplyScalar(0.475);
                // (t * adjustment)·vec <= vec·vec
                // t <= (vec·vec) / (adjustment·vec)
                const tJ = vec.dot(vec) / adjustment.dot(vec);
                if (tJ >= 0) {
                    t = Math.min(t, tJ);
                }
            }

            adjustment.multiplyScalar(t);
        }
    });
}

export function getMillAdjusted(
    geom: THREE.BufferGeometry,
    millRadiusMm: number,
    meshAdj: AdjacencyMatrix,
    inParams?: Partial<MillInterferenceAdjustmentParams>,
): { geometry: THREE.BufferGeometry; millable: boolean; iterations: number } {
    const params = getMillAdjustmentParams(inParams);

    if (!isIndexedGeometry(geom)) {
        throw new Error('Mill interference correction requires indexed geometry.');
    }

    // Make sure the bvh is built
    const bvh = ensureMeshIndex(geom);

    const outGeom = geom.clone();

    const posAttr = outGeom.getAttribute(AttributeName.Position);
    const indexMap = getIndexMap(outGeom);
    const fixedVerts = new Set<number>();

    for (const { index, verts } of generateFacets(outGeom)) {
        const oldIndex = indexMap[index];
        if (oldIndex !== undefined && !params.facetSelector(index, oldIndex)) {
            verts.forEach(v => fixedVerts.add(v));
        }
    }

    if (params.preventBoundaryAdjustment) {
        findBoundaryVerts(meshAdj, fixedVerts);
    }

    const pos = new THREE.Vector3();
    for (let i = 0; i < params.maxIterations; i += 1) {
        logger.info(`Iteration ${i}`);
        const adjustmentsByVert = new Map<number, { adjustment: THREE.Vector3; count: number }>();

        outGeom.boundsTree?.refit();
        const interferences = computeMillInterference(outGeom, millRadiusMm, params);
        const statistics = getMillStatistics(outGeom, interferences, millRadiusMm);
        logger.info('statistics:', statistics);
        if (isItMillableFromStatistics(statistics, params.jiggleIterations > 0)) {
            return { geometry: outGeom, millable: true, iterations: i };
        }

        // Accumulate interference adjustments
        applyInterferenceAdjustments(outGeom, interferences, adjustmentsByVert, meshAdj, params);

        // Accumulate neighbor tensions and finalize desired adjustment
        applySurfaceTensions(outGeom, adjustmentsByVert, meshAdj, params);

        // Limit adjustments according to edge lengths
        applyNeighborLimit(outGeom, adjustmentsByVert, meshAdj);

        // Apply adjustments
        for (const [vIdx, { adjustment }] of adjustmentsByVert) {
            if (fixedVerts.has(vIdx)) {
                continue;
            }
            pos.fromBufferAttribute(posAttr, vIdx);
            adjustment.add(pos);
            constrainBehindBvh(bvh, pos, adjustment);
            posAttr.setXYZ(vIdx, adjustment.x, adjustment.y, adjustment.z);
        }

        params.surfaceTension *= params.surfaceRelaxation;
        params.pushbackTension *= params.pushbackRelaxation;
        params.tiltTension *= params.tiltRelaxation;
    }

    return { geometry: outGeom, millable: false, iterations: params.maxIterations };
}
