mirror of https://github.com/llvm/torch-mlir
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
import torch
|
|
from torch._decomp import get_decompositions
|
|
|
|
# default decompositions pulled from SHARK / torch._decomp
|
|
DEFAULT_DECOMPOSITIONS = [
|
|
torch.ops.aten.embedding_dense_backward,
|
|
torch.ops.aten.native_layer_norm_backward,
|
|
torch.ops.aten.slice_backward,
|
|
torch.ops.aten.select_backward,
|
|
torch.ops.aten.norm.ScalarOpt_dim,
|
|
torch.ops.aten.native_group_norm,
|
|
torch.ops.aten.upsample_bilinear2d.vec,
|
|
torch.ops.aten.split.Tensor,
|
|
torch.ops.aten.split_with_sizes,
|
|
torch.ops.aten.native_layer_norm,
|
|
torch.ops.aten.masked_fill.Tensor,
|
|
torch.ops.aten.masked_fill.Scalar,
|
|
torch.ops.aten.t,
|
|
torch.ops.aten.addmm,
|
|
# decompositions that aid us in handling nn.BatchNorm2d
|
|
torch.ops.aten._native_batch_norm_legit_functional,
|
|
torch.ops.aten._native_batch_norm_legit_no_training,
|
|
torch.ops.aten._native_batch_norm_legit,
|
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
|
torch.ops.aten.squeeze.dims,
|
|
# decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions
|
|
torch.ops.aten.soft_margin_loss,
|
|
torch.ops.aten.im2col,
|
|
torch.ops.aten._euclidean_dist,
|
|
torch.ops.aten.index_copy,
|
|
torch.ops.aten.index_copy_,
|
|
torch.ops.aten.grid_sampler_2d,
|
|
torch.ops.aten.log_sigmoid_forward,
|
|
torch.ops.aten.unsafe_split.Tensor,
|
|
torch.ops.aten.binary_cross_entropy,
|
|
torch.ops.aten.dot,
|
|
torch.ops.aten._adaptive_avg_pool2d,
|
|
torch.ops.aten._prelu_kernel,
|
|
torch.ops.aten.full,
|
|
torch.ops.aten._log_softmax,
|
|
torch.ops.aten.nll_loss_forward,
|
|
torch.ops.aten.nll_loss_backward,
|
|
torch.ops.aten._to_copy,
|
|
torch.ops.aten._log_softmax_backward_data,
|
|
torch.ops.aten.lift_fresh_copy.default,
|
|
torch.ops.aten._unsafe_index.Tensor,
|
|
torch.ops.aten.linspace.default,
|
|
torch.ops.aten.triu.default,
|
|
torch.ops.aten.nan_to_num.default,
|
|
]
|
|
|
|
|
|
def get_decomposition_table():
|
|
return get_decompositions(DEFAULT_DECOMPOSITIONS)
|