Add support for aten::arange.start_out (#905)

pull/909/head
Jae Hoon (Antonio) Kim 2022-06-06 15:02:27 -04:00 committed by GitHub
parent 2718b4d838
commit 8a1839a17e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 0 deletions

View File

@ -4465,6 +4465,32 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
}];
}
def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::arange.start_out : (Scalar start, Scalar end, Scalar step, Tensor out) -> (Tensor)`";
let arguments = (ins
AnyTorchScalarType:$start,
AnyTorchScalarType:$end,
AnyTorchScalarType:$step,
AnyTorchTensorType:$out
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenArangeStartOutOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenArangeStartOutOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -397,6 +397,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start : (Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_step : (Scalar, Scalar, Scalar, int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)")
emit("aten::argmax : (Tensor, int?, bool) -> (Tensor)")
emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)")
emit("aten::clone : (Tensor, int?) -> (Tensor)")