# Owner(s): ["oncall: pt2"]

"""
Shared xfail lists for unbacked symint tests.

These lists are used by both test_ops_unbacked.py (base tensor tests)
and test_dtensor_ops.py (DTensor tests with unbacked dimensions).
"""


def xfail(op_name, variant_name="", *, device_type=None, dtypes=None):
    return (op_name, variant_name, device_type, dtypes, True)


def skip(op_name, variant_name="", *, device_type=None, dtypes=None):
    return (op_name, variant_name, device_type, dtypes, False)


# Ops that have data-dependent errors with unbacked dimensions.
# These fail at the base tensor level (not DTensor-specific).
ops_dde_xfail = {
    xfail("_chunk_cat"),
    xfail("_unsafe_masked_index_put_accumulate"),
    xfail("_upsample_bilinear2d_aa"),
    xfail("addmv"),
    xfail("allclose"),
    xfail("as_strided_scatter"),
    xfail("baddbmm"),
    xfail("bernoulli"),
    xfail("cauchy"),
    xfail("cdist"),
    xfail("cholesky"),
    xfail("chunk"),
    xfail("combinations"),
    xfail("corrcoef"),
    xfail("cov"),
    xfail("cross"),
    xfail("cummax"),
    xfail("cummin"),
    xfail("cumulative_trapezoid"),
    xfail("diagonal_scatter"),
    xfail("diff"),
    xfail("dist"),
    xfail("dsplit"),
    xfail("equal"),
    xfail("exponential"),
    xfail("fft.fft"),
    xfail("fft.fft2"),
    xfail("fft.fftn"),
    xfail("fft.fftshift"),
    xfail("fft.hfft"),
    xfail("fft.hfft2"),
    xfail("fft.hfftn"),
    xfail("fft.ifft"),
    xfail("fft.ifft2"),
    xfail("fft.ifftn"),
    xfail("fft.ifftshift"),
    xfail("fft.ihfft"),
    xfail("fft.ihfft2"),
    xfail("fft.ihfftn"),
    xfail("fft.irfft"),
    xfail("fft.irfft2"),
    xfail("fft.irfftn"),
    xfail("fft.rfft"),
    xfail("fft.rfft2"),
    xfail("fft.rfftn"),
    xfail("float"),
    xfail("geometric"),
    xfail("geqrf"),
    xfail("gradient"),
    xfail("grid_sampler_2d"),
    xfail("hash_tensor"),
    xfail("histogram"),
    xfail("histogramdd"),
    xfail("hsplit"),
    xfail("index_fill"),
    xfail("inner"),
    xfail("kron"),
    xfail("linalg.cond"),
    xfail("linalg.cross"),
    xfail("linalg.householder_product"),
    xfail("linalg.ldl_solve"),
    xfail("linalg.lstsq"),
    xfail("linalg.lstsq", "grad_oriented"),
    xfail("linalg.lu_solve"),
    xfail("linalg.matrix_norm"),
    xfail("linalg.matrix_power"),
    xfail("linalg.matrix_rank"),
    xfail("linalg.matrix_rank", "hermitian"),
    xfail("linalg.multi_dot"),
    xfail("linalg.norm"),
    xfail("linalg.norm", "subgradients_at_zero"),
    xfail("linalg.pinv"),
    xfail("linalg.pinv", "hermitian"),
    xfail("linalg.pinv", "singular"),
    xfail("linalg.qr"),
    xfail("linalg.solve"),
    xfail("linalg.solve_ex"),
    xfail("linalg.solve_triangular"),
    xfail("linalg.svd"),
    xfail("linalg.svdvals"),
    xfail("linalg.tensorinv"),
    xfail("linalg.tensorsolve"),
    xfail("linalg.vander"),
    xfail("linalg.vector_norm"),
    xfail("log_normal"),
    xfail("logdet"),
    xfail("logsumexp"),
    xfail("lu_solve"),
    xfail("lu_unpack"),
    xfail("masked.amax"),
    xfail("masked.amin"),
    xfail("masked.argmax"),
    xfail("masked.argmin"),
    xfail("masked.cumprod"),
    xfail("masked.cumsum"),
    xfail("masked.log_softmax"),
    xfail("masked.logaddexp"),
    xfail("masked.logsumexp"),
    xfail("masked.mean"),
    xfail("masked.median"),
    xfail("masked.norm"),
    xfail("masked.prod"),
    xfail("masked.softmax"),
    xfail("masked.softmin"),
    xfail("masked.std"),
    xfail("masked.sum"),
    xfail("masked.var"),
    xfail("max_pool2d_with_indices_backward"),
    xfail("multinomial"),
    xfail("nanquantile"),
    xfail("nn.functional.adaptive_avg_pool1d"),
    xfail("nn.functional.adaptive_avg_pool2d"),
    xfail("nn.functional.adaptive_avg_pool3d"),
    xfail("nn.functional.adaptive_max_pool1d"),
    xfail("nn.functional.adaptive_max_pool2d"),
    xfail("nn.functional.adaptive_max_pool3d"),
    xfail("nn.functional.alpha_dropout"),
    xfail("nn.functional.avg_pool1d"),
    xfail("nn.functional.avg_pool2d"),
    xfail("nn.functional.avg_pool3d"),
    xfail("nn.functional.batch_norm"),
    xfail("nn.functional.bilinear"),
    xfail("nn.functional.binary_cross_entropy"),
    xfail("nn.functional.binary_cross_entropy_with_logits"),
    xfail("nn.functional.channel_shuffle"),
    xfail("nn.functional.conv1d"),
    xfail("nn.functional.conv2d"),
    xfail("nn.functional.conv3d"),
    xfail("nn.functional.conv_transpose1d"),
    xfail("nn.functional.conv_transpose2d"),
    xfail("nn.functional.conv_transpose3d"),
    xfail("nn.functional.cosine_similarity"),
    xfail("nn.functional.cross_entropy"),
    xfail("nn.functional.ctc_loss"),
    xfail("nn.functional.dropout"),
    xfail("nn.functional.dropout2d"),
    xfail("nn.functional.dropout3d"),
    xfail("nn.functional.embedding"),
    xfail("nn.functional.embedding_bag"),
    xfail("nn.functional.feature_alpha_dropout", "with_train"),
    xfail("nn.functional.feature_alpha_dropout", "without_train"),
    xfail("nn.functional.fractional_max_pool2d"),
    xfail("nn.functional.fractional_max_pool3d"),
    xfail("nn.functional.gaussian_nll_loss"),
    xfail("nn.functional.glu"),
    xfail("nn.functional.grid_sample"),
    xfail("nn.functional.group_norm"),
    xfail("nn.functional.huber_loss"),
    xfail("nn.functional.instance_norm"),
    xfail("nn.functional.interpolate", "area"),
    xfail("nn.functional.interpolate", "bicubic"),
    xfail("nn.functional.interpolate", "bilinear"),
    xfail("nn.functional.interpolate", "linear"),
    xfail("nn.functional.interpolate", "trilinear"),
    xfail("nn.functional.l1_loss"),
    xfail("nn.functional.local_response_norm"),
    xfail("nn.functional.max_pool1d"),
    xfail("nn.functional.max_pool2d"),
    xfail("nn.functional.max_pool3d"),
    xfail("nn.functional.max_unpool1d"),
    xfail("nn.functional.max_unpool1d", "grad"),
    xfail("nn.functional.max_unpool2d"),
    xfail("nn.functional.max_unpool2d", "grad"),
    xfail("nn.functional.max_unpool3d"),
    xfail("nn.functional.max_unpool3d", "grad"),
    xfail("nn.functional.mse_loss"),
    xfail("nn.functional.multi_head_attention_forward"),
    xfail("nn.functional.multilabel_margin_loss"),
    xfail("nn.functional.nll_loss"),
    xfail("nn.functional.pad", "circular"),
    xfail("nn.functional.pad", "reflect"),
    xfail("nn.functional.pad", "replicate"),
    xfail("nn.functional.pad", "replicate_negative"),
    xfail("nn.functional.pdist"),
    xfail("nn.functional.pixel_shuffle"),
    xfail("nn.functional.prelu"),
    xfail("nn.functional.rrelu"),
    xfail("nn.functional.scaled_dot_product_attention"),
    xfail("nn.functional.smooth_l1_loss"),
    xfail("nn.functional.unfold"),
    xfail("nn.functional.upsample_bilinear"),
    xfail("norm"),
    xfail("norm", "fro"),
    xfail("norm", "nuc"),
    xfail("normal"),
    xfail("normal", "in_place"),
    xfail("normal", "number_mean"),
    xfail("ormqr"),
    xfail("pca_lowrank"),
    xfail("pinverse"),
    xfail("quantile"),
    xfail("qr"),
    xfail("rand_like"),
    xfail("randint_like"),
    xfail("randn_like"),
    xfail("repeat_interleave"),
    xfail("resize_"),
    xfail("resize_as_"),
    xfail("roll"),
    xfail("searchsorted"),
    xfail("sparse.mm", "reduce"),
    xfail("split"),
    xfail("stft"),
    xfail("svd"),
    xfail("svd_lowrank"),
    xfail("sum_to_size"),
    xfail("take"),
    xfail("take_along_dim"),
    xfail("tensordot"),
    xfail("tensor_split"),
    xfail("to_sparse"),
    xfail("trapezoid"),
    xfail("trapz"),
    xfail("unbind"),
    xfail("unbind_copy"),
    xfail("uniform"),
    xfail("unsafe_chunk"),
    xfail("unsafe_split"),
    xfail("vsplit"),
}

# Ops that skip for unbacked tests (no valid samples with markable dims)
ops_unbacked_skip = {
    skip("arange"),
    skip("broadcast_shapes"),
    skip("empty"),
    skip("empty_permuted"),
    skip("empty_strided"),
    skip("eye"),
    skip("full"),
    skip("item"),
    skip("linspace"),
    skip("linspace", "tensor_overload"),
    skip("logspace"),
    skip("logspace", "tensor_overload"),
    skip("ones"),
    skip("randint"),
    skip("randn"),
    skip("scalar_tensor"),
    skip("signal.windows.bartlett"),
    skip("signal.windows.blackman"),
    skip("signal.windows.cosine"),
    skip("signal.windows.exponential"),
    skip("signal.windows.gaussian"),
    skip("signal.windows.general_cosine"),
    skip("signal.windows.general_hamming"),
    skip("signal.windows.hamming"),
    skip("signal.windows.hann"),
    skip("signal.windows.kaiser"),
    skip("signal.windows.nuttall"),
    skip("zeros"),
    # Sparse ops that can't be deepcopied
    skip("sparse.sampled_addmm"),
}
