import { AttributeName } from '../BufferAttributeConstants';
import { isIndexedGeometry } from '../BufferGeometry.util';
import { ensureMeshIndex, getIndexMap } from '../MeshIndex';
import { SphereCaster } from '../ShapeCast';
import { triangleFromBufferAttributes, getIncenter } from '../Triangle.util';
import type { MillInterferenceParams, MillInterference } from './types';
import { getMillInterferenceParams } from './utils';
import { isArrayMin1, type ArrayMin1 } from '@orthly/runtime-utils';
import * as THREE from 'three';

const JIGGLE_CACHE = {
    triangle: new THREE.Triangle(),
    closestPt: new THREE.Vector3(),
    vec: new THREE.Vector3(),
};

function applyJiggle(
    hits: ArrayMin1<number>,
    sphere: THREE.Sphere,
    index: THREE.BufferAttribute,
    posAttr: THREE.BufferAttribute | THREE.InterleavedBufferAttribute,
    relativeEpsilon: number,
) {
    const { triangle, closestPt, vec } = JIGGLE_CACHE;
    let minDist2 = Infinity;
    for (const fIdx of hits) {
        triangleFromBufferAttributes(triangle, index, posAttr, fIdx);
        triangle.closestPointToPoint(sphere.center, vec);
        const d2 = vec.distanceToSquared(sphere.center);
        if (d2 < minDist2) {
            closestPt.copy(vec);
            minDist2 = d2;
        }
    }
    if (minDist2 > sphere.radius * sphere.radius) {
        return;
    }
    vec.subVectors(sphere.center, closestPt).normalize();
    const dr = (sphere.radius - Math.sqrt(minDist2)) * (1 + relativeEpsilon);
    sphere.center.addScaledVector(vec, dr);
}

function limitJiggle(
    initialPt: THREE.Vector3,
    normal: THREE.Vector3,
    sphere: THREE.Sphere,
    lateralFactor: number,
    normalFactor: number,
) {
    const { vec } = JIGGLE_CACHE;
    vec.subVectors(sphere.center, initialPt);
    const dn = vec.dot(normal);
    vec.addScaledVector(normal, -dn);
    const dl = vec.length();
    if (dl < lateralFactor * sphere.radius && Math.abs(dn) < normalFactor * sphere.radius) {
        return;
    }

    vec.multiplyScalar((lateralFactor * sphere.radius) / dl);
    sphere.center
        .copy(initialPt)
        .add(vec)
        .addScaledVector(normal, Math.sign(dn) * Math.min(normalFactor * sphere.radius, Math.abs(dn)));
}

export function computeMillInterference(
    geom: THREE.BufferGeometry,
    millRadiusMm: number,
    params?: Partial<MillInterferenceParams>,
): MillInterference[] {
    if (!isIndexedGeometry(geom)) {
        throw new Error('Mill interference operation unsupported for non-indexed geometry.');
    }
    const millParams = getMillInterferenceParams(params);
    const { relativeEpsilon, jiggleIterations } = millParams;
    const bvh = ensureMeshIndex(geom);
    const posAttr = geom.getAttribute(AttributeName.Position);
    const index = geom.getIndex();
    const indexMap = getIndexMap(geom);
    const triangle = new THREE.Triangle();
    const center = new THREE.Vector3();
    const normal = new THREE.Vector3();
    const interferences: MillInterference[] = [];
    const caster = new SphereCaster(new THREE.Vector3(), millRadiusMm / (1 + relativeEpsilon));
    for (let fIdx = 0; fIdx < index.count / 3; fIdx += 1) {
        const originalIndex = indexMap[fIdx];
        if (originalIndex === undefined || !millParams.facetSelector(fIdx, originalIndex)) {
            continue;
        }
        triangleFromBufferAttributes(triangle, index, posAttr, fIdx);
        triangle.getNormal(normal);
        if (normal.lengthSq() === 0) {
            continue;
        }
        getIncenter(triangle, center);
        center.addScaledVector(normal, millRadiusMm);
        caster.sphere.center.copy(center);
        let hits = caster.castAll(bvh).filter(idx => idx !== fIdx);
        for (let i = 0; i < jiggleIterations && isArrayMin1(hits); i += 1) {
            applyJiggle(hits, caster.sphere, index, posAttr, relativeEpsilon);
            limitJiggle(
                center,
                normal,
                caster.sphere,
                millParams.maxLateralJiggleFactor,
                millParams.maxNormalJiggleFactor,
            );
            hits = caster.castAll(bvh).filter(idx => idx !== fIdx);
        }

        if (isArrayMin1(hits)) {
            interferences.push({ sourceIndex: fIdx, sphere: caster.sphere.clone(), collisionIndices: hits });
        }
    }
    return interferences;
}
