import { boxEdges, boxPoints } from '../Box3.generators';
import { THREE_CACHE } from '../ThreeObjectCache';
import type { ShapeCaster } from './types';
import { CastMode } from './types';
import type * as THREE from 'three';
import type { MeshBVH, ShapecastIntersection } from 'three-mesh-bvh';
import { CONTAINED, INTERSECTED, NOT_INTERSECTED } from 'three-mesh-bvh';

function swapVectors(vec1: THREE.Vector3, vec2: THREE.Vector3) {
    const { x, y, z } = vec1;
    vec1.copy(vec2);
    vec2.set(x, y, z);
}

function sqr(x: number) {
    return x * x;
}

export class ConeCaster implements ShapeCaster {
    readonly intersectsBounds: (
        box: THREE.Box3,
        isLeaf: boolean,
        score: number | undefined,
        depth: number,
        nodeIndex: number,
    ) => ShapecastIntersection;
    readonly intersectsTriangle:
        | ((triangle: THREE.Triangle, triangleIndex: number, contained: boolean, depth: number) => boolean)
        | undefined;

    private facetIndices: Set<number> = new Set();

    private halfAngleInternal: number;
    private cosHalfAngle: number;
    private tanHalfAngle: number;
    private mode: CastMode;

    constructor(
        public readonly axis: THREE.Ray,
        halfAngle: number,
    ) {
        this.halfAngle = halfAngle;
        this.intersectsBounds = this.intersectsBoundsImpl.bind(this);
        this.intersectsTriangle = this.intersectsTriangleImpl.bind(this);
    }

    get halfAngle() {
        return this.halfAngleInternal;
    }

    set halfAngle(value: number) {
        this.halfAngleInternal = value;
        this.cosHalfAngle = Math.cos(this.halfAngleInternal);
        this.tanHalfAngle = Math.tan(this.halfAngleInternal);
    }

    get intersectingFacets() {
        return Array.from(this.facetIndices).sort((a, b) => a - b);
    }

    clear() {
        this.facetIndices.clear();
    }

    castAny(bvh: MeshBVH): boolean {
        this.mode = CastMode.FindAny;
        this.clear();
        return bvh.shapecast(this);
    }

    castAll(bvh: MeshBVH): number[] {
        this.mode = CastMode.FindAll;
        this.clear();
        bvh.shapecast(this);
        return this.intersectingFacets;
    }

    private intersectsBoundsImpl(box: THREE.Box3): ShapecastIntersection {
        return THREE_CACHE.autoAcquire(
            'vector3',
            'line3',
        )((vec, line) => {
            let anyIn = false;
            let allIn = true;

            for (const pt of boxPoints(box, vec)) {
                if (this.pointInside(pt)) {
                    anyIn ||= true;
                } else {
                    allIn &&= false;
                }
            }

            if (allIn) {
                return CONTAINED;
            }

            if (anyIn || this.axis.intersectsBox(box)) {
                return INTERSECTED;
            }

            // At this point if there's an intersection it's between the cone and an edge of the box
            for (const edge of boxEdges(box, line)) {
                if (this.intersectsEdge(edge)) {
                    return INTERSECTED;
                }
            }
            return NOT_INTERSECTED;
        });
    }

    private intersectsTriangleImpl(triangle: THREE.Triangle, triangleIndex: number, contained: boolean): boolean {
        if (contained || this.intersectsTriangleHelper(triangle)) {
            // Find any mode means we can terminate early if we find any intersection
            if (this.mode === CastMode.FindAny) {
                return true;
            } else {
                this.facetIndices.add(triangleIndex);
            }
        }
        return false;
    }

    /**
     * This isn't exactly an intersection test since it assumes that the endpoints of the line are already
     * determented to be outside the cone
     */
    private intersectsEdge(edge: THREE.Line3): boolean {
        const originD = this.axis.origin.dot(this.axis.direction);
        if (edge.start.dot(this.axis.direction) <= originD && edge.end.dot(this.axis.direction) <= originD) {
            return false;
        }
        return THREE_CACHE.autoAcquire(
            'vector3',
            'vector3',
            'matrix4',
        )((vec0, vec1, mat) => {
            // First we transform into a 2D projective space looking from the cone vertex along its axis.
            // This is a standard perspective transformation and, as such, lines through the cone vertex
            // become points and all other lines remain lines. The surface of the cone then becomes a circle
            // in our projective space and line segments become either 2D segments or 2D rays.
            vec1.subVectors(edge.end, edge.start).cross(this.axis.direction).normalize();
            if (vec1.lengthSq() < 1e-5) {
                // edge is parallel to cone axis, no intersection
                return false;
            }
            vec0.crossVectors(vec1, this.axis.direction).normalize();
            mat.makeBasis(vec0, vec1, this.axis.direction).transpose();

            vec0.subVectors(edge.start, this.axis.origin).applyMatrix4(mat);
            vec1.subVectors(edge.end, this.axis.origin).applyMatrix4(mat);

            // Handle lines that cross the cone vertex plane by trimming them
            if (vec0.z < 0) {
                if (vec1.z <= 0) {
                    return false;
                }

                // Shift vec0 so that its z coordinate is 0
                vec0.sub(vec1)
                    .multiplyScalar(-vec1.z / vec0.z)
                    .add(vec1)
                    .setZ(0); // Avoid rounding issues
            } else if (vec1.z < 0) {
                // Shift vec1 so that its z coordinate is 0
                vec1.sub(vec0)
                    .multiplyScalar(-vec0.z / vec1.z)
                    .add(vec0)
                    .setZ(0);
            }

            if (vec0.z > 0) {
                vec0.multiplyScalar(1 / (this.tanHalfAngle * vec0.z));
            }

            if (vec1.z > 0) {
                vec1.multiplyScalar(1 / (this.tanHalfAngle * vec1.z));
            }

            if (vec1.z < vec0.z) {
                swapVectors(vec0, vec1);
            }

            vec1.setZ(0);
            if (vec0.z === 0) {
                // In this case our segment is actually a ray and vec0 represents its direction
                // To test if the ray hits our cone, which has been projected into a unit disk
                // we first test if the direction is "towards" the origin and then if the ray
                // comes within a distance of 1 of the origin.
                const s2 = vec0.lengthSq();
                return vec0.dot(vec1) < 0 && vec0.cross(vec1).lengthSq() <= s2;
            }

            // set z to 0 if it hadn't been already.
            vec0.setZ(0);

            const d2 = vec0.distanceToSquared(vec1);
            const r2 = sqr(vec1.x * vec0.y - vec0.x * vec1.y);
            const s = vec0.dot(vec1);

            return r2 <= d2 && (vec0.lengthSq() - s) / (vec1.lengthSq() - s) > 0;
        });
    }

    private intersectsTriangleHelper(triangle: THREE.Triangle): boolean {
        if ([triangle.a, triangle.b, triangle.c].some(pt => this.pointInside(pt))) {
            return true;
        }

        return THREE_CACHE.autoAcquire(
            'vector3',
            'line3',
        )((target, edge) => {
            const hit = this.axis.intersectTriangle(triangle.a, triangle.b, triangle.c, false, target);
            if (hit !== null) {
                return true;
            }

            edge.start.copy(triangle.a);
            edge.end.copy(triangle.b);
            if (this.intersectsEdge(edge)) {
                return true;
            }
            edge.start.copy(triangle.b);
            edge.end.copy(triangle.c);
            if (this.intersectsEdge(edge)) {
                return true;
            }
            edge.start.copy(triangle.c);
            edge.end.copy(triangle.a);
            if (this.intersectsEdge(edge)) {
                return true;
            }

            return false;
        });
    }

    private pointInside(pt: THREE.Vector3): boolean {
        const s = pt.dot(this.axis.direction) - this.axis.origin.dot(this.axis.direction);
        return s > 0 && s * s >= this.cosHalfAngle * this.cosHalfAngle * pt.distanceToSquared(this.axis.origin);
    }
}
