mirror of https://github.com/llvm/torch-mlir
parent
eae3ff7f1c
commit
2389729fb9
|
@ -728,7 +728,10 @@ TOSA_PASS_SET = {
|
||||||
"ConstantPadNdPartialStaticModule_basic",
|
"ConstantPadNdPartialStaticModule_basic",
|
||||||
"ConstantPadNdStaticModule_basic",
|
"ConstantPadNdStaticModule_basic",
|
||||||
"PadModule_basic",
|
"PadModule_basic",
|
||||||
"PadWithNoneValModule_basic"
|
"PadWithNoneValModule_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Float_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||||
|
"ElementwiseRemainderScalarModule_Int_basic"
|
||||||
}
|
}
|
||||||
|
|
||||||
LTC_XFAIL_SET = {
|
LTC_XFAIL_SET = {
|
||||||
|
|
|
@ -3499,7 +3499,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
|
||||||
// Support for multiple index
|
// Support for multiple index
|
||||||
auto index = indexTensors[0];
|
auto index = indexTensors[0];
|
||||||
auto indexTorch = tensorsTorchType[0];
|
auto indexTorch = tensorsTorchType[0];
|
||||||
// TODO add support for none index input like torch.ops.aten.index(x, (None, index1, index2, None))
|
// TODO add support for none index input like torch.ops.aten.index(x, (None,
|
||||||
|
// index1, index2, None))
|
||||||
if (indexTorch.getType().isa<Torch::NoneType>())
|
if (indexTorch.getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only list ranked tensor types index are supported");
|
op, "Only list ranked tensor types index are supported");
|
||||||
|
@ -3772,6 +3773,58 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
||||||
|
AtenRemainderScalarOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
|
Value self = adaptor.getSelf();
|
||||||
|
auto selfTy = self.getType().template cast<RankedTensorType>();
|
||||||
|
|
||||||
|
if (!selfTy)
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only ranked tensor types supported in TOSA Remainder");
|
||||||
|
|
||||||
|
auto outType =
|
||||||
|
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
||||||
|
|
||||||
|
Type outElemTy = outType.getElementType();
|
||||||
|
if (!outElemTy.isIntOrFloat())
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Only floating-point or integer datatype legalization supported");
|
||||||
|
|
||||||
|
Value otherTensor;
|
||||||
|
Value other = op.getOther();
|
||||||
|
if (failed(torchScalarToTosaTensor(rewriter, op, other, otherTensor,
|
||||||
|
outElemTy, {})))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "Currently only scalar constants are supported for "
|
||||||
|
"conversion in TOSA Remainder operation");
|
||||||
|
|
||||||
|
if (selfTy.getElementType() != outElemTy)
|
||||||
|
self = rewriter.create<tosa::CastOp>(op.getLoc(), outType, self);
|
||||||
|
|
||||||
|
auto divTensor = self;
|
||||||
|
// tosa::DivOp only supports int
|
||||||
|
if (outElemTy.isa<mlir::FloatType>()) {
|
||||||
|
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
|
||||||
|
op.getLoc(), otherTensor.getType(), otherTensor);
|
||||||
|
divTensor = rewriter.create<tosa::MulOp>(
|
||||||
|
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
|
||||||
|
divTensor = rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
|
||||||
|
} else {
|
||||||
|
divTensor =
|
||||||
|
rewriter.create<tosa::DivOp>(op.getLoc(), outType, self, otherTensor);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto mulTensor =
|
||||||
|
rewriter.create<tosa::MulOp>(op.getLoc(), outType, otherTensor, divTensor,
|
||||||
|
/*shift=*/0);
|
||||||
|
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename AtenOpT, typename TosaOpT>
|
template <typename AtenOpT, typename TosaOpT>
|
||||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||||
public:
|
public:
|
||||||
|
@ -3798,7 +3851,8 @@ public:
|
||||||
if (inputDim == kUnknownSize) {
|
if (inputDim == kUnknownSize) {
|
||||||
return kUnknownSize;
|
return kUnknownSize;
|
||||||
} else {
|
} else {
|
||||||
int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
int64_t dimSize =
|
||||||
|
inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1;
|
||||||
if (ceilMode && (dimSize % stride != 0))
|
if (ceilMode && (dimSize % stride != 0))
|
||||||
return dimSize / stride + 2;
|
return dimSize / stride + 2;
|
||||||
return dimSize / stride + 1;
|
return dimSize / stride + 1;
|
||||||
|
@ -4308,14 +4362,15 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||||
auto selfElemTy = selfTy.getElementType();
|
auto selfElemTy = selfTy.getElementType();
|
||||||
int64_t rank = selfTy.getRank();
|
int64_t rank = selfTy.getRank();
|
||||||
|
|
||||||
// START the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
|
// START the code snippet from
|
||||||
// Pattern match against the op's original operands, because otherwise we
|
// lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see:
|
||||||
// will get the lowered version of the operands which is harder to pattern
|
// ConvertAtenConstantPadNdOp) Pattern match against the op's original
|
||||||
// match.
|
// operands, because otherwise we will get the lowered version of the operands
|
||||||
|
// which is harder to pattern match.
|
||||||
SmallVector<int64_t> padInts;
|
SmallVector<int64_t> padInts;
|
||||||
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(op,
|
||||||
op, "only support constant int pad ranges");
|
"only support constant int pad ranges");
|
||||||
uint64_t padRank = padInts.size() / 2;
|
uint64_t padRank = padInts.size() / 2;
|
||||||
if (padRank * 2 != padInts.size())
|
if (padRank * 2 != padInts.size())
|
||||||
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
return rewriter.notifyMatchFailure(op, "pad range size is not even");
|
||||||
|
@ -4328,10 +4383,12 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||||
// Add the requested padding - note op.pad() is highest dim first ordered
|
// Add the requested padding - note op.pad() is highest dim first ordered
|
||||||
// pairs of low,high.
|
// pairs of low,high.
|
||||||
for (uint64_t i = 0; i < padRank; ++i) {
|
for (uint64_t i = 0; i < padRank; ++i) {
|
||||||
lowPadding[rank-i-1] = padInts[i * 2];
|
lowPadding[rank - i - 1] = padInts[i * 2];
|
||||||
highPadding[rank-i-1] = padInts[i * 2 + 1];
|
highPadding[rank - i - 1] = padInts[i * 2 + 1];
|
||||||
}
|
}
|
||||||
//END the code snippet from lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see: ConvertAtenConstantPadNdOp)
|
// END the code snippet from
|
||||||
|
// lib/Conversion/TorchToLinalg/TensorConstructors.cpp (see:
|
||||||
|
// ConvertAtenConstantPadNdOp)
|
||||||
|
|
||||||
llvm::SmallVector<int64_t> translatePadsList;
|
llvm::SmallVector<int64_t> translatePadsList;
|
||||||
|
|
||||||
|
@ -4353,13 +4410,14 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||||
|
|
||||||
Value padTensor;
|
Value padTensor;
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue,
|
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), padValue,
|
||||||
padTensor, selfElemTy, {})))
|
padTensor, selfElemTy, {})))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Pad value needs to be a scalar constant for conversion to "
|
op, "Pad value needs to be a scalar constant for conversion to "
|
||||||
"TOSA pad operation");
|
"TOSA pad operation");
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
|
rewriter.replaceOpWithNewOp<mlir::tosa::PadOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self, padsList1, padTensor);
|
op, getTypeConverter()->convertType(op.getType()), self, padsList1,
|
||||||
|
padTensor);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4587,6 +4645,7 @@ public:
|
||||||
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
INSERT_ATENOP_PATTERN(AtenCopyOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
INSERT_ATENOP_PATTERN(AtenToDtypeOp);
|
||||||
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
INSERT_ATENOP_PATTERN(AtenConstantPadNdOp);
|
||||||
|
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
|
||||||
#undef INSERT_ATENOP_PATTERN
|
#undef INSERT_ATENOP_PATTERN
|
||||||
|
|
||||||
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
|
||||||
|
|
|
@ -1100,3 +1100,24 @@ func.func @torch.aten.where.self(%arg0: !torch.vtensor<[1,1,5,5],i1>, %arg1: !to
|
||||||
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
|
%0 = torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[1,1,5,5],i1>, !torch.vtensor<[1,12,5,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[1,12,5,5],f32>
|
||||||
return %0 : !torch.vtensor<[1,12,5,5],f32>
|
return %0 : !torch.vtensor<[1,12,5,5],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.remainder.Scalar(
|
||||||
|
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> {
|
||||||
|
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[2,4],f32> -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_4:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[VAL_5:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_6:.*]] = "tosa.reciprocal"(%[[VAL_5:.*]]) : (tensor<f32>) -> tensor<f32>
|
||||||
|
// CHECK: %[[VAL_7:.*]] = "tosa.mul"(%[[VAL_3:.*]], %[[VAL_6:.*]]) {shift = 0 : i32} : (tensor<2x4xf32>, tensor<f32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_8:.*]] = "tosa.floor"(%[[VAL_7]]) : (tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_9:.*]] = "tosa.mul"(%[[VAL_5]], %[[VAL_8]]) {shift = 0 : i32} : (tensor<f32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_10:.*]] = "tosa.sub"(%[[VAL_3]], %[[VAL_9]]) : (tensor<2x4xf32>, tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||||
|
// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32>
|
||||||
|
// CHECK: return %[[VAL_11]] : !torch.vtensor<[2,4],f32>
|
||||||
|
// CHECK: }
|
||||||
|
func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !torch.vtensor<[2, 4],f32> {
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
|
||||||
|
return %0 : !torch.vtensor<[2, 4],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue