import { logger } from '../Utils/Logger';
import { AttributeName } from './BufferAttributeConstants';
import _ from 'lodash';
import * as THREE from 'three';

// These are the attributes we expect precomputed
export const PRIMARY_VERTEX_ATTR_NAMES = [
    AttributeName.ThicknessDistance,
    AttributeName.ProximalDistance,
    AttributeName.OcclusalDistance,
] as const;
export type PrimaryVertexAttr = (typeof PRIMARY_VERTEX_ATTR_NAMES)[number];

export type PLYExporterOptions = {
    disableUVs?: boolean;
    disableVertexColors?: boolean;
    customVertexAttributeNames?: string[];
};

/**
 * Minimally ensure that a buffer attribute exists in geometry.attributes
 * and also that it has a single scalar value per vertex.
 * */
function ensureVertexScalarAttribute(
    geom: THREE.BufferGeometry,
    attrName: string,
    numVerts: number,
): THREE.BufferAttribute | THREE.InterleavedBufferAttribute | undefined {
    const maybeAttr = geom.getAttribute(attrName);
    if (!maybeAttr) {
        return undefined;
    }
    // currently constrained to single float/scalar per vertex
    if (maybeAttr.count !== numVerts || maybeAttr.itemSize !== 1) {
        return undefined;
    }
    return maybeAttr;
}
/*
 * This class takes a ThreeJS mesh, and converts it into a PLY ArrayBuffer.
 * It is heavily based on ThreeJS's own PLY exporter (https://github.com/pmndrs/three-stdlib/blob/main/src/exporters/PLYExporter.ts),
 * but with some important changes:
 * 1) It adds support for texcoords generation, which is required by some of our ML tooling (eg margin line suggestion)
 * 2) It assumes that you always pass it a single Mesh, whereas 3JS' can take an entire scene.
 * 3) We only support little-endian binary outputs, whereas 3JS' supports both ascii and big endian modes.
 * 4) Assumes that the model contains indices.
 * Differences (2), (3), and (4) made the code contain significantly less edge cases for testing.
 */
export class PLYExporter {
    private vertexCount: number;
    private faceCount: number;

    private vertices: THREE.BufferAttribute | THREE.InterleavedBufferAttribute;
    private normals?: THREE.BufferAttribute | THREE.InterleavedBufferAttribute;
    private uvs?: THREE.BufferAttribute | THREE.InterleavedBufferAttribute;
    private colors?: THREE.BufferAttribute | THREE.InterleavedBufferAttribute;
    private customVertexAttributes: Map<string, THREE.BufferAttribute | THREE.InterleavedBufferAttribute>;
    private indices: THREE.BufferAttribute;

    constructor(
        private readonly mesh: THREE.Mesh<THREE.BufferGeometry, THREE.Material>,
        private readonly options?: PLYExporterOptions,
    ) {
        if (!('isBufferGeometry' in mesh.geometry) || !mesh.geometry.isBufferGeometry) {
            throw new Error('Geometry is not of type THREE.BufferGeometry.');
        }

        this.vertices = mesh.geometry.getAttribute(AttributeName.Position);

        if (this.vertices === undefined) {
            throw new Error('No vertices in mesh');
        }

        this.normals = mesh.geometry.getAttribute(AttributeName.Normal);
        this.uvs = mesh.geometry.getAttribute(AttributeName.TexCoord);
        this.colors = mesh.geometry.getAttribute(AttributeName.Color);

        // The caller owns the responsibility to ensure the names exist
        // in the actual geometry.attributes.  If a name is missing in the attribute
        // there will be a warning in the console, but no error is thrown.
        this.customVertexAttributes = new Map<string, THREE.BufferAttribute | THREE.InterleavedBufferAttribute>();
        options?.customVertexAttributeNames?.forEach(attrName => {
            const attribute = ensureVertexScalarAttribute(mesh.geometry, attrName, this.vertices.count);
            if (!attribute) {
                logger.warn(`Geometry did not have float attribute matching ${attrName}`);
                return;
            }
            this.customVertexAttributes.set(attrName, attribute);
        });

        const indices = mesh.geometry.getIndex();
        if (!indices) {
            throw new Error('No indices in mesh');
        }
        this.indices = indices;

        this.vertexCount = this.vertices.count;
        this.faceCount = this.indices ? this.indices.count / 3 : this.vertices.count / 3;
    }

    get includeNormals(): boolean {
        return !!this.normals;
    }

    get includeVertexColors(): boolean {
        return !!this.colors && !this.options?.disableVertexColors;
    }

    get includeUVs(): boolean {
        return !!this.uvs && !this.options?.disableUVs;
    }

    get includeTextureCoords(): boolean {
        return !!this.uvs;
    }

    // Generates the header of the PLY file, in plain-text.
    private getHeader(): string {
        const parts = [
            // Standard header
            `ply`,
            `format binary_little_endian 1.0`,

            // Here we define how many vertices we have, and all of the per-vertex attributes.
            `element vertex ${this.vertexCount}`,
            // Position floats
            `property float x`,
            `property float y`,
            `property float z`,
            // Normals
            this.includeNormals ? `property float nx` : undefined,
            this.includeNormals ? `property float ny` : undefined,
            this.includeNormals ? `property float nz` : undefined,

            // Custom Vertex Attributes
            ...Array.from(this.customVertexAttributes.keys()).map(
                (attribName: string) => `property float ${attribName}`,
            ),
            // UVs
            this.includeUVs ? `property float s` : undefined,
            this.includeUVs ? `property float t` : undefined,
            // Vertex colors
            this.includeVertexColors ? `property uchar red` : undefined,
            this.includeVertexColors ? `property uchar green` : undefined,
            this.includeVertexColors ? `property uchar blue` : undefined,

            // Here we define how many facets we have, and all of the per-facet attributes.
            `element face ${this.faceCount}`,
            `property list uchar int vertex_indices`,
            // Texture coords (required for margin line processing!!!)
            this.includeTextureCoords ? `property list uchar float texcoord` : undefined,

            // End of standard header
            `end_header\n`,
        ];

        return _.compact(parts).join('\n');
    }

    export(): ArrayBuffer {
        const header = this.getHeader();

        // Binary File Generation
        const headerBin = new TextEncoder().encode(header);

        // 3 position values at 4 bytes
        // 3 normal values at 4 bytes
        // 3 color channels with 1 byte
        // 2 uv values at 4 bytes
        // 4 bytes each per custom float vertex attributes
        const vertexListLength =
            this.vertexCount *
            (4 * 3 +
                (this.includeNormals ? 4 * 3 : 0) +
                (this.includeVertexColors ? 3 : 0) +
                (this.includeUVs ? 4 * 2 : 0) +
                4 * this.customVertexAttributes.size);

        // 3 facet indices at 12 bytes (1 to describe length of array (3), 4 bytes for each of the 3 numbers)
        // 6 facet texture coords at 25 bytes (1 to describe length of array (6), 4 bytes for each of the 6 numbers)
        // Note, this assumes we are only using tri's
        const faceListLength = this.faceCount * (4 * 3 + 1 + (this.includeTextureCoords ? 4 * 6 + 1 : 0));

        const output = new DataView(new ArrayBuffer(headerBin.length + vertexListLength + faceListLength));

        new Uint8Array(output.buffer).set(headerBin, 0);
        this.writeVertices(output, headerBin.length);
        this.writeFacets(output, headerBin.length + vertexListLength);

        return output.buffer;
    }

    // Writes all of the vertices to the output buffer.
    // NOTE: assumes that the output buffer is large enough!!
    private writeVertices(output: DataView, startingOffset: number) {
        // Generate attribute data
        const normalMatrixWorld = new THREE.Matrix3();
        normalMatrixWorld.getNormalMatrix(this.mesh.matrixWorld);
        const vertex = new THREE.Vector3();

        let offset = startingOffset;
        for (let i = 0, l = this.vertices.count; i < l; i++) {
            vertex.x = this.vertices.getX(i);
            vertex.y = this.vertices.getY(i);
            vertex.z = this.vertices.getZ(i);

            vertex.applyMatrix4(this.mesh.matrixWorld);

            // Position information
            output.setFloat32(offset, vertex.x, true);
            offset += 4;

            output.setFloat32(offset, vertex.y, true);
            offset += 4;

            output.setFloat32(offset, vertex.z, true);
            offset += 4;

            // Normals information
            if (this.includeNormals && this.normals) {
                vertex.x = this.normals.getX(i);
                vertex.y = this.normals.getY(i);
                vertex.z = this.normals.getZ(i);

                vertex.applyMatrix3(normalMatrixWorld).normalize();

                output.setFloat32(offset, vertex.x, true);
                offset += 4;

                output.setFloat32(offset, vertex.y, true);
                offset += 4;

                output.setFloat32(offset, vertex.z, true);
                offset += 4;
            }

            // UV information
            if (this.includeUVs && this.uvs) {
                output.setFloat32(offset, this.uvs.getX(i), true);
                offset += 4;

                output.setFloat32(offset, this.uvs.getY(i), true);
                offset += 4;
            }

            // Color information
            if (this.includeVertexColors && this.colors) {
                output.setUint8(offset, Math.floor(this.colors.getX(i) * 255));
                offset += 1;

                output.setUint8(offset, Math.floor(this.colors.getY(i) * 255));
                offset += 1;

                output.setUint8(offset, Math.floor(this.colors.getZ(i) * 255));
                offset += 1;
            }

            // Custom Vertex information
            for (const attr of this.customVertexAttributes.values()) {
                output.setFloat32(offset, attr.getX(i), true);
                offset += 4;
            }
        }
    }

    // Writes all of the facets to the output buffer.
    // NOTE: assumes that the output buffer is large enough!!
    private writeFacets(output: DataView, startingOffset: number) {
        let fOffset = startingOffset;

        for (let i = 0, l = this.indices.count; i < l; i += 3) {
            // Indices
            output.setUint8(fOffset, 3);
            fOffset += 1;

            output.setUint32(fOffset, this.indices.getX(i + 0), true);
            fOffset += 4;

            output.setUint32(fOffset, this.indices.getX(i + 1), true);
            fOffset += 4;

            output.setUint32(fOffset, this.indices.getX(i + 2), true);
            fOffset += 4;

            // Texture coordinates
            if (this.includeTextureCoords && this.uvs) {
                output.setUint8(fOffset, 6);
                fOffset += 1;

                for (let idx = 0; idx < 3; idx++) {
                    // const vertexId = this.vertices
                    output.setFloat32(fOffset, this.uvs.getX(this.indices.getX(i + idx)), true);
                    fOffset += 4;

                    output.setFloat32(fOffset, this.uvs.getY(this.indices.getX(i + idx)), true);
                    fOffset += 4;
                }
            }
        }
    }
}
