mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for AtenPowTensorScalarOp with constant Scalar as input
Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>pull/522/head snapshot-20220112.201
parent
077e55d756
commit
d69d29b7a6
|
@ -50,4 +50,5 @@ TOSA_PASS_SET = {
|
||||||
"SqueezeDimModule_identity",
|
"SqueezeDimModule_identity",
|
||||||
"SqueezeDimModule_unitDim",
|
"SqueezeDimModule_unitDim",
|
||||||
"ReturnTwoTensorF32I64_basic",
|
"ReturnTwoTensorF32I64_basic",
|
||||||
|
"ElementwisePowModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -635,6 +635,49 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// FIXME(AG): This will eventually go into a Tosa*Utils file
|
||||||
|
// Convert an fp32 scalar into tosa fp32 tensor.
|
||||||
|
static LogicalResult
|
||||||
|
tosaF32TensorFromTorchFloat(ConversionPatternRewriter &rewriter, Operation *op,
|
||||||
|
Value torchScalarValue, Value &tosaTensor) {
|
||||||
|
double scalarValue;
|
||||||
|
|
||||||
|
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Construct a tosa.const
|
||||||
|
tosaTensor =
|
||||||
|
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
|
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
Value self = adaptor.self();
|
||||||
|
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!selfTy)
|
||||||
|
return op.emitError("Only ranked tensor types supported in TOSA Pow");
|
||||||
|
|
||||||
|
if (!selfTy.getElementType().isa<mlir::FloatType>())
|
||||||
|
return op.emitError("Only floating-point datatype legalization supported");
|
||||||
|
|
||||||
|
Value expTensor;
|
||||||
|
Value expScalar = op.exponent();
|
||||||
|
if (failed(tosaF32TensorFromTorchFloat(rewriter, op.getOperation(), expScalar,
|
||||||
|
expTensor)))
|
||||||
|
return op.emitError("Currently only scalar constants are supported for "
|
||||||
|
"conversion in TOSA Pow operation");
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::PowOp>(
|
||||||
|
op, getTypeConverter()->convertType(op.getType()), self, expTensor);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
@ -740,6 +783,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
INSERT_ATENOP_PATTERN(AtenMulTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
INSERT_ATENOP_PATTERN(AtenDivTensorOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
|
||||||
|
#include "mlir/Conversion/Passes.h"
|
||||||
#include "mlir/Pass/PassManager.h"
|
#include "mlir/Pass/PassManager.h"
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
|
||||||
|
@ -101,6 +102,9 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
|
||||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the ToStandard pass for lowering some ops
|
||||||
|
pm.addNestedPass<FuncOp>(createTosaToStandard());
|
||||||
|
|
||||||
// Finish the type conversion from `torch` types to the types of the
|
// Finish the type conversion from `torch` types to the types of the
|
||||||
// TOSA backend contract.
|
// TOSA backend contract.
|
||||||
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
|
||||||
|
|
|
@ -44,6 +44,7 @@ class VerifyTosaBackendContractPass
|
||||||
target.addLegalDialect<tosa::TosaDialect>();
|
target.addLegalDialect<tosa::TosaDialect>();
|
||||||
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<tensor::CastOp>(opHasLegalTypes);
|
||||||
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
|
target.addDynamicallyLegalOp<arith::ExtSIOp>(opHasLegalTypes);
|
||||||
|
target.addDynamicallyLegalOp<arith::ConstantOp>(opHasLegalTypes);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||||
|
|
|
@ -330,3 +330,20 @@ func @torch.aten.minimum$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v
|
||||||
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @torch.aten.pow.Tensor_Scalar$basic(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_2:.*]] = torch.constant.float 3.123400e+00
|
||||||
|
// CHECK: %[[VAL_3:.*]] = "tosa.const"() {value = dense<3.123400e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = "tosa.pow"(%[[VAL_1]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func @torch.aten.pow.Tensor_Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
%fp0 = torch.constant.float 3.123400e+00
|
||||||
|
%0 = torch.aten.pow.Tensor_Scalar %arg0, %fp0 : !torch.vtensor<[?,?],f32>, !torch.float -> !torch.vtensor<[?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue