[MLIR][TORCH] Add E2E support for aten.polar op (#3671)

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
pull/3685/head
Vivek Khandelwal 2024-09-03 10:51:03 +05:30 committed by GitHub
parent 3180704b14
commit 567ed44fd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 205 additions and 0 deletions

View File

@ -5332,6 +5332,30 @@ def Torch_AtenSoftshrinkOp : Torch_Op<"aten.softshrink", [
}]; }];
} }
def Torch_AtenPolarOp : Torch_Op<"aten.polar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::polar : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$abs,
AnyTorchTensorType:$angle
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenPolarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenPolarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [ def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -3295,6 +3295,72 @@ public:
}; };
} // namespace } // namespace
namespace {
class ConvertAtenPolarOp : public OpConversionPattern<AtenPolarOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenPolarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
const TypeConverter *typeConverter = getTypeConverter();
MLIRContext *context = rewriter.getContext();
Value absTensor = adaptor.getAbs();
Value angleTensor = adaptor.getAngle();
RankedTensorType resultType =
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
auto elementType = resultType.getElementType();
SmallVector<Value> resultShape;
for (int64_t i = 0; i < resultType.getRank(); i++) {
auto currentDimSize = rewriter.create<tensor::DimOp>(loc, absTensor, i);
resultShape.push_back(currentDimSize);
}
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(resultShape), elementType);
SmallVector<AffineExpr> outputExpr;
for (unsigned i = 0; i < resultType.getRank(); i++) {
outputExpr.push_back(getAffineDimExpr(i, context));
}
AffineMap identityMap =
AffineMap::get(resultType.getRank(), 0, outputExpr, op->getContext());
SmallVector<AffineMap> indexingMaps{identityMap, identityMap, identityMap};
SmallVector<utils::IteratorType> iteratorTypes(
resultType.getRank(), utils::IteratorType::parallel);
auto complexVar =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange{absTensor, angleTensor},
outTensor, indexingMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
// out = abs⋅cos(angle) + abs⋅sin(angle)⋅j
Value abs = args[0];
Value angle = args[1];
Value realVal = b.create<math::CosOp>(loc, angle);
Value imagVal = b.create<math::SinOp>(loc, angle);
realVal = b.create<arith::MulFOp>(loc, abs, realVal);
imagVal = b.create<arith::MulFOp>(loc, abs, imagVal);
Value complexVal = b.create<complex::CreateOp>(
loc, elementType, realVal, imagVal);
b.create<linalg::YieldOp>(loc, complexVal);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, complexVar);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target) {
@ -3355,4 +3421,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertInterpolateOp>(typeConverter, context); patterns.add<ConvertInterpolateOp>(typeConverter, context);
target.addIllegalOp<AtenLinalgDetOp>(); target.addIllegalOp<AtenLinalgDetOp>();
patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context); patterns.add<ConvertAtenLinalgDetOp>(typeConverter, context);
target.addIllegalOp<AtenPolarOp>();
patterns.add<ConvertAtenPolarOp>(typeConverter, context);
} }

View File

@ -6715,6 +6715,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.polar\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.mish\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.mish\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -11276,6 +11280,44 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
" return %1 : !torch.int\n" " return %1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.polar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int9 = torch.constant.int 9\n"
" %int6 = torch.constant.int 6\n"
" %int10 = torch.constant.int 10\n"
" %int7 = torch.constant.int 7\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %0 = torch.prim.Uninitialized : !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.aten.eq.int %1#1, %2#1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int10 : !torch.bool, !torch.int\n"
" } else {\n"
" %7 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.int) {\n"
" torch.prim.If.yield %true, %int9 : !torch.bool, !torch.int\n"
" } else {\n"
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.int\n"
" }\n"
" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.int\n"
" }\n"
" %6 = torch.prim.If %5#0 -> (!torch.int) {\n"
" torch.prim.If.yield %5#1 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %1#1 : !torch.int\n"
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"

View File

@ -2428,6 +2428,8 @@ ONNX_XFAIL_SET = {
"AtenMmQMixedSigni8_basic", "AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic", "AtenMmQint8_basic",
"AtenMmQuint8_basic", "AtenMmQuint8_basic",
"AtenPolarFloatModule_basic",
"AtenPolarDoubleModule_basic",
"AtenRealView128Module_basic", "AtenRealView128Module_basic",
"AtenRealView64Module_basic", "AtenRealView64Module_basic",
"AtenSubFloatModule_basic", "AtenSubFloatModule_basic",
@ -3794,6 +3796,8 @@ ONNX_TOSA_XFAIL_SET = {
"AtenMmQMixedSigni8_basic", "AtenMmQMixedSigni8_basic",
"AtenMmQint8_basic", "AtenMmQint8_basic",
"AtenMmQuint8_basic", "AtenMmQuint8_basic",
"AtenPolarFloatModule_basic",
"AtenPolarDoubleModule_basic",
"AtenRealView128Module_basic", "AtenRealView128Module_basic",
"AtenRealView64Module_basic", "AtenRealView64Module_basic",
"AtenRoundFloatHalfToEvenModule_basic", "AtenRoundFloatHalfToEvenModule_basic",

View File

@ -322,6 +322,9 @@ def atenhardshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]:
def atensoftshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]: def atensoftshrink〡shape(self: List[int], lambd: float = 0.5) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atenpolar〡shape(abs: List[int], angle: List[int]) -> List[int]:
return upstream_shape_functions.unary(abs)
def atenmish〡shape(self: List[int]) -> List[int]: def atenmish〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
@ -2595,6 +2598,17 @@ def atensoftshrink〡dtype(self_rank_dtype: Tuple[int, int], lambd: Union[int
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return _get_dtype_of_floating_point_op(self_dtype) return _get_dtype_of_floating_point_op(self_dtype)
def atenpolar〡dtype(abs_rank_dtype: Tuple[int, int], angle_rank_dtype: Tuple[int, int]) -> int:
_, abs_dtype = abs_rank_dtype
_, angle_dtype = angle_rank_dtype
assert (abs_dtype == angle_dtype)
if abs_dtype == torch.float64:
return torch.complex128
elif abs_dtype == torch.float32:
return torch.complex64
return abs_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atenlogit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: def atenlogit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype

View File

@ -501,6 +501,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::log_sigmoid : (Tensor) -> (Tensor)") emit("aten::log_sigmoid : (Tensor) -> (Tensor)")
emit("aten::hardshrink : (Tensor, Scalar) -> (Tensor)") emit("aten::hardshrink : (Tensor, Scalar) -> (Tensor)")
emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)") emit("aten::softshrink : (Tensor, Scalar) -> (Tensor)")
emit("aten::polar : (Tensor, Tensor) -> (Tensor)")
# Ops with dynamic number of outputs # Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])") emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")

View File

@ -5761,3 +5761,55 @@ class UnfoldModule(torch.nn.Module):
@register_test_case(module_factory=lambda: UnfoldModule()) @register_test_case(module_factory=lambda: UnfoldModule())
def UnfoldModule_basic(module, tu: TestUtils): def UnfoldModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 4)) module.forward(tu.rand(2, 5, 3, 4))
# ==============================================================================
class AtenPolarFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, abs_, angle):
return torch.ops.aten.polar(torch.ops.aten.abs(abs_), angle)
@register_test_case(module_factory=lambda: AtenPolarFloatModule())
def AtenPolarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5, 3, 4), tu.rand(2, 5, 3, 4))
# ==============================================================================
class AtenPolarDoubleModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.unfold = torch.nn.Unfold(kernel_size=(2, 3))
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float64, True),
([-1, -1, -1, -1], torch.float64, True),
]
)
def forward(self, abs_, angle):
return torch.ops.aten.polar(torch.ops.aten.abs(abs_), angle)
@register_test_case(module_factory=lambda: AtenPolarDoubleModule())
def AtenPolarDoubleModule_basic(module, tu: TestUtils):
module.forward(
tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64)
)