[MLIR][TORCH] Add e2e support for aten.pow.Scalar op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/2432/head
Vivek Khandelwal 2023-08-31 12:20:22 +00:00
parent aa15f0d4ca
commit 5c43daa3bf
5 changed files with 77 additions and 21 deletions

View File

@ -660,6 +660,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
divTensorMode.emitError("invalid rounding mode");
return nullptr;
}
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
if (!dtype.isa<mlir::FloatType>()) {
pow.emitError("unimplemented: non-floating point dtype");
return nullptr;
}
Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype);
Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
return b.create<math::PowFOp>(loc, selfPromoted, expPromoted);
}
if (auto pow = dyn_cast<AtenPowTensorScalarOp>(op)) {
if (!pow.getType()
.cast<ValueTensorType>()
@ -1162,20 +1174,20 @@ public:
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp,
AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp,
AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenRemainderScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenAtanOp, AtenRealOp, AtenImagOp>(op))
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenRealOp, AtenImagOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1697,17 +1709,18 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenAtan2Op,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPreluOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp>();
AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp,
AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenLogicalXorOp,
AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -6522,6 +6522,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.pow.Scalar\"(%arg0: !torch.float, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -9747,6 +9751,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %6 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%3, %5) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Scalar\"(%arg0: !torch.number, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.ListConstruct %none, %0#0 : (!torch.none, !torch.int) -> !torch.list<optional<int>>\n"
" %2 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.get_dtype_of_scalar(%arg0) : (!torch.number) -> !torch.int\n"
" %3 = torch.prim.ListConstruct %2, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.pow.Tensor_Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -446,7 +446,6 @@ PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
if output_type == OutputType.RAW:
return mb.module
# mb.module.dump()
option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + \
" extra-library=" + extra_library_file_name + "}"
run_pipeline_with_repro_report(

View File

@ -267,6 +267,9 @@ def atenremainderScalar〡shape(self: List[int], other: float) -> List[int
def atenfloor_divideScalar〡shape(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)
def atenpowScalar〡shape(self: float, exponent: List[int]) -> List[int]:
return upstream_shape_functions.unary(exponent)
def atenpowTensor_Scalar〡shape(self: List[int], exponent: float) -> List[int]:
return upstream_shape_functions.unary(self)
@ -2704,6 +2707,12 @@ def atenfloor_divideScalar〡dtype(self_rank_dtype: Tuple[int, int], other
dtypes = [self_dtype, get_dtype_of_scalar(other)]
return promote_dtypes(ranks, dtypes)
def atenpowScalar〡dtype(self: Union[int, float, complex], exponent_rank_dtype: Tuple[int, int]) -> int:
exponent_rank, exponent_dtype = exponent_rank_dtype
ranks: List[Optional[int]] = [None, exponent_rank]
dtypes = [get_dtype_of_scalar(self), exponent_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, exponent=1.0))
def atenpowTensor_Scalar〡dtype(self_rank_dtype: Tuple[int, int], exponent: Union[int, float, complex]) -> int:

View File

@ -1468,6 +1468,28 @@ def ElementwisePowTensorBroadcastStaticModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwisePowScalarModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.float32, True),
])
def forward(self, exp):
return torch.pow(2.0, exp)
@register_test_case(module_factory=lambda: ElementwisePowScalarModule())
def ElementwisePowScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseToDtypeF32ToI64Module(torch.nn.Module):
def __init__(self):