import {
    AttributeName,
    ComputeVertexNormalsByAngle,
    ensureDistanceAttributesInitialized,
    ensureMeshIndex,
    isEmptyPositionOrIndex,
} from '../../Utils3D';
import { addIndexAsUnindexed } from '../../Utils3D/BufferGeometry.util';
import { collissionGroupsToGeometry, computeAllCollisionLines } from '../../Utils3D/MeshCollisions';
import type { AdjacencyMatrix, MeshConnectivityGraph } from '../../Utils3D/MeshConnectivityGraph';
import { buildMeshAdjacencyAndVerticesFaces } from '../../Utils3D/MeshConnectivityGraph';
import { cropGeometryNearCenter, cropGeometryNearGeometry, makeCurtainGeometry } from '../../Utils3D/ShadowVolumeMesh';
import { logger } from '../../Utils/Logger';
import {
    DesignCaseFileUtil,
    EligibleForHeatmapToothTypes,
    type ToothElement,
    ToothElementType,
} from '../DesignCaseFile';
import type { DesignProjectAsset } from './DesignZipReading.types';
import type { MinimallyProcessedDesignAssets } from './DesignZipReading.util';
import type { ToothNumber } from '@orthly/items';
import { ToothUtils } from '@orthly/items';
import { Format } from '@orthly/runtime-utils';
import { Jaw } from '@orthly/shared-types';
import _, { compact } from 'lodash';
import * as THREE from 'three';
import { mergeBufferGeometries } from 'three-stdlib';

const MESH_TO_MESH_CURTAIN_CROP_DISTANCE = 3;

// This is like an offline version of ModelPayloadItem
interface DesignAssetWithDesignAttribs {
    unns: number[];
    geom: THREE.BufferGeometry;
    insertionAxis?: THREE.Vector3;
    modelElementId?: string;
    sourceFile: string;
    isImplant?: boolean;
    isBridge?: boolean;
    isRemovablePartialDenture?: boolean;
}

interface RestorativeAsset extends DesignAssetWithDesignAttribs {
    isImplant: boolean;
    isBridge: boolean;
    meshConnectivity: MeshConnectivityGraph;
}

export function createAssetWithDesignAttribs(
    dpa: DesignProjectAsset,
    design: MinimallyProcessedDesignAssets,
    override: Partial<DesignAssetWithDesignAttribs>,
): DesignAssetWithDesignAttribs {
    const mE = design.parsedCase.modelElements.find(
        mE =>
            DesignCaseFileUtil.cleanXmlNameForComparison(mE.modelFilePath).toLowerCase() ===
            DesignCaseFileUtil.cleanXmlNameForComparison(dpa.sourceFile).toLowerCase(),
    );
    const insertionAxis: THREE.Vector3 | undefined =
        mE && design.insertionAxes.has(mE.modelElementID) ? design.insertionAxes.get(mE.modelElementID) : undefined;
    const unns =
        mE && (mE.modelType === 'meIndicationRegular' || mE.modelType === 'meSplint')
            ? DesignCaseFileUtil.getTeethByModelID(design.parsedCase.toothElements, mE.modelElementID).map(
                  (tE: ToothElement) => tE.toothNumber,
              )
            : undefined;
    const designAsset: DesignAssetWithDesignAttribs = {
        unns: unns ?? [],
        geom: dpa.geom,
        sourceFile: dpa.sourceFile,
        insertionAxis: insertionAxis,
        modelElementId: mE ? mE.modelElementID : undefined,
        ...override,
    };
    return designAsset;
}

function areAssetsAdjacent(assetX: RestorativeAsset, assetY: RestorativeAsset) {
    return assetX.unns.some(
        unnX =>
            ToothUtils.isToothNumber(unnX) &&
            assetY.unns.some(unnY => ToothUtils.isToothNumber(unnY) && ToothUtils.areAdjacent(unnX, unnY)),
    );
}

function getAssetAdjacency(assets: RestorativeAsset[]): AdjacencyMatrix {
    const adjacency = Array.from({ length: assets.length }).map((): number[] => []);
    for (let i = 0; i < assets.length - 1; i += 1) {
        const assetI = assets[i];
        if (!assetI) {
            continue;
        }
        for (let j = i + 1; j < assets.length; j += 1) {
            const assetJ = assets[j];
            if (!assetJ) {
                continue;
            }
            if (areAssetsAdjacent(assetI, assetJ)) {
                adjacency[i]?.push(j);
                adjacency[j]?.push(i);
            }
        }
    }
    return adjacency;
}

function populateCurtainsForArchAssets(
    design: MinimallyProcessedDesignAssets,
    archAssets: RestorativeAsset[],
    archScan: THREE.BufferGeometry | undefined,
    archScanConnectivity: MeshConnectivityGraph | undefined,
    curtainsMap: Map<string, THREE.BufferGeometry>,
) {
    const adjacency = getAssetAdjacency(archAssets);
    const center = new THREE.Vector3();
    archAssets.forEach((asset, idx) => {
        if (!asset.modelElementId) {
            return;
        }
        const direction = design.insertionAxes.get(asset.modelElementId);
        if (!direction) {
            return;
        }

        asset.geom.computeBoundingBox();

        const boundingBox = asset.geom.boundingBox;

        if (!boundingBox) {
            logger.warn(`Bounding box could not be computed for asset with modelElementId=${asset.modelElementId}`);
            return;
        }

        boundingBox.getCenter(center);
        const cropRadius = asset.isBridge
            ? MESH_TO_MESH_CURTAIN_CROP_DISTANCE
            : Math.max(10, boundingBox.min.distanceTo(boundingBox.max) / 2);

        const getCropped = asset.isBridge
            ? (geom: THREE.BufferGeometry | undefined, connectivity: MeshConnectivityGraph | undefined) => {
                  return geom && connectivity && cropGeometryNearGeometry(geom, asset.geom, cropRadius, connectivity);
              }
            : (geom: THREE.BufferGeometry | undefined, connectivity: MeshConnectivityGraph | undefined) => {
                  return geom && connectivity && cropGeometryNearCenter(geom, center, cropRadius, connectivity);
              };
        const cropGeoms =
            adjacency[idx]?.map(j => getCropped(archAssets[j]?.geom, archAssets[j]?.meshConnectivity)) ?? [];
        cropGeoms.push(getCropped(archScan, archScanConnectivity));

        const allCurtains = _.compact(cropGeoms)
            .filter(g => !isEmptyPositionOrIndex(g))
            .map(g => makeCurtainGeometry(g, asset.geom, direction));

        const curtainsGeom = allCurtains.length > 0 ? mergeBufferGeometries(allCurtains) : null;

        allCurtains.forEach(g => g.dispose());
        cropGeoms.forEach(g => g?.dispose());

        if (curtainsGeom) {
            curtainsGeom.computeBoundingBox();
            curtainsMap.set(asset.modelElementId, curtainsGeom);
        }
    });
}

export interface DesignAssetsWithExtras extends MinimallyProcessedDesignAssets {
    // Keys are collision geometry file base names
    collisions: Map<string, THREE.BufferGeometry>;
    // Keys are corresponding Model Element IDs
    curtains: Map<string, THREE.BufferGeometry>;
}

export function computeDesignDistances(design: MinimallyProcessedDesignAssets): DesignAssetsWithExtras {
    // Make sure everyone has normals and indexes
    design.upperMbScan && ComputeVertexNormalsByAngle(design.upperMbScan.geom);
    design.lowerMbScan && ComputeVertexNormalsByAngle(design.lowerMbScan.geom);

    // put occlusion heatmap on upper scan
    design.upperMbScan &&
        ensureDistanceAttributesInitialized(design.upperMbScan.geom, [], compact([design.lowerMbScan?.geom]));

    // put occlusion heatmap on lower scan
    design.lowerMbScan &&
        ensureDistanceAttributesInitialized(design.lowerMbScan.geom, [], compact([design.upperMbScan?.geom]));

    // get our assets and attribs assembled
    const cadAssets: DesignAssetWithDesignAttribs[] = design.cadAssets.map((dpa: DesignProjectAsset) =>
        createAssetWithDesignAttribs(dpa, design, {}),
    );

    const crownBridgeAndRemovablePartialDentureAssets = compact(
        cadAssets.map((assetWithAttributes: DesignAssetWithDesignAttribs): RestorativeAsset | undefined => {
            const mE = design.parsedCase.modelElements.find(
                mE => mE.modelElementID === assetWithAttributes.modelElementId,
            );
            if (!mE) {
                return undefined;
            }

            const teeth = DesignCaseFileUtil.getTeethByModelID(design.parsedCase.toothElements, mE.modelElementID);
            // make sure that we have a margin item in the mE
            const eligibleForHeatmaps = teeth.some((tE: ToothElement) =>
                EligibleForHeatmapToothTypes.includes(tE.cacheToothTypeClass),
            );
            if (!eligibleForHeatmaps) {
                return undefined;
            }

            const isRemovablePartialDenture = teeth.some(
                (tE: ToothElement) => tE.cacheToothTypeClass === ToothElementType.RemovablePartialDenture,
            );
            const isImplant = assetWithAttributes.unns.some(el => design.implantToothNumbers.has(el as ToothNumber));
            const isBridge = assetWithAttributes.unns.length > 1;
            const meshAdjacency = buildMeshAdjacencyAndVerticesFaces(assetWithAttributes.geom);
            return {
                ...assetWithAttributes,
                isImplant,
                isBridge,
                isRemovablePartialDenture,
                meshConnectivity: meshAdjacency,
            };
        }),
    );

    const upperAssets = crownBridgeAndRemovablePartialDentureAssets.filter((dawa: DesignAssetWithDesignAttribs) =>
        dawa.unns.some(unn => ToothUtils.toothIsUpper(unn)),
    );
    const lowerAssets = crownBridgeAndRemovablePartialDentureAssets.filter((dawa: DesignAssetWithDesignAttribs) =>
        dawa.unns.some(unn => ToothUtils.toothIsLower(unn)),
    );

    // Do curtains before heatmaps
    const upperGraph = design.upperMbScan && buildMeshAdjacencyAndVerticesFaces(design.upperMbScan.geom);
    const lowerGraph = design.lowerMbScan && buildMeshAdjacencyAndVerticesFaces(design.lowerMbScan.geom);
    const curtainsMap: Map<string, THREE.BufferGeometry> = new Map<string, THREE.BufferGeometry>();

    populateCurtainsForArchAssets(design, upperAssets, design.upperMbScan?.geom, upperGraph, curtainsMap);
    populateCurtainsForArchAssets(design, lowerAssets, design.lowerMbScan?.geom, lowerGraph, curtainsMap);

    upperAssets.forEach(dawa => {
        // put all heatmaps on upper crowns
        ComputeVertexNormalsByAngle(dawa.geom);
        ensureMeshIndex(dawa.geom);

        const curtainObj = curtainsMap.get(dawa.modelElementId ?? '');

        ensureDistanceAttributesInitialized(
            dawa.geom,
            compact([
                design.upperMbScan?.geom,
                ...upperAssets.filter(udawa => udawa !== dawa).map(udawa => udawa.geom),
            ]),
            compact([design.lowerMbScan?.geom, ...lowerAssets.map(ldawa => ldawa.geom)]),
            curtainObj,
        );
    });

    lowerAssets.forEach(dawa => {
        // put all heatmaps on lower crowns
        ComputeVertexNormalsByAngle(dawa.geom);
        ensureMeshIndex(dawa.geom);
        const curtainObj = curtainsMap.get(dawa.modelElementId ?? '');
        ensureDistanceAttributesInitialized(
            dawa.geom,
            compact([
                design.lowerMbScan?.geom,
                ...lowerAssets.filter(ldawa => ldawa !== dawa).map(ldawa => ldawa.geom),
            ]),
            compact([design.upperMbScan?.geom, ...upperAssets.map(udawa => udawa.geom)]),
            curtainObj,
        );
    });

    // compute isolines in the heatmaps
    const collisions = computeDesignCollisions(
        { [Jaw.UPPER]: design.upperMbScan, [Jaw.LOWER]: design.lowerMbScan },
        { [Jaw.UPPER]: upperAssets, [Jaw.LOWER]: lowerAssets },
    );

    return { ...design, collisions, curtains: curtainsMap };
}

/*
 * requires that the attributes the THREE.BufferGeometry elementsin the
 * design have been computed with computeDesignDistances
 * */
export function computeDesignCollisions(
    mbScans: Partial<Record<Jaw, DesignProjectAsset>>,
    crownAndBridgeAssets: Record<Jaw, DesignAssetWithDesignAttribs[]>,
): Map<string, THREE.BufferGeometry> {
    const collisionsMap = new Map<string, THREE.BufferGeometry>();

    const hasUpperItem = crownAndBridgeAssets[Jaw.UPPER].length > 0;
    if (hasUpperItem && mbScans[Jaw.UPPER]) {
        const collisionGeometry = createScanCollisionsGeometry(mbScans[Jaw.UPPER].geom);
        collisionsMap.set(getScanCollisionsGeometryName(Jaw.UPPER), collisionGeometry);
    }

    const hasLowerItem = crownAndBridgeAssets[Jaw.LOWER].length > 0;
    if (hasLowerItem && mbScans[Jaw.LOWER]) {
        const collisionGeometry = createScanCollisionsGeometry(mbScans[Jaw.LOWER].geom);
        collisionsMap.set(getScanCollisionsGeometryName(Jaw.LOWER), collisionGeometry);
    }

    Object.values(crownAndBridgeAssets)
        .flat()
        .forEach((cadAsset: DesignAssetWithDesignAttribs) => {
            const collisionGeometry = createRestorativeCollisionsGeometry(cadAsset.geom);
            collisionsMap.set(getRestorativeCollisionsGeometryName(cadAsset), collisionGeometry);
        });

    return collisionsMap;
}

function createScanCollisionsGeometry(scanGeometry: THREE.BufferGeometry): THREE.BufferGeometry {
    const faceToLineMapOcclusal: Map<number, THREE.Vector3[]> = computeAllCollisionLines(
        scanGeometry,
        AttributeName.OcclusalDistance,
    );
    const collisionsGeometry = collissionGroupsToGeometry([
        { name: 'faceToLineMapOcclusal', data: faceToLineMapOcclusal },
    ]);
    addIndexAsUnindexed(collisionsGeometry);
    return collisionsGeometry;
}

function createRestorativeCollisionsGeometry(restorativeGeometry: THREE.BufferGeometry): THREE.BufferGeometry {
    // We do not calculate the proximal collisions with curtains as the client expects the cached collision geometry
    // file to contain only proximal and occlusal collisions.
    const faceToLineMapProximal: Map<number, THREE.Vector3[]> = computeAllCollisionLines(
        restorativeGeometry,
        AttributeName.ProximalDistance,
    );
    const faceToLineMapOcclusal: Map<number, THREE.Vector3[]> = computeAllCollisionLines(
        restorativeGeometry,
        AttributeName.OcclusalDistance,
    );

    const collisionsGeometry = collissionGroupsToGeometry([
        { name: 'faceToLineMapProximal', data: faceToLineMapProximal },
        { name: 'faceToLineMapOcclusal', data: faceToLineMapOcclusal },
    ]);

    addIndexAsUnindexed(collisionsGeometry);
    return collisionsGeometry;
}

const COLLISIONS_PREFIX = 'QCCollision';

/**
 * Produces the base name, matching root-canal, for the scan collisions geometry file.
 */
function getScanCollisionsGeometryName(jaw: Jaw): string {
    return `${COLLISIONS_PREFIX}_PREP_SCAN_${Format.titleCase(jaw)}`;
}

/**
 * Produces the base name, matching root-canal, for the restorative collisions geometry file.
 */
function getRestorativeCollisionsGeometryName(asset: DesignAssetWithDesignAttribs): string {
    const toothNumber = asset.isBridge
        ? `${Math.min(...asset.unns)}x${Math.max(...asset.unns)}`
        : String(asset.unns[0]);
    const itemType = asset.isBridge ? 'BRDG' : 'CRN';
    const itemDescription = compact([asset.isImplant && 'IMP', itemType, toothNumber]).join('_');
    return `${COLLISIONS_PREFIX}${itemDescription}`;
}
