import last from 'lodash/fp/last'
import React from 'react'

type Route = string[]
type Props = {
  route: Route
  className: string
}

function isContained(a: Route, b: Route): boolean {
  return a.length < b.length && a.every((v, i) => b[i] === v)
}
function isGoingBack(prev: Route, curr: Route): boolean {
  return curr.length === 0 || isContained(curr, prev)
}
function shouldDiscardHistory(prev: Route, curr: Route): boolean {
  return !isContained(prev, curr) && !isContained(curr, prev)
}

const ScrollRetainer: React.FC<Props> = props => {
  const [prevRoute, setPrevRoute] = React.useState<Route>([])
  const [scrollPos, setScrollPos] = React.useState<SMap<number>>({})
  const [currentOffset, setOffset] = React.useState<number>(0)
  const routeChanged = prevRoute.length !== props.route.length
  const ref = React.useRef<HTMLDivElement>()
  const currRoute = last(props.route) || 'root'
  if (routeChanged) {
    if (isGoingBack(prevRoute, props.route)) {
      setOffset(scrollPos[currRoute])
    } else {
      setOffset(0)
    }
    if (shouldDiscardHistory(prevRoute, props.route)) {
      setScrollPos({ root: scrollPos[currRoute] })
    }
  }

  React.useLayoutEffect(() => {
    ref.current.scrollTop = currentOffset
    const newScrollPos: SMap<number> = { root: scrollPos.root }
    props.route.forEach(element => {
      newScrollPos[element] = scrollPos[element]
    })
    setScrollPos(newScrollPos)
  }, [currentOffset, currRoute])
  if (routeChanged) {
    setPrevRoute(props.route)
  }

  const onScroll = () =>
    setScrollPos({ ...scrollPos, [currRoute]: ref.current.scrollTop })

  return (
    <div className={props.className} ref={ref} onScroll={onScroll}>
      {props.children}
    </div>
  )
}

export default ScrollRetainer
