mirror of https://github.com/llvm/torch-mlir
Add lowering of `aten.Int.Tensor` op.
The lowering of `aten.Int.Tensor` op has been added. The changes has been made as a part of `convert-torch-to-linalg` pass. Signed-off-by: Prashant Kumar <prashant@nod-labs.com>pull/375/head snapshot-20211101.58
parent
69eaf9a154
commit
53b4275ef5
|
@ -493,3 +493,22 @@ class ContiguousModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ContiguousModule())
|
||||
def ContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 1))
|
||||
|
||||
class TensorToInt(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([], torch.int64, True),
|
||||
([], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
# This is a workaround for not returning scalar value.
|
||||
a = int(x)
|
||||
return y.add(y, alpha=a)
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorToInt())
|
||||
def TensorToInt_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10,[]), tu.rand())
|
||||
|
|
|
@ -1948,6 +1948,20 @@ def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [
|
|||
let assemblyFormat = "$t `,` $dtype `,` $device `,` $requires_grad attr-dict `:` type($t) `,` type($dtype) `,` type($device) `,` type($requires_grad) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_AtenIntTensorOp : Torch_Op<"aten.Int.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::Int.Tensor : (Tensor) -> (int)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$a
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
);
|
||||
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($result)";
|
||||
}
|
||||
|
||||
def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -2547,6 +2547,30 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Casts a 0d integer tensor to elemental type.
|
||||
namespace {
|
||||
class ConvertAtenIntTensorOp : public OpConversionPattern<AtenIntTensorOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenIntTensorOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
return failure();
|
||||
AtenIntTensorOp::Adaptor adaptor(operands);
|
||||
Value intTensor = adaptor.a();
|
||||
auto tensorType = intTensor.getType().cast<RankedTensorType>();
|
||||
|
||||
if (tensorType.getRank() != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "invalid rank: the rank of the input tensor must be 0");
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, intTensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
|
||||
public:
|
||||
|
@ -2797,6 +2821,8 @@ public:
|
|||
patterns.add<ConvertAtenOnesOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenContiguousOp>();
|
||||
patterns.add<ConvertAtenContiguousOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIntTensorOp>();
|
||||
patterns.add<ConvertAtenIntTensorOp>(typeConverter, context);
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
|
|
|
@ -555,6 +555,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)")
|
||||
emit("aten::IntImplicit : (Tensor) -> (int)")
|
||||
emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::Int.Tensor : (Tensor) -> (int)")
|
||||
|
||||
# Dict ops.
|
||||
emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True)
|
||||
|
|
|
@ -54,3 +54,16 @@ func @torch.aten.mm$no_convert$result_missing_dtype(%arg0: !torch.vtensor<[?,?],
|
|||
%0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor
|
||||
return %0 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @integer_extract
|
||||
// CHECK-SAME: (%[[A:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
|
||||
// CHECK: %[[B:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[],si64> -> tensor<i64>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract %[[B]][] : tensor<i64>
|
||||
// CHECK: %[[RET:.*]] = torch_c.from_i64 %[[EXT]]
|
||||
// CHECK: return %[[RET]] : !torch.int
|
||||
func @integer_extract(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
|
||||
%0 = torch.aten.Int.Tensor %arg0 : !torch.vtensor<[],si64> -> !torch.int
|
||||
return %0 : !torch.int
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue