mirror of https://github.com/llvm/torch-mlir
Added aten::t() Op
parent
5eed562e19
commit
3cb46cecef
|
@ -1179,3 +1179,56 @@ class BoolTensorReturnMixedModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
|
@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
|
||||||
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
|
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
|
||||||
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
|
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
class TModuleRank2(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.t(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TModuleRank2())
|
||||||
|
def TModuleRank2_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
class TModuleRank1(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.t(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TModuleRank1())
|
||||||
|
def TModuleRank1_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(3))
|
||||||
|
|
||||||
|
class TModuleRank0(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, lhs):
|
||||||
|
return torch.t(lhs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: TModuleRank0())
|
||||||
|
def TModuleRank0_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||||
|
|
||||||
|
|
|
@ -41,4 +41,6 @@ TOSA_PASS_SET = {
|
||||||
"SqueezeModule_static",
|
"SqueezeModule_static",
|
||||||
"SqueezeModule_noUnitDim",
|
"SqueezeModule_noUnitDim",
|
||||||
"SqueezeModule_allUnitDim",
|
"SqueezeModule_allUnitDim",
|
||||||
|
"TModuleRank1_basic",
|
||||||
|
"TModuleRank0_basic",
|
||||||
}
|
}
|
||||||
|
|
|
@ -2603,6 +2603,19 @@ def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
||||||
let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)";
|
let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenTOp : Torch_Op<"aten.t", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::t : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics
|
HasValueSemantics
|
||||||
|
|
|
@ -379,6 +379,36 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenTOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value lhs = op.self();
|
||||||
|
int lhsRank = getTensorRank(lhs);
|
||||||
|
auto loc = op.getLoc();
|
||||||
|
|
||||||
|
if (lhsRank > 2 || lhsRank < 0) {
|
||||||
|
std::string errorMessage =
|
||||||
|
"t() expects a tensor with <=2 dimensions, but self is " +
|
||||||
|
std::to_string(lhsRank) + "D";
|
||||||
|
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
||||||
|
} else if (lhsRank < 2)
|
||||||
|
rewriter.replaceOp(op, lhs);
|
||||||
|
else {
|
||||||
|
Value zero =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
Value one =
|
||||||
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
|
||||||
|
zero, one);
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Decompose torch.expand into torch.broadcast_to op.
|
// Decompose torch.expand into torch.broadcast_to op.
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
||||||
|
@ -565,6 +595,8 @@ class DecomposeComplexOpsPass
|
||||||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||||
target.addIllegalOp<AtenSelectIntOp>();
|
target.addIllegalOp<AtenSelectIntOp>();
|
||||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||||
|
target.addIllegalOp<AtenTOp>();
|
||||||
|
patterns.add<DecomposeAtenTOp>(context);
|
||||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||||
|
|
|
@ -93,8 +93,8 @@ public:
|
||||||
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
||||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
|
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
||||||
op)) {
|
AtenTOp>(op)) {
|
||||||
// AtenContiguousOp might return a view, so this is conservatively
|
// AtenContiguousOp might return a view, so this is conservatively
|
||||||
// correct. We could potentially be more precise and identify the cases
|
// correct. We could potentially be more precise and identify the cases
|
||||||
// that it does not return a view and treat those as having value
|
// that it does not return a view and treat those as having value
|
||||||
|
|
|
@ -374,6 +374,8 @@ public:
|
||||||
return visitReshapeLikeOp(resize, operands);
|
return visitReshapeLikeOp(resize, operands);
|
||||||
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
|
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
|
||||||
return visitAtenTransposeIntOp(transposeInt, operands);
|
return visitAtenTransposeIntOp(transposeInt, operands);
|
||||||
|
} else if (auto t = dyn_cast<AtenTOp>(op)) {
|
||||||
|
return visitAtenTOp(t, operands);
|
||||||
} else if (auto permute = dyn_cast<AtenPermuteOp>(op)) {
|
} else if (auto permute = dyn_cast<AtenPermuteOp>(op)) {
|
||||||
return visitAtenPermuteOp(permute, operands);
|
return visitAtenPermuteOp(permute, operands);
|
||||||
} else if (auto tensorFloat = dyn_cast<AtenTensorFloatOp>(op)) {
|
} else if (auto tensorFloat = dyn_cast<AtenTensorFloatOp>(op)) {
|
||||||
|
@ -550,6 +552,8 @@ private:
|
||||||
visitAtenTransposeIntOp(AtenTransposeIntOp op,
|
visitAtenTransposeIntOp(AtenTransposeIntOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
ChangeResult
|
ChangeResult
|
||||||
|
visitAtenTOp(AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
|
ChangeResult
|
||||||
visitAtenPermuteOp(AtenPermuteOp op,
|
visitAtenPermuteOp(AtenPermuteOp op,
|
||||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||||
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
|
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
|
||||||
|
@ -1242,6 +1246,24 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
|
||||||
return getLatticeElement(op.getResult()).join(knowledge);
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChangeResult TypeAnalyzer::visitAtenTOp(
|
||||||
|
AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
|
auto input = operands[0]->getValue();
|
||||||
|
auto knowledge =
|
||||||
|
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||||
|
knowledge.dtype = input.dtype;
|
||||||
|
if (!input.hasSizes)
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
int64_t inputRank = input.sizes.size();
|
||||||
|
if (inputRank >= 0 && inputRank <= 2) {
|
||||||
|
knowledge.hasSizes = input.hasSizes;
|
||||||
|
knowledge.sizes = input.sizes;
|
||||||
|
if (inputRank == 2)
|
||||||
|
std::swap(knowledge.sizes[0], knowledge.sizes[1]);
|
||||||
|
}
|
||||||
|
return getLatticeElement(op.getResult()).join(knowledge);
|
||||||
|
}
|
||||||
|
|
||||||
ChangeResult TypeAnalyzer::visitAtenPermuteOp(
|
ChangeResult TypeAnalyzer::visitAtenPermuteOp(
|
||||||
AtenPermuteOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
AtenPermuteOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||||
auto input = operands[0]->getValue();
|
auto input = operands[0]->getValue();
|
||||||
|
|
|
@ -594,6 +594,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
||||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||||
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||||
|
emit("aten::t : (Tensor) -> (Tensor)")
|
||||||
|
|
||||||
# Dict ops.
|
# Dict ops.
|
||||||
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
||||||
|
|
Loading…
Reference in New Issue