Added aten::t() Op

pull/500/head
Nirvedh 2021-12-17 04:08:07 +00:00 committed by nirvedhmeshram
parent 5eed562e19
commit 3cb46cecef
7 changed files with 125 additions and 2 deletions

View File

@ -1179,3 +1179,56 @@ class BoolTensorReturnMixedModule(torch.nn.Module):
@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
# ==============================================================================
class TModuleRank2(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, lhs):
return torch.t(lhs)
@register_test_case(module_factory=lambda: TModuleRank2())
def TModuleRank2_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class TModuleRank1(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float32, True),
])
def forward(self, lhs):
return torch.t(lhs)
@register_test_case(module_factory=lambda: TModuleRank1())
def TModuleRank1_basic(module, tu: TestUtils):
module.forward(tu.rand(3))
class TModuleRank0(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([], torch.float32, True),
])
def forward(self, lhs):
return torch.t(lhs)
@register_test_case(module_factory=lambda: TModuleRank0())
def TModuleRank0_basic(module, tu: TestUtils):
module.forward(torch.tensor(7, dtype=torch.float32))

View File

@ -41,4 +41,6 @@ TOSA_PASS_SET = {
"SqueezeModule_static",
"SqueezeModule_noUnitDim",
"SqueezeModule_allUnitDim",
"TModuleRank1_basic",
"TModuleRank0_basic",
}

View File

@ -2603,6 +2603,19 @@ def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)";
}
def Torch_AtenTOp : Torch_Op<"aten.t", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::t : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
}
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -379,6 +379,36 @@ public:
};
} // namespace
namespace {
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTOp op,
PatternRewriter &rewriter) const override {
Value lhs = op.self();
int lhsRank = getTensorRank(lhs);
auto loc = op.getLoc();
if (lhsRank > 2 || lhsRank < 0) {
std::string errorMessage =
"t() expects a tensor with <=2 dimensions, but self is " +
std::to_string(lhsRank) + "D";
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
} else if (lhsRank < 2)
rewriter.replaceOp(op, lhs);
else {
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
zero, one);
}
return success();
}
};
} // namespace
// Decompose torch.expand into torch.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
@ -565,6 +595,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenSelectIntOp>(context);
target.addIllegalOp<AtenSelectIntOp>();
patterns.add<DecomposeAtenMatmulOp>(context);
target.addIllegalOp<AtenTOp>();
patterns.add<DecomposeAtenTOp>(context);
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {

View File

@ -93,8 +93,8 @@ public:
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
op)) {
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
AtenTOp>(op)) {
// AtenContiguousOp might return a view, so this is conservatively
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value

View File

@ -374,6 +374,8 @@ public:
return visitReshapeLikeOp(resize, operands);
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
return visitAtenTransposeIntOp(transposeInt, operands);
} else if (auto t = dyn_cast<AtenTOp>(op)) {
return visitAtenTOp(t, operands);
} else if (auto permute = dyn_cast<AtenPermuteOp>(op)) {
return visitAtenPermuteOp(permute, operands);
} else if (auto tensorFloat = dyn_cast<AtenTensorFloatOp>(op)) {
@ -550,6 +552,8 @@ private:
visitAtenTransposeIntOp(AtenTransposeIntOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenTOp(AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenPermuteOp(AtenPermuteOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
@ -1242,6 +1246,24 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenTOp(
AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
knowledge.dtype = input.dtype;
if (!input.hasSizes)
return getLatticeElement(op.getResult()).join(knowledge);
int64_t inputRank = input.sizes.size();
if (inputRank >= 0 && inputRank <= 2) {
knowledge.hasSizes = input.hasSizes;
knowledge.sizes = input.sizes;
if (inputRank == 2)
std::swap(knowledge.sizes[0], knowledge.sizes[1]);
}
return getLatticeElement(op.getResult()).join(knowledge);
}
ChangeResult TypeAnalyzer::visitAtenPermuteOp(
AtenPermuteOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();

View File

@ -594,6 +594,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
emit("aten::t : (Tensor) -> (Tensor)")
# Dict ops.
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)