Add e2e support for aten.expm1

pull/1097/head snapshot-20220727.546
Quinn Dawkins 2022-07-27 02:36:52 +00:00 committed by Vivek Khandelwal
parent 052d2f84dc
commit 3c9addf19c
7 changed files with 113 additions and 12 deletions

View File

@ -565,6 +565,51 @@ def Torch_AtenExp_Op : Torch_Op<"aten.exp_", [
}];
}
def Torch_AtenExpm1Op : Torch_Op<"aten.expm1", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::expm1 : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenExpm1Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenExpm1Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenExpm1_Op : Torch_Op<"aten.expm1_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::expm1_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenExpm1_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenExpm1_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenCosOp : Torch_Op<"aten.cos", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -131,6 +131,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCalculationForMathOpWithDtypeConversion<math::ExpOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenExpm1Op>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ExpM1Op>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenLogOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::LogOp>(
b, converter, payloadArgs[0], op);
@ -923,15 +927,15 @@ public:
if (!isa<AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp,
AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp,
AtenDivTensorOp, AtenDivTensorModeOp, AtenSubTensorOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenExpm1Op,
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp,
AtenTriuOp>(op))

View File

@ -663,9 +663,9 @@ void TypeAnalysis::visitOperation(Operation *op,
}
// Dtype is always float32, except for bfloat16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp>(op)) {
if (isa<AtenTanhOp, AtenExpOp, AtenExpm1Op, AtenSinOp, AtenCosOp,
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;

View File

@ -5337,6 +5337,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.expm1"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.sin"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -315,6 +315,9 @@ def atensilu(self: List[int]) -> List[int]:
def atenexp(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenexpm1(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atensin(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -250,6 +250,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::silu : (Tensor) -> (Tensor)",
"aten::sin : (Tensor) -> (Tensor)",
"aten::exp : (Tensor) -> (Tensor)",
"aten::expm1 : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",

View File

@ -1528,6 +1528,50 @@ def ElementwiseExpIntModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseExpm1Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.special.expm1(a)
@register_test_case(module_factory=lambda: ElementwiseExpm1Module())
def ElementwiseExpm1Module_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseExpm1IntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, a):
return torch.special.expm1(a)
@register_test_case(module_factory=lambda: ElementwiseExpm1IntModule())
def ElementwiseExpm1IntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseSinModule(torch.nn.Module):
def __init__(self):