[Torch Dialect] emit aten.__or__Tensor Op (#2437)

* emit aten.__or__TensorOp

* bug fix

* remove convert to stablehlo

* code style refinement
pull/2443/head snapshot-20230906.953
Jiawei Wu 2023-09-06 14:21:51 +08:00 committed by GitHub
parent fcb3b718a5
commit b411a40b3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 115 additions and 0 deletions

View File

@ -452,6 +452,7 @@ STABLEHLO_PASS_SET = {
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseBitwiseNotInt32Module_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseBitwiseOrStaticShapeModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseClampModule_basic",
@ -985,6 +986,8 @@ TOSA_PASS_SET = {
"ElementwiseBitwiseAndStaticShapeModule_basic",
"ElementwiseBitwiseNotInt32Module_basic",
"ElementwiseBitwiseNotInt64Module_basic",
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseOrTensorModule_basic",
"ElementwiseBitwiseOrModule_basic",
"ElementwiseBitwiseOrStaticShapeModule_basic",
"ElementwiseBitwiseXorModule_basic",

View File

@ -6141,6 +6141,31 @@ def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [
}];
}
def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__Or__TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void Aten__Or__TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_Aten_SoftmaxOp : Torch_Op<"aten._softmax", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -1124,6 +1124,19 @@ void AtenDivTensorModeOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// Aten__Or__TensorOp
//===----------------------------------------------------------------------===//
void Aten__Or__TensorOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add(+[](Aten__Or__TensorOp op, PatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<AtenBitwiseOrTensorOp>(
op, op.getType(), op.getSelf(), op.getOther());
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenScalarImplicitOp
//===----------------------------------------------------------------------===//

View File

@ -7376,6 +7376,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.__or__.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.minimum\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
@ -9154,6 +9158,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.__or__.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -771,6 +771,9 @@ def atenatan2〡shape(self: List[int], other: List[int]) -> List[int]:
def aten__and__Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def aten__or__Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenminimum〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
@ -2223,6 +2226,14 @@ def aten__and__Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def aten__or__Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
other_rank, other_dtype = other_rank_dtype
self_rank, self_dtype = self_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atenaddTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
other_rank, other_dtype = other_rank_dtype

View File

@ -456,6 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
emit("aten::mean : (Tensor, int?) -> (Tensor)")
emit("aten::std : (Tensor, bool) -> (Tensor)")

View File

@ -2024,6 +2024,56 @@ def ElementwiseBitwiseOrStaticShapeModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseOrTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.__or__(x, y)
@register_test_case(module_factory=lambda: ElementwiseOrTensorModule())
def ElementwiseOrTensorModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
tu.randint(3, 4, low=-10, high=10))
# ==============================================================================
class ElementwiseOrTensorStaticShapeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 4], torch.int32, True),
([4], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.__or__(x, y)
@register_test_case(module_factory=lambda: ElementwiseOrTensorStaticShapeModule())
def ElementwiseOrTensorStaticShapeModule_basic(module, tu: TestUtils):
module.forward(
tu.randint(3, 4, low=-10, high=10).to(torch.int32),
tu.randint(4, low=-10, high=10))
# ==============================================================================
class ElementwiseBitwiseXorModule(torch.nn.Module):
def __init__(self):