mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for AtenPowTensorScalarOp with constant Scalar as input
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>pull/522/head snapshot-20220112.201
parent
077e55d756
commit
d69d29b7a6
|
@ -50,4 +50,5 @@ TOSA_PASS_SET = {
|
|||
"SqueezeDimModule_identity",
|
||||
"SqueezeDimModule_unitDim",
|
||||
"ReturnTwoTensorF32I64_basic",
|
||||
"ElementwisePowModule_basic",
|
||||
}
|
||||
|
|
|
@ -635,6 +635,49 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
|
|||
}
|
||||
};
|
||||
|
||||
// FIXME(AG): This will eventually go into a Tosa*Utils file
|
||||
// Convert an fp32 scalar into tosa fp32 tensor.
|
||||
static LogicalResult
|
||||
tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op,
|
||||
Value torchScalarValue, Value &tosaTensor) {
|
||||
double scalarValue;
|
||||
|
||||
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
|
||||
return failure();
|
||||
|
||||
// Construct a tosa.const
|
||||
tosaTensor =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.self();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return op.emitError("Only ranked tensor types supported in TOSA Pow");
|
||||
|
||||
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
||||
return op.emitError("Only floating-point datatype legalization supported");
|
||||
|
||||
Value expTensor;
|
||||
Value expScalar = op.exponent();
|
||||
if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar,
|
||||
expTensor)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::PowOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
@ -740,6 +783,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||
#include "mlir/Conversion/Passes.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||
|
@ -101,6 +102,9 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
|||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Add the ToStandard pass for lowering some ops
|
||||
pm.addNestedPass<FuncOp>(createTosaToStandard());
|
||||
|
||||
// Finish the type conversion from `torch` types to the types of the
|
||||
// TOSA backend contract.
|
||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||
|
|
|
@ -44,6 +44,7 @@ class VerifyTosaBackendContractPass
|
|||
target.addLegalDialect<tosa::TosaDialect>();
|
||||
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
|
||||
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
|
|
|
@ -330,3 +330,20 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
|
|||
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%fp0 = torch.constant.float 3.123400e+00
|
||||
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue