import os

import torch

from ._internally_replaced_utils import _get_extension_path


def _load_library(lib_name):
    """Load a library, optionally warning on failure based on env variable.

    Returns True if the library was loaded successfully, False otherwise.
    """
    try:
        lib_path = _get_extension_path(lib_name)
        torch.ops.load_library(lib_path)
        return True
    except (ImportError, OSError) as e:
        if os.environ.get("TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS"):
            import warnings

            warnings.warn(f"Failed to load '{lib_name}' extension: {type(e).__name__}: {e}")
        return False


def _has_ops():
    return False


if _load_library("_C"):

    def _has_ops():  # noqa: F811
        return True


def _assert_has_ops():
    if not _has_ops():
        raise RuntimeError(
            "Couldn't load custom C++ ops. This can happen if your PyTorch and "
            "torchvision versions are incompatible, or if you had errors while compiling "
            "torchvision from source. For further information on the compatible versions, check "
            "https://github.com/pytorch/vision#installation for the compatibility matrix. "
            "Please check your PyTorch version with torch.__version__ and your torchvision "
            "version with torchvision.__version__ and verify if they are compatible, and if not "
            "please reinstall torchvision so that it matches your PyTorch install. "
            "Set TORCHVISION_WARN_WHEN_EXTENSION_LOADING_FAILS=1 and retry to get more details."
        )


def _check_cuda_version():
    """
    Make sure that CUDA versions match between the pytorch install and torchvision install
    """
    if not _has_ops():
        return -1
    from torch.version import cuda as torch_version_cuda

    _version = torch.ops.torchvision._cuda_version()
    if _version != -1 and torch_version_cuda is not None:
        tv_version = str(_version)
        assert int(tv_version) >= 12000, f"Unexpected CUDA version {_version}, please file a bug report."
        tv_major = int(tv_version[0:2])
        tv_minor = int(tv_version[3])
        t_version = torch_version_cuda.split(".")
        t_major = int(t_version[0])
        t_minor = int(t_version[1])
        if t_major != tv_major:
            raise RuntimeError(
                "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
                f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
                f"CUDA Version={tv_major}.{tv_minor}. "
                "Please reinstall the torchvision that matches your PyTorch install."
            )
    return _version


_check_cuda_version()
