mirror of https://github.com/llvm/torch-mlir
[MLIR][TORCH] Add E2E support for aten.addcmul_ and aten.addcdiv_ op
Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>pull/1535/head snapshot-20221028.640
parent
5d5aa47cdf
commit
ea602127b6
|
@ -482,6 +482,8 @@ TOSA_PASS_SET = {
|
|||
"HardTanhIntModule_basic",
|
||||
"AtenRoundIntModule_basic",
|
||||
"MseLossNoReductionModule_basic",
|
||||
"AddCMul_Module_basic",
|
||||
"AddCDiv_Module_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
|
|
@ -2796,6 +2796,31 @@ def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddcmul_Op : Torch_Op<"aten.addcmul_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::addcmul_ : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$tensor1,
|
||||
AnyTorchTensorType:$tensor2,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAddcmul_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenAddcmul_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddcdivOp : Torch_Op<"aten.addcdiv", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -2822,6 +2847,31 @@ def Torch_AtenAddcdivOp : Torch_Op<"aten.addcdiv", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddcdiv_Op : Torch_Op<"aten.addcdiv_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::addcdiv_ : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$tensor1,
|
||||
AnyTorchTensorType:$tensor2,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAddcdiv_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenAddcdiv_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -304,8 +304,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
|
||||
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::mish : (Tensor) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue