mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add e2e support for aten.pow.Scalar op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/2432/head
parent
aa15f0d4ca
commit
5c43daa3bf
|
@ -660,6 +660,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
divTensorMode.emitError("invalid rounding mode");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
||||
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
||||
if (!dtype.isa<mlir::FloatType>()) {
|
||||
pow.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype);
|
||||
Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||
return b.create<math::PowFOp>(loc, selfPromoted, expPromoted);
|
||||
}
|
||||
|
||||
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
|
||||
if (!pow.getType()
|
||||
.cast<ValueTensorType>()
|
||||
|
@ -1162,20 +1174,20 @@ public:
|
|||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
|
||||
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
|
||||
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
|
||||
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
|
||||
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
|
||||
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
|
||||
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
|
||||
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
|
||||
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp,
|
||||
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1697,17 +1709,18 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
|
||||
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp,
|
||||
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
|
||||
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
|
||||
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
|
||||
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
|
||||
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
|
||||
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
|
||||
AtenRealOp, AtenImagOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -6522,6 +6522,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.pow.Scalar\"(%arg0: !torch.float, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -9747,6 +9751,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -446,7 +446,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
|
|||
if output_type == OutputType.RAW:
|
||||
return mb.module
|
||||
|
||||
# mb.module.dump()
|
||||
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
|
||||
" extra-library=" + extra_library_file_name + "}"
|
||||
run_pipeline_with_repro_report(
|
||||
|
|
|
@ -267,6 +267,9 @@ def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int
|
|||
def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇pow〇Scalar〡shape(self: float, exponent: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(exponent)
|
||||
|
||||
def aten〇pow〇Tensor_Scalar〡shape(self: List[int], exponent: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -2704,6 +2707,12 @@ def aten〇floor_divide〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other
|
|||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
def aten〇pow〇Scalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
|
||||
exponent_rank, exponent_dtype = exponent_rank_dtype
|
||||
ranks: List[Optional[int]] = [None, exponent_rank]
|
||||
dtypes = [get_dtype_of_scalar(self), exponent_dtype]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0))
|
||||
def aten〇pow〇Tensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int:
|
||||
|
|
|
@ -1468,6 +1468,28 @@ def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwisePowScalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, exp):
|
||||
return torch.pow(2.0, exp)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwisePowScalarModule())
|
||||
def ElementwisePowScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue