mirror of https://github.com/llvm/torch-mlir
Added aten::t() Op
parent
5eed562e19
commit
3cb46cecef
|
@ -1179,3 +1179,56 @@ class BoolTensorReturnMixedModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
|
||||
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
|
||||
|
||||
# ==============================================================================
|
||||
class TModuleRank2(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.t(lhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TModuleRank2())
|
||||
def TModuleRank2_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
class TModuleRank1(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.t(lhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TModuleRank1())
|
||||
def TModuleRank1_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3))
|
||||
|
||||
class TModuleRank0(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, lhs):
|
||||
return torch.t(lhs)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TModuleRank0())
|
||||
def TModuleRank0_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor(7, dtype=torch.float32))
|
||||
|
||||
|
|
|
@ -41,4 +41,6 @@ TOSA_PASS_SET = {
|
|||
"SqueezeModule_static",
|
||||
"SqueezeModule_noUnitDim",
|
||||
"SqueezeModule_allUnitDim",
|
||||
"TModuleRank1_basic",
|
||||
"TModuleRank0_basic",
|
||||
}
|
||||
|
|
|
@ -2603,6 +2603,19 @@ def Torch_AtenDropoutOp : Torch_Op<"aten.dropout", [
|
|||
let assemblyFormat = "$input `,` $p `,` $train attr-dict `:` type($input) `,` type($p) `,` type($train) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenTOp : Torch_Op<"aten.t", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::t : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self attr-dict `:` type($self) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -379,6 +379,36 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenTOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value lhs = op.self();
|
||||
int lhsRank = getTensorRank(lhs);
|
||||
auto loc = op.getLoc();
|
||||
|
||||
if (lhsRank > 2 || lhsRank < 0) {
|
||||
std::string errorMessage =
|
||||
"t() expects a tensor with <=2 dimensions, but self is " +
|
||||
std::to_string(lhsRank) + "D";
|
||||
return rewriter.notifyMatchFailure(op, errorMessage.c_str());
|
||||
} else if (lhsRank < 2)
|
||||
rewriter.replaceOp(op, lhs);
|
||||
else {
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
rewriter.replaceOpWithNewOp<AtenTransposeIntOp>(op, op.getType(), lhs,
|
||||
zero, one);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose torch.expand into torch.broadcast_to op.
|
||||
namespace {
|
||||
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
|
||||
|
@ -565,6 +595,8 @@ class DecomposeComplexOpsPass
|
|||
patterns.add<DecomposeAtenSelectIntOp>(context);
|
||||
target.addIllegalOp<AtenSelectIntOp>();
|
||||
patterns.add<DecomposeAtenMatmulOp>(context);
|
||||
target.addIllegalOp<AtenTOp>();
|
||||
patterns.add<DecomposeAtenTOp>(context);
|
||||
patterns.add<DecomposeAten_LogSoftmaxBackwardDataOp>(context);
|
||||
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
|
||||
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
|
||||
|
|
|
@ -93,8 +93,8 @@ public:
|
|||
AtenFlattenUsingIntsOp, AtenTransposeIntOp,
|
||||
TensorStaticInfoCastOp, AtenBroadcastToOp, AtenToDtypeOp,
|
||||
AtenContiguousOp, AtenPermuteOp, AtenViewOp, AtenExpandOp,
|
||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp>(
|
||||
op)) {
|
||||
AtenFill_ScalarOp, AtenSliceTensorOp, AtenSelectIntOp,
|
||||
AtenTOp>(op)) {
|
||||
// AtenContiguousOp might return a view, so this is conservatively
|
||||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
|
|
|
@ -374,6 +374,8 @@ public:
|
|||
return visitReshapeLikeOp(resize, operands);
|
||||
} else if (auto transposeInt = dyn_cast<AtenTransposeIntOp>(op)) {
|
||||
return visitAtenTransposeIntOp(transposeInt, operands);
|
||||
} else if (auto t = dyn_cast<AtenTOp>(op)) {
|
||||
return visitAtenTOp(t, operands);
|
||||
} else if (auto permute = dyn_cast<AtenPermuteOp>(op)) {
|
||||
return visitAtenPermuteOp(permute, operands);
|
||||
} else if (auto tensorFloat = dyn_cast<AtenTensorFloatOp>(op)) {
|
||||
|
@ -550,6 +552,8 @@ private:
|
|||
visitAtenTransposeIntOp(AtenTransposeIntOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenTOp(AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult
|
||||
visitAtenPermuteOp(AtenPermuteOp op,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitNumToTensorOp(PrimNumToTensorScalarOp op);
|
||||
|
@ -1242,6 +1246,24 @@ ChangeResult TypeAnalyzer::visitAtenTransposeIntOp(
|
|||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenTOp(
|
||||
AtenTOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op.getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
if (!input.hasSizes)
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
int64_t inputRank = input.sizes.size();
|
||||
if (inputRank >= 0 && inputRank <= 2) {
|
||||
knowledge.hasSizes = input.hasSizes;
|
||||
knowledge.sizes = input.sizes;
|
||||
if (inputRank == 2)
|
||||
std::swap(knowledge.sizes[0], knowledge.sizes[1]);
|
||||
}
|
||||
return getLatticeElement(op.getResult()).join(knowledge);
|
||||
}
|
||||
|
||||
ChangeResult TypeAnalyzer::visitAtenPermuteOp(
|
||||
AtenPermuteOp op, ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
|
|
|
@ -594,6 +594,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::dropout : (Tensor, float, bool) -> (Tensor)")
|
||||
emit("aten::t : (Tensor) -> (Tensor)")
|
||||
|
||||
# Dict ops.
|
||||
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
||||
|
|
Loading…
Reference in New Issue