diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index a0c12c796..7f9be0987 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -50,4 +50,5 @@ TOSA_PASS_SET = { "SqueezeDimModule_identity", "SqueezeDimModule_unitDim", "ReturnTwoTensorF32I64_basic", + "ElementwisePowModule_basic", } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 6d3a28c17..ae98dd5e3 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -635,6 +635,49 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp { } }; +// FIXME(AG): This will eventually go into a Tosa*Utils file +// Convert an fp32 scalar into tosa fp32 tensor. +static LogicalResult +tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op, + Value torchScalarValue, Value &tosaTensor) { + double scalarValue; + + if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue))) + return failure(); + + // Construct a tosa.const + tosaTensor = + mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue); + + return success(); +} + +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenPowTensorScalarOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + Value self = adaptor.self(); + auto selfTy = self.getType().template cast(); + + if (!selfTy) + return op.emitError("Only ranked tensor types supported in TOSA Pow"); + + if (!selfTy.getElementType().isa()) + return op.emitError("Only floating-point datatype legalization supported"); + + Value expTensor; + Value expScalar = op.exponent(); + if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar, + expTensor))) + return op.emitError("Currently only scalar constants are supported for " + "conversion in TOSA Pow operation"); + + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, expTensor); + + return success(); +} } // namespace // ----------------------------------------------------------------------------- @@ -740,6 +783,7 @@ public: INSERT_ATENOP_PATTERN(AtenMulTensorOp); INSERT_ATENOP_PATTERN(AtenDivTensorOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp); #undef INSERT_ATENOP_PATTERN if (failed(applyPartialConversion(getOperation(), target, diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index ba513e6ff..a04bc8e1a 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" @@ -101,6 +102,9 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( pm.addNestedPass(createCSEPass()); } + // Add the ToStandard pass for lowering some ops + pm.addNestedPass(createTosaToStandard()); + // Finish the type conversion from `torch` types to the types of the // TOSA backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp index d4052f653..4fb3f5d8b 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyTosaBackendContract.cpp @@ -44,6 +44,7 @@ class VerifyTosaBackendContractPass target.addLegalDialect(); target.addDynamicallyLegalOp(opHasLegalTypes); target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalOp(opHasLegalTypes); RewritePatternSet patterns(context); if (failed(applyFullConversion(module, target, std::move(patterns)))) { diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 99de8868a..2aaa71edd 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -330,3 +330,20 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v %0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32> } + +// ----- + +// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00 +// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor} : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32> +// CHECK: } +func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %fp0 = torch.constant.float 3.123400e+00 + %0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32> + return %0 : !torch.vtensor<[?,?],f32> +}