Add e2e linalg support for aten.atan (#2070)

* new atan op

* update shape

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
pull/2040/head snapshot-20230428.822
Ze Zhang 2023-04-28 00:04:58 -07:00 committed by GitHub
parent a58442b50d
commit 7b73e0cfaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 108 additions and 6 deletions

View File

@ -700,6 +700,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [
}];
}
def Torch_AtenAtanOp : Torch_Op<"aten.atan", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::atan : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtanOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAtanOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::atan_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAtan_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenAtan_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -229,6 +229,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCalculationForMathOpWithDtypeConversion<math::CosOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenAtanOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::AtanOp>(
b, converter, payloadArgs[0], op);
}
if (auto clone = dyn_cast<AtenCloneOp>(op)) {
int64_t memoryFormat;
if (!clone.getMemoryFormat().getType().isa<Torch::NoneType>() &&
@ -1119,7 +1123,7 @@ public:
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1595,7 +1599,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);

View File

@ -6115,6 +6115,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -664,10 +664,10 @@ void TypeAnalysis::visitOperation(Operation *op,
}
// Dtype is always float32, except for bfloat16, float16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp, AtenFrobeniusNormDimOp,
PrimsSqrtOp>(op)) {
if (isa<AtenAtanOp, AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp,
AtenSigmoidOp, AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenErfOp, AtenSoftplusOp,
AtenFrobeniusNormDimOp, PrimsSqrtOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;

View File

@ -46,6 +46,9 @@ def _embedding_bag_helper(weight: List[int], indices: List[int], offsets: List[i
def atentriu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)
def atenatan〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atentanh〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -255,6 +255,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::exp : (Tensor) -> (Tensor)",
"aten::expm1 : (Tensor) -> (Tensor)",
"aten::cos : (Tensor) -> (Tensor)",
"aten::atan : (Tensor) -> (Tensor)",
"aten::atan2 : (Tensor, Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)",
"aten::floor : (Tensor) -> (Tensor)",

View File

@ -933,6 +933,51 @@ def ElementwiseMishModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseAtanTensorFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.atan(a)
@register_test_case(module_factory=lambda: ElementwiseAtanTensorFloatModule())
def ElementwiseAtanTensorFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(4, 4))
# ==============================================================================
class ElementwiseAtanTensorIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
])
def forward(self, a):
return torch.atan(a)
@register_test_case(module_factory=lambda: ElementwiseAtanTensorIntModule())
def ElementwiseAtanTensorIntModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(4, low=1, high=10).type(torch.int32))
# ==============================================================================
class ElementwiseAtan2TensorFloatModule(torch.nn.Module):
def __init__(self):