/* eslint-disable max-lines */
import { boxEdges, boxPoints } from '../Box3.generators';
import { lineIntersectPlaneAt, lineIntersectsTriangle, planesIntersection } from '../Intersection';
import { THREE_CACHE } from '../ThreeObjectCache';
import type { CastVolume } from './PrismCaster.types';
import { PlaneSide } from './PrismCaster.types';
import type { EdgePlane, VertPlane } from './PrismCaster.util';
import { PlaneID, boxPlaneSide, trianglePlaneSide } from './PrismCaster.util';
import type { ShapeCaster } from './types';
import { CastMode } from './types';
import type { ArrayN } from '@orthly/runtime-utils';
import { assertNever } from '@orthly/shared-types';
import * as THREE from 'three';
import type { MeshBVH, ShapecastIntersection } from 'three-mesh-bvh';
import { CONTAINED, INTERSECTED, NOT_INTERSECTED } from 'three-mesh-bvh';

/**
 * A shape caster class that can cast an optionally tapered extrusion of a triangle and report
 * all intersecting facets.
 */
export class PrismCaster implements ShapeCaster {
    readonly intersectsBounds: (
        box: THREE.Box3,
        isLeaf: boolean,
        score: number | undefined,
        depth: number,
        nodeIndex: number,
    ) => ShapecastIntersection;
    readonly intersectsTriangle: (
        triangle: THREE.Triangle,
        triangleIndex: number,
        contained: boolean,
        depth: number,
    ) => boolean;

    private readonly facetIndices: Set<number> = new Set();
    private readonly planes: ArrayN<THREE.Plane, 7> = [
        new THREE.Plane(),
        new THREE.Plane(),
        new THREE.Plane(),
        new THREE.Plane(),
        new THREE.Plane(),
        new THREE.Plane(),
        new THREE.Plane(),
    ];
    private frustumVertex: THREE.Vector3 = new THREE.Vector3();
    private isFlipped: boolean = false;
    private mode: CastMode;
    private adjustedSlope: number;

    constructor(
        public readonly direction: THREE.Vector3,
        public taperSlope: number = 0,
        public readonly baseTriangle: THREE.Triangle = new THREE.Triangle(),
    ) {
        this.intersectsBounds = this.intersectsBoundsImpl.bind(this);
        this.intersectsTriangle = this.intersectsTriangleImpl.bind(this);
        this.updateBoundingPlanes();
    }

    get intersectingFacets() {
        return Array.from(this.facetIndices).sort((a, b) => a - b);
    }

    isCastVolumeValid(): boolean {
        if (this.taperSlope <= 0) {
            return true;
        }

        return THREE_CACHE.autoAcquire('vector3')(normal => {
            const k = this.taperSlope * Math.sqrt(1 / (1 + this.taperSlope * this.taperSlope));
            return Math.abs(this.baseTriangle.getNormal(normal).dot(this.direction)) >= k;
        });
    }

    getCastVolume(): CastVolume {
        if (this.adjustedSlope > 0) {
            const flip = this.isFlipped ? 1 : -1;
            return {
                type: 'cupola',
                base: this.baseTriangle.clone(),
                aEdges: [
                    new THREE.Vector3()
                        .crossVectors(this.planes[PlaneID.CA].normal, this.planes[PlaneID.A].normal)
                        .multiplyScalar(flip)
                        .normalize(),
                    new THREE.Vector3()
                        .crossVectors(this.planes[PlaneID.A].normal, this.planes[PlaneID.AB].normal)
                        .multiplyScalar(flip)
                        .normalize(),
                ],
                bEdges: [
                    new THREE.Vector3()
                        .crossVectors(this.planes[PlaneID.AB].normal, this.planes[PlaneID.B].normal)
                        .multiplyScalar(flip)
                        .normalize(),
                    new THREE.Vector3()
                        .crossVectors(this.planes[PlaneID.B].normal, this.planes[PlaneID.BC].normal)
                        .multiplyScalar(flip)
                        .normalize(),
                ],
                cEdges: [
                    new THREE.Vector3()
                        .crossVectors(this.planes[PlaneID.BC].normal, this.planes[PlaneID.C].normal)
                        .multiplyScalar(flip)
                        .normalize(),
                    new THREE.Vector3()
                        .crossVectors(this.planes[PlaneID.C].normal, this.planes[PlaneID.CA].normal)
                        .multiplyScalar(flip)
                        .normalize(),
                ],
            };
        } else if (this.adjustedSlope < 0) {
            return { type: 'tetrahedron', base: this.baseTriangle.clone(), vertex: this.frustumVertex.clone() };
        } else {
            return { type: 'prism', base: this.baseTriangle.clone(), direction: this.direction.clone() };
        }
    }

    clear() {
        this.facetIndices.clear();
    }

    castAny(bvh: MeshBVH): boolean {
        this.mode = CastMode.FindAny;
        this.clear();
        this.updateBoundingPlanes();
        return bvh.shapecast(this);
    }

    castAll(bvh: MeshBVH): number[] {
        this.mode = CastMode.FindAll;
        this.clear();
        this.updateBoundingPlanes();
        bvh.shapecast(this);
        return this.intersectingFacets;
    }

    private intersectsBoundsImpl(box: THREE.Box3): ShapecastIntersection {
        const sides = this.planes.map(plane => boxPlaneSide(box, plane));

        // If the box is completely in front of any of the bounding planes, it's definitely not intersecting.
        if (sides.indexOf(PlaneSide.Front) !== -1) {
            return NOT_INTERSECTED;
        }
        // If the box is completely behind all of the bounding planes, it's entirely contained.
        if (sides.every(s => s === PlaneSide.Behind)) {
            return CONTAINED;
        }
        // If the box crosses *exactly* one plane and is behind all the others, it's intersecting.
        if (sides.filter(s => s === PlaneSide.Crossing).length === 1) {
            return INTERSECTED;
        }

        if (this.baseTriangle.intersectsBox(box)) {
            return INTERSECTED;
        }

        // Check if any box vertex is completely within the bounding planes
        const contained = THREE_CACHE.autoAcquire('vector3')(target => {
            for (const boxPt of boxPoints(box, target)) {
                if (this.planes.every(({ normal, constant }) => normal.dot(boxPt) + constant <= 0)) {
                    return true;
                }
            }
            return false;
        });

        if (contained) {
            return INTERSECTED;
        }

        // At this point all vertices of the base triangle and box are outside of the opposing volume.
        // This does not mean there's no intersection, however.
        if (this.adjustedSlope < 0) {
            // Negative taper slope means the query volume is a finite tetrahedron.
            return THREE_CACHE.autoAcquire('triangle')(tri => {
                // We already checked the base face, so we can skip that here.
                const faces = [
                    [this.baseTriangle.a, this.baseTriangle.b, this.frustumVertex],
                    [this.baseTriangle.b, this.baseTriangle.c, this.frustumVertex],
                    [this.baseTriangle.c, this.baseTriangle.a, this.frustumVertex],
                ] as const;
                for (const [a, b, c] of faces) {
                    if (tri.set(a, b, c).intersectsBox(box)) {
                        return INTERSECTED;
                    }
                }
                return NOT_INTERSECTED;
            });
        }

        const raysIntersect = THREE_CACHE.autoAcquire('ray')(rayTarget => {
            for (const ray of this.generateEdgePlaneRays(rayTarget)) {
                if (ray.intersectsBox(box)) {
                    return true;
                }
            }
            return false;
        });
        return raysIntersect || this.intersectsBoxByEdges(box) ? INTERSECTED : NOT_INTERSECTED;
    }

    private intersectsTriangleImpl(triangle: THREE.Triangle, triangleIndex: number, contained: boolean): boolean {
        if (contained || this.intersectsTriangleHelper(triangle)) {
            if (this.mode === CastMode.FindAny) {
                return true;
            } else {
                this.facetIndices.add(triangleIndex);
            }
        }
        return false;
    }

    private updateBoundingPlanes() {
        // Make sure base plane is oriented correctly
        this.baseTriangle.getPlane(this.planes[PlaneID.Base]);
        const baseDot = this.planes[PlaneID.Base].normal.dot(this.direction);
        this.isFlipped = baseDot > 0;
        if (this.isFlipped) {
            this.planes[PlaneID.Base].negate();
        }

        const baseSlope = Math.abs(baseDot) / Math.sqrt(1 - baseDot * baseDot);
        this.adjustedSlope = Math.min(this.taperSlope, baseSlope);

        this.updateEdgePlane(PlaneID.AB);
        this.updateEdgePlane(PlaneID.BC);
        this.updateEdgePlane(PlaneID.CA);

        if (this.adjustedSlope > 0) {
            // If taper slope is positive, we need to add additional bounding planes to prevent
            // skinny triangles from exploding into huge prisms.
            this.updateVertPlane(PlaneID.A);
            this.updateVertPlane(PlaneID.B);
            this.updateVertPlane(PlaneID.C);
        }

        planesIntersection(
            this.planes[PlaneID.AB],
            this.planes[PlaneID.BC],
            this.planes[PlaneID.CA],
            this.frustumVertex,
        );
    }

    private updateEdgePlane(planeId: EdgePlane) {
        THREE_CACHE.autoAcquire('vector3')(edge => {
            let point: THREE.Vector3 | undefined;
            switch (planeId) {
                case PlaneID.AB:
                    point = this.baseTriangle.a;
                    edge.subVectors(point, this.baseTriangle.b).normalize();
                    break;
                case PlaneID.BC:
                    point = this.baseTriangle.b;
                    edge.subVectors(point, this.baseTriangle.c).normalize();
                    break;
                case PlaneID.CA:
                    point = this.baseTriangle.c;
                    edge.subVectors(point, this.baseTriangle.a).normalize();
                    break;
                default:
                    assertNever(planeId);
            }

            // Need to ensure the edge is oriented correctly so the following computation produces
            // the correct result; otherwise the meaning of taper ends up sign-flipped.
            if (this.isFlipped) {
                edge.negate();
            }

            const plane = this.planes[planeId];
            const k = -this.adjustedSlope / Math.sqrt(1 + this.adjustedSlope * this.adjustedSlope);
            const ed = edge.dot(this.direction);
            const s2 = 1 - ed * ed;
            const t = s2 >= k * k ? k / Math.sqrt(s2 - k * k) : 0;
            plane.normal
                .crossVectors(edge, this.direction)
                .addScaledVector(this.direction, t)
                .addScaledVector(edge, -t * ed)
                .normalize();
            plane.constant = -plane.normal.dot(point);
        });
    }

    private updateVertPlane(planeId: VertPlane): void {
        // For vertex bounding planes we're going to pick them so that they generate parallel edges
        // for each edge bounding plane. I.e. plane A will generate an edge with plane CA and plane
        // AB, the A/CA edge will be parallel to the C/CA edge and the A/AB edge will be parallel
        // to the B/AB edge.
        THREE_CACHE.autoAcquire('vector3')(vec => {
            let point: THREE.Vector3 | undefined;
            let ni: THREE.Vector3 | undefined;
            let nj: THREE.Vector3 | undefined;
            switch (planeId) {
                case PlaneID.A:
                    point = this.baseTriangle.a;
                    ni = this.planes[PlaneID.AB].normal;
                    nj = this.planes[PlaneID.CA].normal;
                    break;
                case PlaneID.B:
                    point = this.baseTriangle.b;
                    ni = this.planes[PlaneID.BC].normal;
                    nj = this.planes[PlaneID.AB].normal;
                    break;
                case PlaneID.C:
                    point = this.baseTriangle.c;
                    ni = this.planes[PlaneID.CA].normal;
                    nj = this.planes[PlaneID.BC].normal;
                    break;
                default:
                    assertNever(planeId);
            }

            if (this.isFlipped) {
                const tmp = ni;
                ni = nj;
                nj = tmp;
            }

            const plane = this.planes[planeId];
            vec.copy(this.direction).projectOnPlane(ni);
            plane.normal.copy(this.direction).projectOnPlane(nj).cross(vec).normalize();
            // Tilt the plane up if possible to reach desired taper
            const nd = plane.normal.dot(this.direction);
            // We want nd == -m / √(1 + m²), a suitable correction factor can be found:
            const t = nd + this.adjustedSlope * Math.sqrt(1 - nd * nd);
            // But we only want to make this adjustment to increase slopes, never decrease them.
            // The reason for this is that our initial normal is computed such that it would form
            // parallel edges in the neighboring planes. Increasing the slope of the resulting
            // normal will pull those edges outwards, making them divergent. Decreasing the slope
            // would cause them to converge, which makes the geometry trickier to deal with.
            if (t > 0) {
                plane.normal.addScaledVector(this.direction, -t).normalize();
            }
            plane.constant = -plane.normal.dot(point);
        });
    }

    private *generateEdgePlaneRays(target: THREE.Ray) {
        const baseNormal = this.planes[PlaneID.Base].normal;
        const setOrigin = (pt: THREE.Vector3) => {
            target.origin.copy(pt);
        };
        const setDirection = (n1: THREE.Vector3, n2: THREE.Vector3) => {
            target.direction.crossVectors(n1, n2).normalize();
            if (target.direction.dot(baseNormal) > 0) {
                target.direction.negate();
            }
            return target;
        };
        if (this.adjustedSlope > 0) {
            let curr = this.planes[PlaneID.CA];
            let next = this.planes[PlaneID.A];
            setOrigin(this.baseTriangle.a);
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.AB];
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.B];
            setOrigin(this.baseTriangle.b);
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.BC];
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.C];
            setOrigin(this.baseTriangle.c);
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.CA];
            yield setDirection(curr.normal, next.normal);
        } else {
            let curr = this.planes[PlaneID.CA];
            let next = this.planes[PlaneID.AB];
            setOrigin(this.baseTriangle.a);
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.BC];
            setOrigin(this.baseTriangle.b);
            yield setDirection(curr.normal, next.normal);
            curr = next;
            next = this.planes[PlaneID.CA];
            setOrigin(this.baseTriangle.c);
            yield setDirection(curr.normal, next.normal);
        }
    }

    private intersectsEdge(edge: THREE.Line3): boolean {
        const planes =
            this.adjustedSlope > 0
                ? [
                      this.planes[PlaneID.A],
                      this.planes[PlaneID.AB],
                      this.planes[PlaneID.B],
                      this.planes[PlaneID.BC],
                      this.planes[PlaneID.C],
                      this.planes[PlaneID.CA],
                  ]
                : [this.planes[PlaneID.AB], this.planes[PlaneID.BC], this.planes[PlaneID.CA]];
        const base = this.planes[PlaneID.Base];
        const isPointAtBehind = (t: number, plane: THREE.Plane) =>
            plane.normal.dot(edge.start) * (1 - t) + plane.normal.dot(edge.end) * t + plane.constant <= 0;
        for (let i = 0; i < planes.length; i += 1) {
            const prev = planes[(i + planes.length - 1) % planes.length] as THREE.Plane;
            const curr = planes[i] as THREE.Plane;
            const next = planes[(i + 1) % planes.length] as THREE.Plane;
            const t = lineIntersectPlaneAt(edge, curr);
            if (t < 0 || 1 < t || isNaN(t)) {
                // Intersection is outside of the edge bounds, no intersection.
                continue;
            }

            // Make sure the resulting intersection point is "behind" the neighboring planes
            if (isPointAtBehind(t, prev) && isPointAtBehind(t, base) && isPointAtBehind(t, next)) {
                return true;
            }
        }

        return false;
    }

    private intersectsBoxByEdges(box: THREE.Box3): boolean {
        return THREE_CACHE.autoAcquire('line3')(edgeTarget => {
            for (const edge of boxEdges(box, edgeTarget)) {
                if (this.intersectsEdge(edge)) {
                    return true;
                }
            }
            return false;
        });
    }

    private intersectsTriangleHelper(triangle: THREE.Triangle): boolean {
        const sides = this.planes.map(plane => trianglePlaneSide(triangle, plane));
        if (sides.some(side => side === PlaneSide.Front)) {
            return false;
        }
        if (sides.every(side => side === PlaneSide.Behind)) {
            return true;
        }

        return THREE_CACHE.autoAcquire('line3')(edge => {
            const edges = [
                [triangle.a, triangle.b],
                [triangle.b, triangle.c],
                [triangle.c, triangle.a],
            ] as const;
            for (const [start, end] of edges) {
                edge.set(start, end);
                if (lineIntersectsTriangle(edge, this.baseTriangle) || this.intersectsEdge(edge)) {
                    return true;
                }
            }
            return false;
        });
    }
}
