mirror of https://github.com/llvm/torch-mlir
Generate underscore variant of functional ops (#915)
* Generate underscore variant of functional ops * Do not apply `IsTrailingUnderscoreInplaceVariant` trait to underscore variant of functional oppull/918/head
parent
bd53998da8
commit
c1da9edcf0
|
@ -2395,6 +2395,50 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenZero_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -2593,29 +2637,6 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
@ -4186,27 +4207,6 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenZero_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -4515,11 +4515,9 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
|
|||
}
|
||||
|
||||
def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::arange.start_out : (Scalar start, Scalar end, Scalar step, Tensor out) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchScalarType:$start,
|
||||
AnyTorchScalarType:$end,
|
||||
|
@ -5742,7 +5740,7 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [
|
|||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::scatter_add : (Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)`";
|
||||
let summary = "Generated op for `aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$dim,
|
||||
|
@ -8293,3 +8291,4 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
|||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -218,9 +218,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
operator = registry[key]
|
||||
emit_op(operator, emitter_td, **kwargs)
|
||||
ns, unqual, overload = operator.triple
|
||||
emit_op(registry.get_by_triple((ns, unqual + "_", overload)),
|
||||
# Underscore variant of functional ops should have "functional" part removed.
|
||||
is_functional_op = overload == 'functional'
|
||||
emit_op(registry.get_by_triple((ns, unqual + "_", overload if not is_functional_op else "")),
|
||||
emitter_td,
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"])
|
||||
traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [])
|
||||
|
||||
# ==========================================================================
|
||||
# `aten::` namespace.
|
||||
|
@ -279,7 +281,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::square : (Tensor) -> (Tensor)",
|
||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
||||
|
||||
"aten::zero.functional : (Tensor) -> (Tensor)",
|
||||
]:
|
||||
emit_with_mutating_variants(key)
|
||||
# Elementwise tensor compute ops that don't have the standard mutating
|
||||
|
@ -292,7 +294,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::zero.functional : (Tensor) -> (Tensor)")
|
||||
|
||||
# Ops without value semantics but the corresponding without trailing
|
||||
# underscore variant doesn't exist.
|
||||
|
@ -385,7 +386,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zero_ : (Tensor) -> (Tensor)")
|
||||
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue