import React, { useCallback, useRef, useMemo, useState } from 'react';
import ReactFlow, {
    Background,
    Controls,
    applyEdgeChanges,
    applyNodeChanges,
    useReactFlow,
    useStoreApi,
    ControlButton
} from 'reactflow';
import 'reactflow/dist/style.css';
import CustomNode from './CustomNode';
import CustomEdge from './CustomEdge';
import { useWorkflow } from '../WorkflowContext';
import { timbalGrey } from 'components/CustomColors';
import { LeakAddOutlined, LeakRemoveOutlined } from '@mui/icons-material';

const MIN_DISTANCE = 300;

const getHandleCoordinates = (node, handleType) => {
    const handle = handleType === 'source' ? { x: node.width, y: node.height / 2 } : { x: 0, y: node.height / 2 };
    return {
        x: node.position.x + handle.x,
        y: node.position.y + handle.y
    };
};

const ReactFlowCanvas = () => {
    const { bodyCanva, setContextMenu, repositionStep, removeStep, removeLink, addLink, lockStep, unlockStep, runningFlow, getCanvaInputs, getCanvaOutputs } = useWorkflow();
    const lastPosition = useRef({});
    const [nodes, setNodes] = useState([]);
    const [edges, setEdges] = useState([]);
    const [autoConnectEnabled, setAutoConnectEnabled] = useState(true);
    const [isHoveringNode, setIsHoveringNode] = useState(false);

    const { getNode } = useReactFlow();
    const store = useStoreApi();

    const nodeTypes = useMemo(
        () => ({
            customNode: CustomNode
        }),
        [],
    );

    const edgeTypes = useMemo(
        () => ({
            customEdge: CustomEdge
        }),
        [],
    );

    useMemo(() => {
        if (bodyCanva) {
            if (bodyCanva.nodes) {
                setNodes(bodyCanva.nodes);
            }
            if (bodyCanva.edges) {
                setEdges(bodyCanva.edges);
            }
        }
    }, [bodyCanva]);

    const toggleAutoConnect = useCallback(() => {
        setAutoConnectEnabled(prev => !prev);
    }, []);

    const getClosestEdge = useCallback((node) => {
        if (!autoConnectEnabled) return null;

        const { nodeInternals } = store.getState();
        const storeNode = getNode(node.id);

        if (!storeNode) {
            return null;
        }

        const draggedSourcePos = getHandleCoordinates(storeNode, 'source');
        const draggedTargetPos = getHandleCoordinates(storeNode, 'target');

        let closestNode = null;
        let minDistance = Infinity;
        let isSource = true;

        Array.from(nodeInternals.values()).forEach((otherNode) => {
            if (otherNode.id !== storeNode.id) {
                const otherSourcePos = getHandleCoordinates(otherNode, 'source');
                const otherTargetPos = getHandleCoordinates(otherNode, 'target');

                // Check distance when dragged node is source
                const distanceAsSource = Math.sqrt(
                    Math.pow(draggedSourcePos.x - otherTargetPos.x, 2) +
                    Math.pow(draggedSourcePos.y - otherTargetPos.y, 2)
                );

                // Check distance when dragged node is target
                const distanceAsTarget = Math.sqrt(
                    Math.pow(draggedTargetPos.x - otherSourcePos.x, 2) +
                    Math.pow(draggedTargetPos.y - otherSourcePos.y, 2)
                );

                if (distanceAsSource < minDistance && distanceAsSource < MIN_DISTANCE) {
                    minDistance = distanceAsSource;
                    closestNode = otherNode;
                    isSource = true;
                }

                if (distanceAsTarget < minDistance && distanceAsTarget < MIN_DISTANCE) {
                    minDistance = distanceAsTarget;
                    closestNode = otherNode;
                    isSource = false;
                }
            }
        });

        if (!closestNode) {
            return null;
        }

        return {
            id: `${isSource ? storeNode.id : closestNode.id}-${isSource ? closestNode.id : storeNode.id}`,
            source: isSource ? storeNode.id : closestNode.id,
            target: isSource ? closestNode.id : storeNode.id,
            sourceHandle: 'source',
            targetHandle: 'target',
        };
    }, [getNode, autoConnectEnabled, store]);

    const onNodeDrag = useCallback(
        (_, node) => {
            if (!autoConnectEnabled) return;

            const closeEdge = getClosestEdge(node);
            setEdges((eds) => {
                // Remove any temporary edges for this node, whether it's the source or target
                const filteredEdges = eds.filter(e =>
                    !e.id.includes('-temp') ||
                    (e.source !== node.id && e.target !== node.id)
                );

                if (closeEdge) {
                    // Check if this connection already exists
                    const existingEdge = filteredEdges.find(
                        e => e.source === closeEdge.source && e.target === closeEdge.target
                    );

                    if (!existingEdge) {
                        // Add the new temporary edge only if it doesn't already exist
                        const tempEdge = {
                            ...closeEdge,
                            id: `${closeEdge.id}-temp`, // Add a temp suffix to identify this edge
                            style: { stroke: '#999', strokeDasharray: '5,5' } // Optional: style for temporary edge
                        };
                        return [...filteredEdges, tempEdge];
                    }
                }

                return filteredEdges;
            });
        },
        [getClosestEdge, autoConnectEnabled]
    );

    const onNodeDragStop = useCallback(
        (_, node) => {
            if (!autoConnectEnabled) return;

            const closeEdge = getClosestEdge(node);
            setEdges((eds) => {
                // Remove any temporary edges
                const remainingEdges = eds.filter(e => !e.id.includes('-temp'));

                if (closeEdge) {
                    // Check if this connection already exists
                    const existingEdge = remainingEdges.find(
                        e => e.source === closeEdge.source && e.target === closeEdge.target
                    );

                    if (!existingEdge) {
                        addLink({ step_id: closeEdge.source, next_step_id: closeEdge.target });
                        return [...remainingEdges, closeEdge];
                    }
                }

                return remainingEdges;
            });
        },
        [getClosestEdge, addLink, autoConnectEnabled]
    );

    const onNodesChange = useCallback(
        async (changes) => {
            changes.sort((a, b) => (a.selected && !b.selected) ? -1 : 1);
            for (let change of changes) {
                if (change.type === 'select') {
                    if (change.selected) {
                        lockStep({ step_id: change.id });
                    } else {
                        unlockStep({ step_id: change.id });
                    }
                } else if (change.type === 'position') {
                    if (lastPosition.current[change.id]) {
                        if (lastPosition.current[change.id].dragging && !change.dragging) {
                            repositionStep({
                                step_id: change.id,
                                x: lastPosition.current[change.id].position.x,
                                y: lastPosition.current[change.id].position.y
                            });
                            lastPosition.current[change.id] = null;
                        } else {
                            lastPosition.current[change.id] = change;
                        }
                    } else {
                        lastPosition.current[change.id] = change;
                    }
                } else if (change.type === 'remove') {
                    await removeStep({ step_id: change.id });
                    getCanvaInputs();
                    getCanvaOutputs();
                    return;
                }
            }

            setNodes((n) => {
                const newNodes = applyNodeChanges(changes, n);
                return newNodes;
            });
        },
        [
            setNodes,
            repositionStep,
            removeStep,
            lockStep,
            unlockStep,
            getCanvaInputs,
            getCanvaOutputs
        ]
    );

    const onEdgesChange = useCallback(
        (changes) => {
            for (let change of changes) {
                if (change.type === 'remove') {
                    const edge = edges.find((edge) => edge.id === change.id);
                    if (edge && edge.selected) {
                        removeLink({ link_id: change.id });
                    }
                    return;
                }
            }
            setEdges((e) => {
                const newEdges = applyEdgeChanges(changes, e);
                return newEdges;
            });
        },
        [
            edges,
            setEdges,
            removeLink
        ]
    );

    const handleOnConnect = (params) => {
        addLink({ step_id: params.source, next_step_id: params.target });
    }

    const handleContextMenu = (event, item) => {
        event.preventDefault();
        setContextMenu({
            mouseX: event.clientX - 2,
            mouseY: event.clientY - 4,
            item,
        });
    };

    const onNodeMouseEnter = useCallback(() => {
        setIsHoveringNode(true);
    }, []);

    const onNodeMouseLeave = useCallback(() => {
        setIsHoveringNode(false);
    }, []);

    if (!bodyCanva) return null; // Render null if not ready

    return (
        <ReactFlow
            nodes={nodes}
            edges={edges}
            onNodesChange={onNodesChange}
            onEdgesChange={onEdgesChange}
            onConnect={handleOnConnect}
            onNodeContextMenu={(event, node) => handleContextMenu(event, node)}
            onEdgeContextMenu={(event, edge) => handleContextMenu(event, edge)}
            onNodeDrag={autoConnectEnabled ? onNodeDrag : null}
            onNodeDragStop={autoConnectEnabled ? onNodeDragStop : null}
            onNodeMouseEnter={onNodeMouseEnter}
            onNodeMouseLeave={onNodeMouseLeave}
            nodeDragThreshold={5}
            zoomOnScroll={false}
            panOnScroll={!isHoveringNode}
            panOnScrollSpeed={1}
            minZoom={0.1}
            maxZoom={2.0}
            fitView={true}
            fitViewOptions={{ padding: 0.2, includeHiddenNodes: false }}
            deleteKeyCode={['Backspace', 'Delete']}
            proOptions={{ hideAttribution: true }}
            nodeTypes={nodeTypes}
            edgeTypes={edgeTypes}
            defaultEdgeOptions={{ type: 'customEdge' }}
            connectionLineType='smoothstep'
            edgesUpdatable={!runningFlow}
            edgesFocusable={!runningFlow}
            nodesDraggable={!runningFlow}
            nodesConnectable={!runningFlow}
        >
            <Controls>
                <ControlButton
                    onClick={toggleAutoConnect}
                    title={autoConnectEnabled ? "Disable Auto-Connect" : "Enable Auto-Connect"}
                >
                    {autoConnectEnabled ?
                        <LeakAddOutlined /> :
                        <LeakRemoveOutlined />
                    }
                </ControlButton>
            </Controls>
            <Background color={timbalGrey[300]} size={2} style={{ backgroundColor: '#fefdf7' }} />
        </ReactFlow>
    );
};

export default ReactFlowCanvas;
