[TORCH][MLIR] Add E2E support for aten.clone (#571)

This commit adds support for the aten.clone op.
pull/573/head
Ramiro Leal-Cavazos 2022-02-09 19:31:03 -08:00 committed by GitHub
parent bd177bdfc7
commit 9b89f8eb3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 4 deletions

View File

@ -1106,3 +1106,21 @@ class ElementwiseAddScalarFloatModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseAddScalarFloatModule()) @register_test_case(module_factory=lambda: ElementwiseAddScalarFloatModule())
def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils): def ElementwiseAddScalarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
class ElementwiseCloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
])
def forward(self, x):
return torch.clone(x)
@register_test_case(module_factory=lambda: ElementwiseCloneModule())
def ElementwiseCloneModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))

View File

@ -2262,6 +2262,21 @@ def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [
let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` qualified(type($self)) `,` qualified(type($boundaries)) `,` qualified(type($out_int32)) `,` qualified(type($right)) `->` qualified(type($result))"; let assemblyFormat = "$self `,` $boundaries `,` $out_int32 `,` $right attr-dict `:` qualified(type($self)) `,` qualified(type($boundaries)) `,` qualified(type($out_int32)) `,` qualified(type($right)) `->` qualified(type($result))";
} }
def Torch_AtenCloneOp : Torch_Op<"aten.clone", [
AllowsTypeRefinement,
HasValueSemantics
]> {
let summary = "Generated op for `aten::clone : (Tensor, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchOptionalIntType:$memory_format
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $memory_format attr-dict `:` qualified(type($self)) `,` qualified(type($memory_format)) `->` qualified(type($result))";
}
def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [ def Torch_AtenContiguousOp : Torch_Op<"aten.contiguous", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {

View File

@ -1660,6 +1660,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return b.create<math::SqrtOp>(loc, payloadArgs[0]); return b.create<math::SqrtOp>(loc, payloadArgs[0]);
if (isa<AtenRsqrtOp>(op)) if (isa<AtenRsqrtOp>(op))
return b.create<math::RsqrtOp>(loc, payloadArgs[0]); return b.create<math::RsqrtOp>(loc, payloadArgs[0]);
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
if (!clone.memory_format().getType().isa<Torch::NoneType>()) {
clone.emitError("unimplemented: only default memory format is supported");
return nullptr;
}
return payloadArgs[0];
}
if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) { if (auto bitwiseAndTensor = dyn_cast<AtenBitwiseAndTensorOp>(op)) {
if (bitwiseAndTensor.getType() if (bitwiseAndTensor.getType()
.cast<ValueTensorType>() .cast<ValueTensorType>()
@ -2450,7 +2457,7 @@ struct ConvertElementwiseOp : ConversionPattern {
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenLtScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp>(op)) AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -4571,7 +4578,7 @@ public:
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqScalarOp, AtenLtScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
AtenThresholdBackwardOp>(); AtenThresholdBackwardOp, AtenCloneOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenSqueezeOp>(); target.addIllegalOp<AtenSqueezeOp>();
patterns.add<ConvertAtenSqueezeOp>(typeConverter, context); patterns.add<ConvertAtenSqueezeOp>(typeConverter, context);

View File

@ -242,8 +242,8 @@ public:
AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp, AtenClampOp, AtenLogOp, AtenNegOp, AtenSqrtOp, AtenFloorOp,
AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp, AtenLog2Op, Aten_SoftmaxBackwardDataOp, AtenRsqrtOp, AtenDropoutOp,
AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp, AtenTanhBackwardOp, Aten_LogSoftmaxBackwardDataOp, AtenAddIntOp,
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp>( AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
op)) { AtenCloneOp>(op)) {
return getLatticeElement(op->getResult(0)).join(*operands[0]); return getLatticeElement(op->getResult(0)).join(*operands[0]);
} }

View File

@ -577,6 +577,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")
emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)")
emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)") emit("aten::copy_ : (Tensor, Tensor, bool) -> (Tensor)")
emit("aten::detach : (Tensor) -> (Tensor)") emit("aten::detach : (Tensor) -> (Tensor)")