[MHLO] fix tensor mode aten.div op pattern (#1160)

* [MHLO] fix tensor mode aten.div op pattern

See RFC #999
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
pull/1169/head snapshot-20220807.557
Tanyo Kwok 2022-08-06 23:38:06 +08:00 committed by GitHub
parent 5618890ca0
commit 1ee865983b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 7 deletions

View File

@ -208,7 +208,6 @@ public:
"only floating-point or integer datatype legalization supported");
}
Value lhsTensor = lhs;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhs = lhs;
} else if (!rhsType) {
@ -217,8 +216,37 @@ public:
DenseIntElementsAttr bcastDimensions;
lhs = mhlo::promoteType(rewriter, lhs, outType);
rhs = mhlo::promoteType(rewriter, rhs, outType);
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
auto loc = op.getLoc();
Value result =
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
if (!isa<AtenDivTensorModeOp>(op)) {
rewriter.replaceOp(op, result);
return success();
}
AtenDivTensorModeOp divTensorModeOp =
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
std::string roundingMode;
if (!matchPattern(divTensorModeOp.rounding_mode(),
m_TorchConstantStr(roundingMode)))
return rewriter.notifyMatchFailure(
op, "only support constant str rounding mode");
if (roundingMode == "trunc") {
// "trunc" - rounds the results of the division towards zero. Equivalent
// to C-style integer division.
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
}
if (roundingMode == "floor") {
// "floor" - rounds the results of the division down. Equivalent to
// floor division in Python (the // operator)
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
}
rewriter.replaceOp(op, result);
return success();
}
};
@ -554,7 +582,6 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
RankedTensorType outputType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
auto outputShape = outputType.getShape();
auto outputElemType = outputType.getElementType();
Value mhloTensor =
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
@ -968,6 +995,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
#undef INSERT_BINARY_MULDIV_PATTERN

View File

@ -2167,8 +2167,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
PatternRewriter &rewriter) const override {
// https://pytorch.org/docs/stable/generated/torch.floor_divide.html
// PyTorch aten.floor_divide is a misnomer because it actually rounds
// the quotient towards zero instead of taking its floor.
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
op, op.getType(), op.self(), op.other(),
/*rounding_mode=*/cstStrFloor);

View File

@ -540,3 +540,37 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
return %0 : !torch.vtensor<[?,?],i1>
}
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%str = torch.constant.str "trunc"
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%str = torch.constant.str "floor"
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

View File

@ -1113,8 +1113,8 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.
// CHECK-LABEL: func @torch.aten.floor_divide(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor"
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
// CHECK: %[[CSTTRUNC:.*]] = torch.constant.str "trunc"
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTTRUNC]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>