mirror of https://github.com/llvm/torch-mlir
parent
eae3ff7f1c
commit
2389729fb9
|
@ -728,7 +728,10 @@ TOSA_PASS_SET = {
|
|||
"ConstantPadNdPartialStaticModule_basic",
|
||||
"ConstantPadNdStaticModule_basic",
|
||||
"PadModule_basic",
|
||||
"PadWithNoneValModule_basic"
|
||||
"PadWithNoneValModule_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_basic"
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
|
|
@ -3499,7 +3499,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
|||
// Support for multiple index
|
||||
auto index = indexTensors[0];
|
||||
auto indexTorch = tensorsTorchType[0];
|
||||
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
|
||||
// TODO add support for none index input like torch.ops.aten.index(x, (None,
|
||||
// index1, index2, None))
|
||||
if (indexTorch.getType().isa<Torch::NoneType>())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only list ranked tensor types index are supported");
|
||||
|
@ -3772,6 +3773,58 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
||||
AtenRemainderScalarOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Value self = adaptor.getSelf();
|
||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||
|
||||
if (!selfTy)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only ranked tensor types supported in TOSA Remainder");
|
||||
|
||||
auto outType =
|
||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value otherTensor;
|
||||
Value other = op.getOther();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
|
||||
outElemTy, {})))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Remainder operation");
|
||||
|
||||
if (selfTy.getElementType() != outElemTy)
|
||||
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
|
||||
|
||||
auto divTensor = self;
|
||||
// tosa::DivOp only supports int
|
||||
if (outElemTy.isa<mlir::FloatType>()) {
|
||||
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
||||
op.getLoc(), otherTensor.getType(), otherTensor);
|
||||
divTensor = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
|
||||
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
|
||||
} else {
|
||||
divTensor =
|
||||
rewriter.create<tosa::DivOp>(op.getLoc(), outType, self, otherTensor);
|
||||
}
|
||||
|
||||
auto mulTensor =
|
||||
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
|
||||
/*shift=*/0);
|
||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -3798,7 +3851,8 @@ public:
|
|||
if (inputDim == kUnknownSize) {
|
||||
return kUnknownSize;
|
||||
} else {
|
||||
int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
||||
int64_t dimSize =
|
||||
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
||||
if (ceilMode && (dimSize % stride != 0))
|
||||
return dimSize / stride + 2;
|
||||
return dimSize / stride + 1;
|
||||
|
@ -4308,14 +4362,15 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
|||
auto selfElemTy = selfTy.getElementType();
|
||||
int64_t rank = selfTy.getRank();
|
||||
|
||||
// START the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
|
||||
// Pattern match against the op's original operands, because otherwise we
|
||||
// will get the lowered version of the operands which is harder to pattern
|
||||
// match.
|
||||
// START the code snippet from
|
||||
// lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see:
|
||||
// ConvertAtenConstantPadNdOp) Pattern match against the op's original
|
||||
// operands, because otherwise we will get the lowered version of the operands
|
||||
// which is harder to pattern match.
|
||||
SmallVector<int64_t> padInts;
|
||||
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "only support constant int pad ranges");
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"only support constant int pad ranges");
|
||||
uint64_t padRank = padInts.size() / 2;
|
||||
if (padRank * 2 != padInts.size())
|
||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||
|
@ -4331,7 +4386,9 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
|||
lowPadding[rank - i - 1] = padInts[i * 2];
|
||||
highPadding[rank - i - 1] = padInts[i * 2 + 1];
|
||||
}
|
||||
//END the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
|
||||
// END the code snippet from
|
||||
// lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see:
|
||||
// ConvertAtenConstantPadNdOp)
|
||||
|
||||
llvm::SmallVector<int64_t> translatePadsList;
|
||||
|
||||
|
@ -4359,7 +4416,8 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
|||
"TOSA pad operation");
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), self, padsList1, padTensor);
|
||||
op, getTypeConverter()->convertType(op.getType()), self, padsList1,
|
||||
padTensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -4587,6 +4645,7 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1100,3 +1100,24 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
|
|||
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
|
||||
return %0 : !torch.vtensor<[1,12,5,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[VAL_6:.*]] = "tosa.reciprocal"(%[[VAL_5:.*]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3:.*]], %[[VAL_6:.*]]) {shift = 0 : i32} : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_8:.*]] = "tosa.floor"(%[[VAL_7]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) {shift = 0 : i32} : (tensor<f32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_10:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_9]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||
// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
|
||||
%int2 = torch.constant.int 2
|
||||
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
|
||||
return %0 : !torch.vtensor<[2, 4],f32>
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue