From 8a1839a17e1c52080da168bece005a4be3a1c3c1 Mon Sep 17 00:00:00 2001 From: "Jae Hoon (Antonio) Kim" <17433012+antoniojkim@users.noreply.github.com> Date: Mon, 6 Jun 2022 15:02:27 -0400 Subject: [PATCH] Add support for aten::arange.start_out (#905) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 +++++++++++++++++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + 2 files changed, 27 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c41075b03..cf2653d0f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4465,6 +4465,32 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [ }]; } +def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::arange.start_out : (Scalar start, Scalar end, Scalar step, Tensor out) -> (Tensor)`"; + let arguments = (ins + AnyTorchScalarType:$start, + AnyTorchScalarType:$end, + AnyTorchScalarType:$step, + AnyTorchTensorType:$out + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenArangeStartOutOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenArangeStartOutOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index b92433ec0..aa3a4ede7 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -397,6 +397,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)") emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)")