import type { TreeProps, TreeState } from '@react-stately/tree'
import { useTreeState } from '@react-stately/tree'
import type {
  Collection,
  CollectionChildren,
  CollectionElement,
  ItemElement,
  Node,
  SectionElement,
} from '@react-types/shared'
import type { JSXElementConstructor, ReactElement, ReactNode } from 'react'
import { cloneElement, isValidElement } from 'react'

import unwrapFragment from '@src/lib/unwrapFragment'
import type { MenuItemProps } from '@ui/Menu/MenuListItem'
import type { MenuListSectionProps } from '@ui/Menu/MenuListSection/MenuListSection'
import type { SubMenuProviderProps } from '@ui/Menu/SubMenuProvider'
import type { MenuChild } from '@ui/Menu/types'

export type MenuNode<T> = Omit<Node<T>, 'type' | 'props' | 'childNodes'> & {
  childNodes: Iterable<MenuNode<T>>
} & (
    | {
        type: 'item'
        props: MenuItemProps<T> & { isSubMenu?: never }
      }
    | {
        type: 'section'
        props: MenuListSectionProps<T>
      }
    | {
        type: 'separator'
      }
    | {
        type: 'item'
        props: SubMenuProviderProps<T> & { isSubMenu: true }
      }
  )

export type MenuTreeState<T> = Omit<TreeState<T>, 'collection'> & {
  readonly collection: Collection<MenuNode<T>>
}

interface MenuTreeStateProps<T> extends Omit<TreeProps<T>, 'children'> {
  children: MenuChild<T>
}

export default function useMenuTreeState<T extends object>({
  children,
  ...props
}: MenuTreeStateProps<T>) {
  const unwrappedChildren = unwrapChildren(children)
  return useTreeState({ ...props, children: unwrappedChildren }) as MenuTreeState<T>
}

function unwrapChildren<T>(children: MenuChild<T> | ReactNode): CollectionChildren<T> {
  const unwrapChild = (child: unknown) => {
    if (!isValidElement(child)) {
      return null as unknown as ItemElement<T>
    }

    const validUnwrappedChild = unwrapFragment(child)

    if (Array.isArray(validUnwrappedChild)) {
      return validUnwrappedChild.map(unwrapChildren) as (
        | SectionElement<T>
        | ItemElement<T>
      )[]
    }

    if (!isValidElement(validUnwrappedChild)) {
      return null as unknown as ItemElement<T>
    }

    // check if its a section and unwrap its children
    if (isSection(validUnwrappedChild)) {
      return cloneElement(
        validUnwrappedChild,
        validUnwrappedChild.props,
        unwrapChildren(validUnwrappedChild.props.children) as ItemElement<T>[],
      ) as SectionElement<T>
    }

    return validUnwrappedChild as ItemElement<T>
  }

  if (typeof children === 'function') {
    return children as (item: T) => CollectionElement<T>
  }

  if (Array.isArray(children)) {
    return children.flatMap(unwrapChild)
  }

  return unwrapChild(children)
}

function isSection(
  element: ReactElement<unknown, string | JSXElementConstructor<any>>,
): element is ReactElement<{ title: string; children: ReactNode }> {
  if (!element.props || typeof element.props !== 'object') {
    return false
  }

  const hasTitle = 'title' in element.props && typeof element.props.title === 'string'
  const hasChildren =
    'children' in element.props && typeof element.props.children === 'object'

  return hasTitle && hasChildren
}
