import { useEffect, useCallback } from 'react';

const getFocusableElements = (ref, focusableElements = 'button, input, a') => {
  const foundFocusableElements = ref.current.querySelectorAll(
    `${focusableElements}, [tabindex='0']`
  );

  return [
    foundFocusableElements[0],
    foundFocusableElements[foundFocusableElements.length - 1],
    foundFocusableElements
  ];
};

const useTrapFocus = (ref, trapped, elementsToTrap) => {
  const trapFocus = useCallback(
    (event) => {
      const [firstFocusableButton, lastFocusableButton] = getFocusableElements(
        ref,
        elementsToTrap
      );

      if (
        event.key === 'Tab' &&
        event.shiftKey === false &&
        event.target === lastFocusableButton
      ) {
        event.preventDefault();
        firstFocusableButton.focus();
      } else if (
        event.key === 'Tab' &&
        event.shiftKey === true &&
        event.target === firstFocusableButton
      ) {
        event.preventDefault();
        lastFocusableButton.focus();
      }
    },
    [ref, elementsToTrap]
  );

  useEffect(() => {
    if (trapped) {
      document.addEventListener('keydown', trapFocus, false);
    }
    return () => {
      document.removeEventListener('keydown', trapFocus, false);
    };
  }, [trapped, trapFocus]);
};

export default useTrapFocus;
