[MLIR][TORCH] Add decomposition of aten.floor_divide op

This commit adds the decomposition of `aten.floor_divide` op into
`aten.div.Tensor_mode` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/936/head snapshot-20220614.12
Vivek Khandelwal 2022-06-09 11:39:28 +05:30
parent 0d4445eaf9
commit 33fa8e7761
8 changed files with 105 additions and 1 deletions

View File

@ -2637,6 +2637,30 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
}];
}
def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::floor_divide : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFloorDivideOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenFloorDivideOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
AllowsTypeRefinement
]> {

View File

@ -1941,6 +1941,22 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
};
} // namespace
namespace {
// Decompose `aten.floor_divide` op into `aten.div.Tensor_mode` op.
class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFloorDivideOp op,
PatternRewriter &rewriter) const override {
Value cstStrFloor =
rewriter.create<Torch::ConstantStrOp>(op.getLoc(), "floor");
rewriter.replaceOpWithNewOp<AtenDivTensorModeOp>(
op, op.getType(), op.self(), op.other(),
/*rounding_mode=*/cstStrFloor);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -2084,6 +2100,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenClampMaxOp>();
patterns.add<DecomposeAtenBaddbmmOp>(context);
target.addIllegalOp<AtenBaddbmmOp>();
patterns.add<DecomposeAtenFloorDivideOp>(context);
target.addIllegalOp<AtenFloorDivideOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -702,7 +702,8 @@ ChangeResult TypeAnalyzer::visitOperation(
// Promote the two dtypes assuming possibly-zero rank.
if (isa<AtenAddTensorOp, AtenSubTensorOp, AtenMulTensorOp, AtenDivTensorOp,
AtenDivTensorModeOp, Aten__And__TensorOp, AtenMinimumOp,
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp>(op)) {
AtenMaximumOp, AtenBitwiseAndTensorOp, AtenThresholdBackwardOp,
AtenFloorDivideOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = getPromotedResultType(

View File

@ -6077,6 +6077,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.floor_divide"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.__and__.Tensor"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -761,6 +761,9 @@ def atendivTensor(self: List[int], other: List[int]) -> List[int]:
def atendivTensor_mode(self: List[int], other: List[int], rounding_mode: Optional[str]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenfloor_divide(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def aten__and__Tensor(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

View File

@ -302,6 +302,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
# Ops without value semantics but the corresponding without trailing
# underscore variant doesn't exist.

View File

@ -1720,4 +1720,45 @@ def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
module.forward(torch.randint(3, (3,)), torch.randint(3, (4, 3)))
# ==============================================================================
class ElementwiseAtenFloorDivideModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.floor_divide(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenFloorDivideModule())
def ElementwiseAtenFloorDivideModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 3), tu.rand(4, 3))
class ElementwiseAtenFloorDivideBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, y):
return torch.ops.aten.floor_divide(x, y)
@register_test_case(
module_factory=lambda: ElementwiseAtenFloorDivideBroadcastModule())
def ElementwiseAtenFloorDivideBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3), tu.rand(4, 3))

View File

@ -1024,3 +1024,15 @@ func.func @torch.aten.baddbmm(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.
%0 = torch.aten.baddbmm %arg0, %arg1, %arg2, %int1, %int1 : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],f32>, !torch.int , !torch.int -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.floor_divide(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[CSTFLOOR:.*]] = torch.constant.str "floor"
// CHECK: %[[OUT:.*]] = torch.aten.div.Tensor_mode %[[SELF]], %[[OTHER]], %[[CSTFLOOR]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>, !torch.str -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}