mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add decomposition of aten.numpy_T op
This commit adds the decomposition of `aten.numpy_T` op into `aten.t` or `aten.permute` op. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/792/merge
parent
4605dc9c99
commit
77ab31641f
|
@ -160,4 +160,8 @@ TOSA_PASS_SET = {
|
||||||
"BaddbmmWithBetaModule_basic",
|
"BaddbmmWithBetaModule_basic",
|
||||||
"BaddbmmBroadcast1DInputModule_basic",
|
"BaddbmmBroadcast1DInputModule_basic",
|
||||||
"BaddbmmBroadcast2DInputModule_basic",
|
"BaddbmmBroadcast2DInputModule_basic",
|
||||||
|
"NumpyTRank1Module_basic",
|
||||||
|
"NumpyTRank2Module_basic",
|
||||||
|
"NumpyTRankNStaticModule_basic",
|
||||||
|
"NumpyTRankNDynamicModule_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -6002,6 +6002,28 @@ def Torch_AtenTOp : Torch_Op<"aten.t", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenNumpyTOp : Torch_Op<"aten.numpy_T", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::numpy_T : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenNumpyTOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenNumpyTOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenFullOp : Torch_Op<"aten.full", [
|
def Torch_AtenFullOp : Torch_Op<"aten.full", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -1957,6 +1957,29 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern<AtenFloorDivideOp> {
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Decompose `aten.numpy_T` op into `aten.permute` op.
|
||||||
|
class DecomposeAtenNumpyTOp : public OpRewritePattern<AtenNumpyTOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenNumpyTOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value self = op.self();
|
||||||
|
int64_t inputRank = getTensorRank(self);
|
||||||
|
|
||||||
|
SmallVector<Value> dimListElements;
|
||||||
|
for (int64_t i = inputRank - 1; i >= 0; i--)
|
||||||
|
dimListElements.push_back(rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(i)));
|
||||||
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||||
|
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
||||||
|
dimListElements);
|
||||||
|
rewriter.replaceOpWithNewOp<AtenPermuteOp>(op, op.getType(), self, dimList);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeComplexOpsPass
|
class DecomposeComplexOpsPass
|
||||||
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
|
||||||
|
@ -2102,6 +2125,8 @@ class DecomposeComplexOpsPass
|
||||||
target.addIllegalOp<AtenBaddbmmOp>();
|
target.addIllegalOp<AtenBaddbmmOp>();
|
||||||
patterns.add<DecomposeAtenFloorDivideOp>(context);
|
patterns.add<DecomposeAtenFloorDivideOp>(context);
|
||||||
target.addIllegalOp<AtenFloorDivideOp>();
|
target.addIllegalOp<AtenFloorDivideOp>();
|
||||||
|
patterns.add<DecomposeAtenNumpyTOp>(context);
|
||||||
|
target.addIllegalOp<AtenNumpyTOp>();
|
||||||
|
|
||||||
if (failed(applyPartialConversion(getOperation(), target,
|
if (failed(applyPartialConversion(getOperation(), target,
|
||||||
std::move(patterns)))) {
|
std::move(patterns)))) {
|
||||||
|
|
|
@ -37,7 +37,7 @@ static bool isViewLikeOp(Operation *op) {
|
||||||
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
|
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
|
||||||
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
|
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
|
||||||
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp>(op);
|
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
@ -642,7 +642,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
AtenZero_Op, AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp,
|
||||||
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
AtenIndexPutOp, ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
|
||||||
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
|
||||||
PrimAbsScalarOp>(op)) {
|
PrimAbsScalarOp, AtenNumpyTOp>(op)) {
|
||||||
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5628,6 +5628,19 @@ module {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.transpose(%arg0, %int0, %int1) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.numpy_T"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%0 = torch.prim.ListConstruct : () -> !torch.list<int>
|
||||||
|
%1 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int
|
||||||
|
torch.prim.Loop %1, %true, init() {
|
||||||
|
^bb0(%arg1: !torch.int):
|
||||||
|
%2 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list<int>, !torch.int -> !torch.int
|
||||||
|
torch.aten.insert.t %0, %int0, %2 : !torch.list<int>, !torch.int, !torch.int
|
||||||
|
torch.prim.Loop.condition %true, iter()
|
||||||
|
} : (!torch.int, !torch.bool) -> ()
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.matmul"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.matmul"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch.jit._shape_functions.matmul(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -529,6 +529,12 @@ def aten〇transpose〇int(self: List[int], dim0: int, dim1: int) -> List[int]:
|
||||||
def aten〇t(self: List[int]) -> List[int]:
|
def aten〇t(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.transpose(self, 0, 1)
|
return upstream_shape_functions.transpose(self, 0, 1)
|
||||||
|
|
||||||
|
def aten〇numpy_T(self: List[int]) -> List[int]:
|
||||||
|
result_shape: List[int] = []
|
||||||
|
for i in self:
|
||||||
|
result_shape.insert(0, i)
|
||||||
|
return result_shape
|
||||||
|
|
||||||
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
|
def aten〇matmul(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.matmul(self, other)
|
return upstream_shape_functions.matmul(self, other)
|
||||||
|
|
||||||
|
|
|
@ -465,6 +465,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
emit_with_mutating_variants("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||||
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)")
|
||||||
emit("aten::t : (Tensor) -> (Tensor)")
|
emit("aten::t : (Tensor) -> (Tensor)")
|
||||||
|
emit("aten::numpy_T : (Tensor) -> (Tensor)")
|
||||||
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
|
|
|
@ -2167,3 +2167,101 @@ class BaddbmmBroadcast2DInputModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: BaddbmmBroadcast2DInputModule())
|
@register_test_case(module_factory=lambda: BaddbmmBroadcast2DInputModule())
|
||||||
def BaddbmmBroadcast2DInputModule_basic(module, tu: TestUtils):
|
def BaddbmmBroadcast2DInputModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(2, 7), tu.rand(5, 2, 9), tu.rand(5, 9, 7))
|
module.forward(tu.rand(2, 7), tu.rand(5, 2, 9), tu.rand(5, 9, 7))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyTRankNStaticModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([3, 4, 5, 6], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.ops.aten.numpy_T(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NumpyTRankNStaticModule())
|
||||||
|
def NumpyTRankNStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, 6))
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyTRankNDynamicModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.ops.aten.numpy_T(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NumpyTRankNDynamicModule())
|
||||||
|
def NumpyTRankNDynamicModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4, 5, 6, 2))
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyTRank2Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.ops.aten.numpy_T(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NumpyTRank2Module())
|
||||||
|
def NumpyTRank2Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyTRank1Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.ops.aten.numpy_T(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NumpyTRank1Module())
|
||||||
|
def NumpyTRank1Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3))
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyTRank0Module(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.ops.aten.numpy_T(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: NumpyTRank0Module())
|
||||||
|
def NumpyTRank0Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||||
|
|
|
@ -1036,3 +1036,30 @@ func.func @torch.aten.floor_divide(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !tor
|
||||||
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
%0 = torch.aten.floor_divide %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||||
return %0 : !torch.vtensor<[?,?],f32>
|
return %0 : !torch.vtensor<[?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.numpy_T$rank_two(
|
||||||
|
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[4,5],f32> {
|
||||||
|
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST1]], %[[CST0]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUT:.*]] = torch.aten.permute %[[SELF]], %[[DIMS]] : !torch.vtensor<[5,4],f32>, !torch.list<int> -> !torch.vtensor<[4,5],f32>
|
||||||
|
// CHECK: return %[[OUT]] : !torch.vtensor<[4,5],f32>
|
||||||
|
func.func @torch.aten.numpy_T$rank_two(%arg0: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[4,5],f32> {
|
||||||
|
%0 = torch.aten.numpy_T %arg0 : !torch.vtensor<[5,4],f32> -> !torch.vtensor<[4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.numpy_T$rank_three(
|
||||||
|
// CHECK-SAME: %[[SELF:.*]]: !torch.vtensor<[5,4,3],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
// CHECK: %[[CST2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[CST2]], %[[CST1]], %[[CST0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[OUT:.*]] = torch.aten.permute %[[SELF]], %[[DIMS]] : !torch.vtensor<[5,4,3],f32>, !torch.list<int> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
// CHECK: return %[[OUT]] : !torch.vtensor<[3,4,5],f32>
|
||||||
|
func.func @torch.aten.numpy_T$rank_three(%arg0: !torch.vtensor<[5,4,3],f32>) -> !torch.vtensor<[3,4,5],f32> {
|
||||||
|
%0 = torch.aten.numpy_T %arg0 : !torch.vtensor<[5,4,3],f32> -> !torch.vtensor<[3,4,5],f32>
|
||||||
|
return %0 : !torch.vtensor<[3,4,5],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue