mirror of https://github.com/llvm/torch-mlir
Generate underscore variant of functional ops (#915)
* Generate underscore variant of functional ops * Do not apply `IsTrailingUnderscoreInplaceVariant` trait to underscore variant of functional oppull/918/head
parent
bd53998da8
commit
c1da9edcf0
|
@ -2395,6 +2395,50 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||||
|
}
|
||||||
|
void AtenZero_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -2593,29 +2637,6 @@ def Torch_AtenThresholdBackwardOp : Torch_Op<"aten.threshold_backward", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
|
||||||
}
|
|
||||||
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
||||||
AllowsTypeRefinement
|
AllowsTypeRefinement
|
||||||
]> {
|
]> {
|
||||||
|
@ -4186,27 +4207,6 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
|
|
||||||
AllowsTypeRefinement
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
|
||||||
}
|
|
||||||
void AtenZero_Op::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
|
def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -4515,11 +4515,9 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [
|
def Torch_AtenArangeStartOutOp : Torch_Op<"aten.arange.start_out", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::arange.start_out : (Scalar start, Scalar end, Scalar step, Tensor out) -> (Tensor)`";
|
let summary = "Generated op for `aten::arange.start_out : (Scalar, Scalar, Scalar, Tensor) -> (Tensor)`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchScalarType:$start,
|
AnyTorchScalarType:$start,
|
||||||
AnyTorchScalarType:$end,
|
AnyTorchScalarType:$end,
|
||||||
|
@ -5742,7 +5740,7 @@ def Torch_AtenScatterAddOp : Torch_Op<"aten.scatter_add", [
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
]> {
|
]> {
|
||||||
let summary = "Generated op for `aten::scatter_add : (Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)`";
|
let summary = "Generated op for `aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)`";
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
AnyTorchTensorType:$self,
|
AnyTorchTensorType:$self,
|
||||||
Torch_IntType:$dim,
|
Torch_IntType:$dim,
|
||||||
|
@ -8293,3 +8291,4 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -218,9 +218,11 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
operator = registry[key]
|
operator = registry[key]
|
||||||
emit_op(operator, emitter_td, **kwargs)
|
emit_op(operator, emitter_td, **kwargs)
|
||||||
ns, unqual, overload = operator.triple
|
ns, unqual, overload = operator.triple
|
||||||
emit_op(registry.get_by_triple((ns, unqual + "_", overload)),
|
# Underscore variant of functional ops should have "functional" part removed.
|
||||||
|
is_functional_op = overload == 'functional'
|
||||||
|
emit_op(registry.get_by_triple((ns, unqual + "_", overload if not is_functional_op else "")),
|
||||||
emitter_td,
|
emitter_td,
|
||||||
traits=["IsTrailingUnderscoreInplaceVariant"])
|
traits=["IsTrailingUnderscoreInplaceVariant"] if not is_functional_op else [])
|
||||||
|
|
||||||
# ==========================================================================
|
# ==========================================================================
|
||||||
# `aten::` namespace.
|
# `aten::` namespace.
|
||||||
|
@ -279,7 +281,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||||
"aten::square : (Tensor) -> (Tensor)",
|
"aten::square : (Tensor) -> (Tensor)",
|
||||||
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
|
||||||
|
"aten::zero.functional : (Tensor) -> (Tensor)",
|
||||||
]:
|
]:
|
||||||
emit_with_mutating_variants(key)
|
emit_with_mutating_variants(key)
|
||||||
# Elementwise tensor compute ops that don't have the standard mutating
|
# Elementwise tensor compute ops that don't have the standard mutating
|
||||||
|
@ -292,7 +294,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
emit("aten::gelu : (Tensor, str) -> (Tensor)")
|
||||||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||||
emit("aten::zero.functional : (Tensor) -> (Tensor)")
|
|
||||||
|
|
||||||
# Ops without value semantics but the corresponding without trailing
|
# Ops without value semantics but the corresponding without trailing
|
||||||
# underscore variant doesn't exist.
|
# underscore variant doesn't exist.
|
||||||
|
@ -385,7 +386,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::zero_ : (Tensor) -> (Tensor)")
|
|
||||||
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
|
||||||
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
|
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue