[MLIR][TORCH] Add decomposition of aten.numpy_T op

This commit adds the decomposition of `aten.numpy_T` op into
`aten.t` or `aten.permute` op.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/792/merge
Vivek Khandelwal 2022-06-03 18:08:59 +05:30
parent 4605dc9c99
commit 77ab31641f
10 changed files with 198 additions and 2 deletions

View File

@ -160,4 +160,8 @@ TOSA_PASS_SET = {
"BaddbmmWithBetaModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
}

View File

@ -6002,6 +6002,28 @@ def Torch_AtenTOp : Torch_Op<"aten.t", [
}];
}
def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenNumpyTOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenFullOp : Torch_Op<"aten.full", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1957,6 +1957,29 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
};
} // namespace
namespace {
// Decompose `aten.numpy_T` op into `aten.permute` op.
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNumpyTOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.self();
int64_t inputRank = getTensorRank(self);
SmallVector<Value> dimListElements;
for (int64_t i = inputRank - 1; i >= 0; i--)
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i)));
Value dimList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
dimListElements);
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), self, dimList);
return success();
}
};
} // namespace
namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@ -2102,6 +2125,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenBaddbmmOp>();
patterns.add<DecomposeAtenFloorDivideOp>(context);
target.addIllegalOp<AtenFloorDivideOp>();
patterns.add<DecomposeAtenNumpyTOp>(context);
target.addIllegalOp<AtenNumpyTOp>();
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns)))) {

View File

@ -37,7 +37,7 @@ static bool isViewLikeOp(Operation *op) {
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp>(op);
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp>(op);
}
namespace {

View File

@ -642,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp>(op)) {
PrimAbsScalarOp, AtenNumpyTOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
}

View File

@ -5628,6 +5628,19 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.numpy_T"(%arg0: !torch.list<int>) -> !torch.list<int> {
%int0 = torch.constant.int 0
%true = torch.constant.bool true
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
torch.prim.Loop %1, %true, init() {
^bb0(%arg1: !torch.int):
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
torch.aten.insert.t %0, %int0, %2 : !torch.list<int>, !torch.int, !torch.int
torch.prim.Loop.condition %true, iter()
} : (!torch.int, !torch.bool) -> ()
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.matmul"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -529,6 +529,12 @@ def atentransposeint(self: List[int], dim0: int, dim1: int) -> List[int]:
def atent(self: List[int]) -> List[int]:
return upstream_shape_functions.transpose(self, 0, 1)
def atennumpy_T(self: List[int]) -> List[int]:
result_shape: List[int] = []
for i in self:
result_shape.insert(0, i)
return result_shape
def atenmatmul(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.matmul(self, other)

View File

@ -465,6 +465,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")
emit("aten::numpy_T : (Tensor) -> (Tensor)")
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")

View File

@ -2167,3 +2167,101 @@ class BaddbmmBroadcast2DInputModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BaddbmmBroadcast2DInputModule())
def BaddbmmBroadcast2DInputModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 7), tu.rand(5, 2, 9), tu.rand(5, 9, 7))
# ==============================================================================
class NumpyTRankNStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4, 5, 6], torch.float32, True),
])
def forward(self, lhs):
return torch.ops.aten.numpy_T(lhs)
@register_test_case(module_factory=lambda: NumpyTRankNStaticModule())
def NumpyTRankNStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, 6))
class NumpyTRankNDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1, -1, -1], torch.float32, True),
])
def forward(self, lhs):
return torch.ops.aten.numpy_T(lhs)
@register_test_case(module_factory=lambda: NumpyTRankNDynamicModule())
def NumpyTRankNDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5, 6, 2))
class NumpyTRank2Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, lhs):
return torch.ops.aten.numpy_T(lhs)
@register_test_case(module_factory=lambda: NumpyTRank2Module())
def NumpyTRank2Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class NumpyTRank1Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, lhs):
return torch.ops.aten.numpy_T(lhs)
@register_test_case(module_factory=lambda: NumpyTRank1Module())
def NumpyTRank1Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3))
class NumpyTRank0Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float32, True),
])
def forward(self, lhs):
return torch.ops.aten.numpy_T(lhs)
@register_test_case(module_factory=lambda: NumpyTRank0Module())
def NumpyTRank0Module_basic(module, tu: TestUtils):
module.forward(torch.tensor(7, dtype=torch.float32))

View File

@ -1036,3 +1036,30 @@ func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !tor
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.numpy_T$rank_two(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[4,5],f32> {
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.permute %[[SELF]], %[[DIMS]] : !torch.vtensor<[5,4],f32>, !torch.list<int> -> !torch.vtensor<[4,5],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[4,5],f32>
func.func @torch.aten.numpy_T$rank_two(%arg0: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[4,5],f32> {
%0 = torch.aten.numpy_T %arg0 : !torch.vtensor<[5,4],f32> -> !torch.vtensor<[4,5],f32>
return %0 : !torch.vtensor<[4,5],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.numpy_T$rank_three(
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[5,4,3],f32>) -> !torch.vtensor<[3,4,5],f32> {
// CHECK: %[[CST2:.*]] = torch.constant.int 2
// CHECK: %[[CST1:.*]] = torch.constant.int 1
// CHECK: %[[CST0:.*]] = torch.constant.int 0
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]], %[[CST1]], %[[CST0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[OUT:.*]] = torch.aten.permute %[[SELF]], %[[DIMS]] : !torch.vtensor<[5,4,3],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5],f32>
func.func @torch.aten.numpy_T$rank_three(%arg0: !torch.vtensor<[5,4,3],f32>) -> !torch.vtensor<[3,4,5],f32> {
%0 = torch.aten.numpy_T %arg0 : !torch.vtensor<[5,4,3],f32> -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}