mirror of https://github.com/llvm/torch-mlir
[MHLO] fix tensor mode aten.div op pattern (#1160)
* [MHLO] fix tensor mode aten.div op pattern See RFC #999 Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>pull/1169/head snapshot-20220807.557
parent
5618890ca0
commit
1ee865983b
|
@ -208,7 +208,6 @@ public:
|
||||||
"only floating-point or integer datatype legalization supported");
|
"only floating-point or integer datatype legalization supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
Value lhsTensor = lhs;
|
|
||||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||||
rhs = lhs;
|
rhs = lhs;
|
||||||
} else if (!rhsType) {
|
} else if (!rhsType) {
|
||||||
|
@ -217,8 +216,37 @@ public:
|
||||||
DenseIntElementsAttr bcastDimensions;
|
DenseIntElementsAttr bcastDimensions;
|
||||||
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
lhs = mhlo::promoteType(rewriter, lhs, outType);
|
||||||
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
rhs = mhlo::promoteType(rewriter, rhs, outType);
|
||||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
auto loc = op.getLoc();
|
||||||
bcastDimensions);
|
Value result =
|
||||||
|
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
|
||||||
|
|
||||||
|
if (!isa<AtenDivTensorModeOp>(op)) {
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
AtenDivTensorModeOp divTensorModeOp =
|
||||||
|
llvm::dyn_cast<AtenDivTensorModeOp>(op.getOperation());
|
||||||
|
std::string roundingMode;
|
||||||
|
if (!matchPattern(divTensorModeOp.rounding_mode(),
|
||||||
|
m_TorchConstantStr(roundingMode)))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "only support constant str rounding mode");
|
||||||
|
|
||||||
|
if (roundingMode == "trunc") {
|
||||||
|
// "trunc" - rounds the results of the division towards zero. Equivalent
|
||||||
|
// to C-style integer division.
|
||||||
|
auto sign = rewriter.create<mhlo::SignOp>(loc, result);
|
||||||
|
auto abs = rewriter.create<mhlo::AbsOp>(loc, result);
|
||||||
|
auto floor = rewriter.create<mhlo::FloorOp>(loc, abs);
|
||||||
|
result = rewriter.create<mhlo::MulOp>(loc, sign, floor).getResult();
|
||||||
|
}
|
||||||
|
if (roundingMode == "floor") {
|
||||||
|
// "floor" - rounds the results of the division down. Equivalent to
|
||||||
|
// floor division in Python (the // operator)
|
||||||
|
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
|
||||||
|
}
|
||||||
|
rewriter.replaceOp(op, result);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -554,7 +582,6 @@ LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
RankedTensorType outputType = getTypeConverter()
|
RankedTensorType outputType = getTypeConverter()
|
||||||
->convertType(op->getResult(0).getType())
|
->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
auto outputShape = outputType.getShape();
|
|
||||||
auto outputElemType = outputType.getElementType();
|
auto outputElemType = outputType.getElementType();
|
||||||
Value mhloTensor =
|
Value mhloTensor =
|
||||||
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
|
mhlo::scalarToMhloTensor(rewriter, op, adaptor.a(), outputElemType);
|
||||||
|
@ -968,6 +995,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenMulTensorOp, chlo::BroadcastMulOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenMulScalarOp, chlo::BroadcastMulOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorOp, chlo::BroadcastDivOp);
|
||||||
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
|
||||||
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
|
||||||
#undef INSERT_BINARY_MULDIV_PATTERN
|
#undef INSERT_BINARY_MULDIV_PATTERN
|
||||||
|
|
||||||
|
|
|
@ -2167,8 +2167,11 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
|
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
// https://pytorch.org/docs/stable/generated/torch.floor_divide.html
|
||||||
|
// PyTorch aten.floor_divide is a misnomer because it actually rounds
|
||||||
|
// the quotient towards zero instead of taking its floor.
|
||||||
Value cstStrFloor =
|
Value cstStrFloor =
|
||||||
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
|
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "trunc");
|
||||||
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
|
||||||
op, op.getType(), op.self(), op.other(),
|
op, op.getType(), op.self(), op.other(),
|
||||||
/*rounding_mode=*/cstStrFloor);
|
/*rounding_mode=*/cstStrFloor);
|
||||||
|
|
|
@ -540,3 +540,37 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1
|
||||||
return %0 : !torch.vtensor<[?,?],i1>
|
return %0 : !torch.vtensor<[?,?],i1>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$trunc(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[STR:.*]] = torch.constant.str "trunc"
|
||||||
|
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
// CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
%str = torch.constant.str "trunc"
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$floor(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[STR:.*]] = torch.constant.str "floor"
|
||||||
|
// CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor<?x?x?x?xf32>
|
||||||
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||||
|
%str = torch.constant.str "floor"
|
||||||
|
%0 = torch.aten.div.Tensor_mode %arg0, %arg1, %str : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.str -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
|
@ -1113,8 +1113,8 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.
|
||||||
// CHECK-LABEL: func @torch.aten.floor_divide(
|
// CHECK-LABEL: func @torch.aten.floor_divide(
|
||||||
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
|
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
|
||||||
// CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
// CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor"
|
// CHECK: %[[CSTTRUNC:.*]] = torch.constant.str "trunc"
|
||||||
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTTRUNC]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
|
|
Loading…
Reference in New Issue