Generate underscore variant of functional ops (#915)

* Generate underscore variant of functional ops

* Do not apply `IsTrailingUnderscoreInplaceVariant` trait to underscore variant of functional op
pull/918/head
Henry Tu 2022-06-08 14:27:36 -04:00 committed by GitHub
parent bd53998da8
commit c1da9edcf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 54 deletions

View File

@ -2395,6 +2395,50 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
}]; }];
} }
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenZero_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [ def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -2593,29 +2637,6 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
}]; }];
} }
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [ def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
AllowsTypeRefinement AllowsTypeRefinement
]> { ]> {
@ -4186,27 +4207,6 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
}]; }];
} }
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenZero_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -4515,11 +4515,9 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
} }
def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [ def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [
AllowsTypeRefinement, AllowsTypeRefinement
HasValueSemantics,
ReadOnly
]> { ]> {
let summary = "Generated op for `aten::arange.start_out : (Scalar start, Scalar end, Scalar step, Tensor out) -> (Tensor)`"; let summary = "Generated op for `aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchScalarType:$start, AnyTorchScalarType:$start,
AnyTorchScalarType:$end, AnyTorchScalarType:$end,
@ -5742,7 +5740,7 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::scatter_add : (Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)`"; let summary = "Generated op for `aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
Torch_IntType:$dim, Torch_IntType:$dim,
@ -8293,3 +8291,4 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
} }
}]; }];
} }

View File

@ -218,9 +218,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
operator = registry[key] operator = registry[key]
emit_op(operator, emitter_td, **kwargs) emit_op(operator, emitter_td, **kwargs)
ns, unqual, overload = operator.triple ns, unqual, overload = operator.triple
emit_op(registry.get_by_triple((ns, unqual + "_", overload)), # Underscore variant of functional ops should have "functional" part removed.
is_functional_op = overload == 'functional'
emit_op(registry.get_by_triple((ns, unqual + "_", overload if not is_functional_op else "")),
emitter_td, emitter_td,
traits=["IsTrailingUnderscoreInplaceVariant"]) traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [])
# ========================================================================== # ==========================================================================
# `aten::` namespace. # `aten::` namespace.
@ -279,7 +281,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)", "aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)", "aten::unsqueeze : (Tensor, int) -> (Tensor)",
"aten::zero.functional : (Tensor) -> (Tensor)",
]: ]:
emit_with_mutating_variants(key) emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating # Elementwise tensor compute ops that don't have the standard mutating
@ -292,7 +294,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::gelu : (Tensor, str) -> (Tensor)") emit("aten::gelu : (Tensor, str) -> (Tensor)")
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)") emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::zero.functional : (Tensor) -> (Tensor)")
# Ops without value semantics but the corresponding without trailing # Ops without value semantics but the corresponding without trailing
# underscore variant doesn't exist. # underscore variant doesn't exist.
@ -385,7 +386,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::zero_ : (Tensor) -> (Tensor)")
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")