Add support for aten_remainder in TorchToTosa (#1966)

pull/1969/head snapshot-20230324.787
Michael Feliz 2023-03-23 17:55:58 -07:00 committed by GitHub
parent eae3ff7f1c
commit 2389729fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 101 additions and 18 deletions

View File

@ -728,7 +728,10 @@ TOSA_PASS_SET = {
"ConstantPadNdPartialStaticModule_basic", "ConstantPadNdPartialStaticModule_basic",
"ConstantPadNdStaticModule_basic", "ConstantPadNdStaticModule_basic",
"PadModule_basic", "PadModule_basic",
"PadWithNoneValModule_basic" "PadWithNoneValModule_basic",
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic"
} }
LTC_XFAIL_SET = { LTC_XFAIL_SET = {

View File

@ -3499,7 +3499,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
// Support for multiple index // Support for multiple index
auto index = indexTensors[0]; auto index = indexTensors[0];
auto indexTorch = tensorsTorchType[0]; auto indexTorch = tensorsTorchType[0];
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None)) // TODO add support for none index input like torch.ops.aten.index(x, (None,
// index1, index2, None))
if (indexTorch.getType().isa<Torch::NoneType>()) if (indexTorch.getType().isa<Torch::NoneType>())
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Only list ranked tensor types index are supported"); op, "Only list ranked tensor types index are supported");
@ -3772,6 +3773,58 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
return success(); return success();
} }
template <>
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
AtenRemainderScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = self.getType().template cast<RankedTensorType>();
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Remainder");
auto outType =
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return rewriter.notifyMatchFailure(
op, "Only floating-point or integer datatype legalization supported");
Value otherTensor;
Value other = op.getOther();
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
outElemTy, {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Remainder operation");
if (selfTy.getElementType() != outElemTy)
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
auto divTensor = self;
// tosa::DivOp only supports int
if (outElemTy.isa<mlir::FloatType>()) {
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
op.getLoc(), otherTensor.getType(), otherTensor);
divTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
} else {
divTensor =
rewriter.create<tosa::DivOp>(op.getLoc(), outType, self, otherTensor);
}
auto mulTensor =
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
/*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
return success();
}
template <typename AtenOpT, typename TosaOpT> template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> { class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public: public:
@ -3798,7 +3851,8 @@ public:
if (inputDim == kUnknownSize) { if (inputDim == kUnknownSize) {
return kUnknownSize; return kUnknownSize;
} else { } else {
int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; int64_t dimSize =
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
if (ceilMode && (dimSize % stride != 0)) if (ceilMode && (dimSize % stride != 0))
return dimSize / stride + 2; return dimSize / stride + 2;
return dimSize / stride + 1; return dimSize / stride + 1;
@ -4308,14 +4362,15 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
auto selfElemTy = selfTy.getElementType(); auto selfElemTy = selfTy.getElementType();
int64_t rank = selfTy.getRank(); int64_t rank = selfTy.getRank();
// START the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp) // START the code snippet from
// Pattern match against the op's original operands, because otherwise we // lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see:
// will get the lowered version of the operands which is harder to pattern // ConvertAtenConstantPadNdOp) Pattern match against the op's original
// match. // operands, because otherwise we will get the lowered version of the operands
// which is harder to pattern match.
SmallVector<int64_t> padInts; SmallVector<int64_t> padInts;
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(op,
op, "only support constant int pad ranges"); "only support constant int pad ranges");
uint64_t padRank = padInts.size() / 2; uint64_t padRank = padInts.size() / 2;
if (padRank * 2 != padInts.size()) if (padRank * 2 != padInts.size())
return rewriter.notifyMatchFailure(op, "pad range size is not even"); return rewriter.notifyMatchFailure(op, "pad range size is not even");
@ -4328,10 +4383,12 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
// Add the requested padding - note op.pad() is highest dim first ordered // Add the requested padding - note op.pad() is highest dim first ordered
// pairs of low,high. // pairs of low,high.
for (uint64_t i = 0; i < padRank; ++i) { for (uint64_t i = 0; i < padRank; ++i) {
lowPadding[rank-i-1] = padInts[i * 2]; lowPadding[rank - i - 1] = padInts[i * 2];
highPadding[rank-i-1] = padInts[i * 2 + 1]; highPadding[rank - i - 1] = padInts[i * 2 + 1];
} }
//END the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp) // END the code snippet from
// lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see:
// ConvertAtenConstantPadNdOp)
llvm::SmallVector<int64_t> translatePadsList; llvm::SmallVector<int64_t> translatePadsList;
@ -4353,13 +4410,14 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
Value padTensor; Value padTensor;
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue, if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue,
padTensor, selfElemTy, {}))) padTensor, selfElemTy, {})))
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "Pad value needs to be a scalar constant for conversion to " op, "Pad value needs to be a scalar constant for conversion to "
"TOSA pad operation"); "TOSA pad operation");
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>( rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
op, getTypeConverter()->convertType(op.getType()), self, padsList1, padTensor); op, getTypeConverter()->convertType(op.getType()), self, padsList1,
padTensor);
return success(); return success();
} }
@ -4587,6 +4645,7 @@ public:
INSERT_ATENOP_PATTERN(AtenCopyOp); INSERT_ATENOP_PATTERN(AtenCopyOp);
INSERT_ATENOP_PATTERN(AtenToDtypeOp); INSERT_ATENOP_PATTERN(AtenToDtypeOp);
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
#undef INSERT_ATENOP_PATTERN #undef INSERT_ATENOP_PATTERN
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \

View File

@ -1100,3 +1100,24 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32> %0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
return %0 : !torch.vtensor<[1,12,5,5],f32> return %0 : !torch.vtensor<[1,12,5,5],f32>
} }
// -----
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.reciprocal"(%[[VAL_5:.*]]) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3:.*]], %[[VAL_6:.*]]) {shift = 0 : i32} : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_8:.*]] = "tosa.floor"(%[[VAL_7]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) {shift = 0 : i32} : (tensor<f32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_10:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_9]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32>
// CHECK: }
func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
%int2 = torch.constant.int 2
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32>
}