* [tosa] Support for AtenPowTensorScalarOp with constant Scalar as input

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>
pull/522/head snapshot-20220112.201
Anup Gangwar 2022-01-11 12:49:17 -08:00 committed by Yi Zhang
parent 077e55d756
commit d69d29b7a6
5 changed files with 67 additions and 0 deletions

View File

@ -50,4 +50,5 @@ TOSA_PASS_SET = {
"SqueezeDimModule_identity", "SqueezeDimModule_identity",
"SqueezeDimModule_unitDim", "SqueezeDimModule_unitDim",
"ReturnTwoTensorF32I64_basic", "ReturnTwoTensorF32I64_basic",
"ElementwisePowModule_basic",
} }

View File

@ -635,6 +635,49 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
} }
}; };
// FIXME(AG): This will eventually go into a Tosa*Utils file
// Convert an fp32 scalar into tosa fp32 tensor.
static LogicalResult
tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op,
Value torchScalarValue, Value &tosaTensor) {
double scalarValue;
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
return failure();
// Construct a tosa.const
tosaTensor =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.self();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return op.emitError("Only ranked tensor types supported in TOSA Pow");
if (!selfTy.getElementType().isa<mlir::FloatType>())
return op.emitError("Only floating-point datatype legalization supported");
Value expTensor;
Value expScalar = op.exponent();
if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar,
expTensor)))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
rewriter.replaceOpWithNewOp<tosa::PowOp>(
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
return success();
}
} // namespace } // namespace
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -740,6 +783,7 @@ public:
INSERT_ATENOP_PATTERN(AtenMulTensorOp); INSERT_ATENOP_PATTERN(AtenMulTensorOp);
INSERT_ATENOP_PATTERN(AtenDivTensorOp); INSERT_ATENOP_PATTERN(AtenDivTensorOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp); INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target, if (failed(applyPartialConversion(getOperation(), target,

View File

@ -12,6 +12,7 @@
#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
@ -101,6 +102,9 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
pm.addNestedPass<FuncOp>(createCSEPass()); pm.addNestedPass<FuncOp>(createCSEPass());
} }
// Add the ToStandard pass for lowering some ops
pm.addNestedPass<FuncOp>(createTosaToStandard());
// Finish the type conversion from `torch` types to the types of the // Finish the type conversion from `torch` types to the types of the
// TOSA backend contract. // TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());

View File

@ -44,6 +44,7 @@ class VerifyTosaBackendContractPass
target.addLegalDialect<tosa::TosaDialect>(); target.addLegalDialect<tosa::TosaDialect>();
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes); target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes); target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
if (failed(applyFullConversion(module, target, std::move(patterns)))) { if (failed(applyFullConversion(module, target, std::move(patterns)))) {

View File

@ -330,3 +330,20 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32> %0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32> return %0 : !torch.vtensor<[?,?],f32>
} }
// -----
// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%fp0 = torch.constant.float 3.123400e+00
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}