import { AttributeName } from './BufferAttributeConstants';
import { generateFacets, getTriangleByIndex } from './Mesh3d.util';
import { ensureMeshIndex } from './MeshIndex';
import { ConeCaster, PrismCaster } from './ShapeCast';
import type { CastVolume } from './ShapeCast/PrismCaster.types';
import { THREE_CACHE } from './ThreeObjectCache';
import { isArrayMin1, type ArrayMin1 } from '@orthly/runtime-utils';
import * as THREE from 'three';

interface Collision {
    meshIndex: number;
    faceIndex: number;
}

export interface UndercutInterferenceCommon {
    sourceIndex: number;
    collisions: ArrayMin1<Collision>;
}

export interface SimpleUndercutInterference extends UndercutInterferenceCommon {
    type: 'cone';
    coneAxis: THREE.Ray;
    coneAngleRadians: number;
}

export interface UndercutInterference extends UndercutInterferenceCommon {
    type: 'prism';
    volume: CastVolume;
}

export interface UndercutParams {
    taperDegrees: number;
}

function getDefaultUndercutParams(params?: Partial<UndercutParams>): UndercutParams {
    return { taperDegrees: params?.taperDegrees ?? 1 };
}

export function computeSimpleUndercutInterference(
    refGeom: THREE.BufferGeometry,
    obstructionGeoms: THREE.BufferGeometry[],
    axis: THREE.Vector3,
    params?: Partial<UndercutParams>,
): SimpleUndercutInterference[] {
    const { taperDegrees } = getDefaultUndercutParams(params);
    return THREE_CACHE.autoAcquire('ray')(coneAxis => {
        coneAxis.direction.copy(axis).negate();
        const caster = new ConeCaster(coneAxis, (taperDegrees * Math.PI) / 180);
        const interferences: SimpleUndercutInterference[] = [];
        const posAttr = refGeom.getAttribute(AttributeName.Position);
        const bvhs = obstructionGeoms.map(mesh => ensureMeshIndex(mesh));
        for (let i = 0; i < posAttr.count; i += 1) {
            caster.axis.origin.fromBufferAttribute(posAttr, i);
            const collisions: Collision[] = [];
            bvhs.forEach((bvh, meshIndex) => {
                for (const faceIndex of caster.castAll(bvh)) {
                    collisions.push({ meshIndex, faceIndex });
                }
            });
            if (isArrayMin1(collisions)) {
                interferences.push({
                    type: 'cone',
                    sourceIndex: i,
                    coneAngleRadians: caster.halfAngle,
                    coneAxis: caster.axis.clone(),
                    collisions,
                });
            }
        }

        return interferences;
    });
}

export function computeUndercutInterference(
    refGeom: THREE.BufferGeometry,
    obstructionGeoms: THREE.BufferGeometry[],
    axis: THREE.Vector3,
    params?: Partial<UndercutParams>,
): UndercutInterference[] {
    const { taperDegrees } = getDefaultUndercutParams(params);
    const taperSlope = Math.tan((Math.PI * taperDegrees) / 180);

    const caster = new PrismCaster(axis.clone(), taperSlope);
    const interferences: UndercutInterference[] = [];
    const bvhs = obstructionGeoms.map(mesh => ensureMeshIndex(mesh));
    const triCount = (refGeom.getIndex() ?? refGeom.getAttribute(AttributeName.Position)).array.length / 3;
    for (let i = 0; i < triCount; i += 1) {
        getTriangleByIndex(refGeom, i, caster.baseTriangle);
        if (!caster.isCastVolumeValid()) {
            continue;
        }
        const collisions: Collision[] = [];
        bvhs.forEach((bvh, meshIndex) => {
            for (const faceIndex of caster.castAll(bvh)) {
                collisions.push({ meshIndex, faceIndex });
            }
        });
        if (isArrayMin1(collisions)) {
            interferences.push({
                type: 'prism',
                sourceIndex: i,
                volume: caster.getCastVolume(),
                collisions,
            });
        }
    }

    return interferences;
}

export interface UndercutEscapeParams extends UndercutParams {
    maxDistance: number;
    precision: number;
}

function getDefaultEscapeParams(params?: Partial<UndercutEscapeParams>): UndercutEscapeParams {
    return {
        ...getDefaultUndercutParams(params),
        maxDistance: params?.maxDistance ?? 1.0,
        precision: params?.precision ?? 0.01,
    };
}

// For undercut escape computation we do a radial binary search for an escape offset. For each escape
// distance, we pick a number of angular samples to test for escape.
const MAX_ANGULAR_ITERATIONS = 32 as const;

// A set of escape offsets assuming the cast direction is +x.
const ANGULAR_SAMPLES = new Array(MAX_ANGULAR_ITERATIONS).fill(0).map((_, i) => {
    const theta = (2 * Math.PI * i) / MAX_ANGULAR_ITERATIONS;
    return new THREE.Vector3(0, Math.cos(theta), Math.sin(theta));
});

const XAxis = new THREE.Vector3(1, 0, 0);

export function computeUndercutEscape(
    obstructionGeoms: THREE.BufferGeometry[],
    pt: THREE.Vector3,
    axis: THREE.Vector3,
    inParams?: Partial<UndercutEscapeParams>,
): THREE.Vector3 | undefined {
    const params = getDefaultEscapeParams(inParams);
    // TODO: handle 0 taper case.
    const caster = new ConeCaster(new THREE.Ray(pt.clone(), axis.clone()), (params.taperDegrees * Math.PI) / 180);
    const bvhs = obstructionGeoms.map(g => ensureMeshIndex(g));

    // Test the query point itself; if it escapes, just return a 0 vector.
    if (!bvhs.some(bvh => caster.castAny(bvh))) {
        return new THREE.Vector3();
    }

    // Rotate the angular sample directions to align with the cast axis.
    const samples = THREE_CACHE.autoAcquire('quaternion')(q => {
        q.setFromUnitVectors(axis, XAxis);
        return ANGULAR_SAMPLES.map(v => v.clone().applyQuaternion(q).normalize());
    });
    const iterations = Math.ceil(Math.log2(params.maxDistance / params.precision));
    let lo = 0;
    let hi = 1;
    let escape: THREE.Vector3 | undefined;
    for (let i = 0; i < iterations; i += 1) {
        const mid = 0.5 * (lo + hi);
        let anyEscaped = false;
        for (const offset of samples) {
            caster.axis.origin.copy(pt).addScaledVector(offset, mid * params.maxDistance);
            if (bvhs.some(bvh => !caster.castAny(bvh))) {
                anyEscaped = true;
                if (!escape) {
                    escape = offset.clone().multiplyScalar(mid);
                } else if (mid * mid < escape.lengthSq()) {
                    escape.copy(offset).multiplyScalar(mid);
                }
                break;
            }
        }

        if (anyEscaped) {
            hi = mid;
        } else {
            lo = mid;
        }
    }

    return escape?.multiplyScalar(params.maxDistance);
}

export interface AdjustmentConflict {
    vertexIndex: number;
    desiredAdjustment: THREE.Vector3;
}

export function adjustForUndercut(
    refGeom: THREE.BufferGeometry,
    obstructionGeoms: THREE.BufferGeometry[],
    axis: THREE.Vector3,
    facetSelector?: (fIdx: number) => boolean,
    inParams?: Partial<UndercutEscapeParams>,
): AdjustmentConflict[] {
    const selector = facetSelector ?? (() => true);
    const { taperDegrees } = getDefaultEscapeParams(inParams);

    const taperRadians = (Math.PI * taperDegrees) / 180;
    const taperSlope = Math.tan(taperRadians);

    const prismCaster = new PrismCaster(axis.clone(), taperSlope);
    const coneCaster = new ConeCaster(new THREE.Ray(), taperRadians);
    coneCaster.axis.direction.copy(axis);

    const bvhs = obstructionGeoms.map(mesh => ensureMeshIndex(mesh));
    const posAttr = refGeom.getAttribute(AttributeName.Position);

    // Since the intended use-case of this function is for intaglio adjustment, we want to avoid adjusting the
    // boundary of the intaglio region (i.e. the margin). These sets discern which vertices lie on that boundary.
    // If a facet fails the selection filter, we add its vertices to the fixed set.
    const fixedVerts = new Set<number>();
    const freeVerts = new Set<number>();
    const tri = prismCaster.baseTriangle;

    for (const { index, verts } of generateFacets(refGeom)) {
        if (!selector(index)) {
            verts.forEach(vi => fixedVerts.add(vi));
            continue;
        }

        getTriangleByIndex(refGeom, index, tri);
        if (prismCaster.isCastVolumeValid()) {
            if (bvhs.some(bvh => prismCaster.castAny(bvh))) {
                verts.forEach(vi => freeVerts.add(vi));
            }
        } else {
            verts.forEach(vi => {
                coneCaster.axis.origin.fromBufferAttribute(posAttr, vi);
                if (bvhs.some(bvh => coneCaster.castAny(bvh))) {
                    freeVerts.add(vi);
                }
            });
        }
    }

    const conflicts: AdjustmentConflict[] = [];
    const pos = new THREE.Vector3();
    for (const vIdx of freeVerts) {
        pos.fromBufferAttribute(posAttr, vIdx);
        const adjust = computeUndercutEscape(obstructionGeoms, pos, axis, inParams);
        if (!adjust) {
            continue;
        }
        if (fixedVerts.has(vIdx)) {
            conflicts.push({ vertexIndex: vIdx, desiredAdjustment: adjust });
            continue;
        }
        pos.add(adjust);
        posAttr.setXYZ(vIdx, pos.x, pos.y, pos.z);
    }

    return conflicts;
}
