[MLIR][TORCH] Add E2E support for aten.addcmul_ and aten.addcdiv_ op

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
pull/1535/head snapshot-20221028.640
Vivek Khandelwal 2022-10-26 18:11:52 +05:30
parent 5d5aa47cdf
commit ea602127b6
3 changed files with 54 additions and 2 deletions

View File

@ -482,6 +482,8 @@ TOSA_PASS_SET = {
"HardTanhIntModule_basic",
"AtenRoundIntModule_basic",
"MseLossNoReductionModule_basic",
"AddCMul_Module_basic",
"AddCDiv_Module_basic",
}
LTC_XFAIL_SET = {

View File

@ -2796,6 +2796,31 @@ def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
}];
}
def Torch_AtenAddcmul_Op : Torch_Op<"aten.addcmul_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::addcmul_ : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor1,
AnyTorchTensorType:$tensor2,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAddcmul_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAddcmul_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenAddcdivOp : Torch_Op<"aten.addcdiv", [
AllowsTypeRefinement,
HasValueSemantics,
@ -2822,6 +2847,31 @@ def Torch_AtenAddcdivOp : Torch_Op<"aten.addcdiv", [
}];
}
def Torch_AtenAddcdiv_Op : Torch_Op<"aten.addcdiv_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::addcdiv_ : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$tensor1,
AnyTorchTensorType:$tensor2,
AnyTorchScalarType:$value
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAddcdiv_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAddcdiv_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -304,8 +304,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
emit("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::mish : (Tensor) -> (Tensor)")