mirror of https://github.com/llvm/torch-mlir
[Torch] Support Aten__And__ScalarOp (#3114)
parent
2c56ef9252
commit
84c24e5771
|
@ -7556,6 +7556,31 @@ def Torch_Aten__And__TensorOp : Torch_Op<"aten.__and__.Tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_Aten__And__ScalarOp : Torch_Op<"aten.__and__.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten__And__ScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void Aten__And__ScalarOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten__Or__TensorOp : Torch_Op<"aten.__or__.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -577,13 +577,24 @@ public:
|
|||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.getSelf();
|
||||
Value rhs = adaptor.getOther();
|
||||
|
||||
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
return op.emitError("lhs must be a ranked tensor type");
|
||||
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
Value lhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), adaptor.getSelf(), outType);
|
||||
Value rhs =
|
||||
hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType);
|
||||
Type outElemTy = outType.getElementType();
|
||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType);
|
||||
if (!rhsTy) {
|
||||
rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy);
|
||||
}
|
||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType);
|
||||
|
||||
DenseI64ArrayAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
|
@ -1861,6 +1872,8 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp);
|
||||
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp);
|
||||
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp);
|
||||
INSERT_BINARY_LOGICAL_PATTERN(AtenBitwiseAndScalarOp, chlo::BroadcastAndOp);
|
||||
|
||||
#undef INSERT_BINARY_LOGICAL_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
|
|
|
@ -1872,6 +1872,18 @@ void Aten__Or__TensorOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Aten__And__ScalarOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
void Aten__And__ScalarOp::getCanonicalizationPatterns(
|
||||
RewritePatternSet &patterns, MLIRContext *context) {
|
||||
patterns.add(+[](Aten__And__ScalarOp op, PatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<AtenBitwiseAndScalarOp>(
|
||||
op, op.getType(), op.getSelf(), op.getOther());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenScalarImplicitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6938,6 +6938,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.__and__.Scalar\"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.remainder.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -10839,6 +10843,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
|
||||
" %none = torch.constant.none\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1 = torch.prim.ListConstruct %0#0, %none : (!torch.int, !torch.none) -> !torch.list<optional<int>>\n"
|
||||
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg1) : (!torch.number) -> !torch.int\n"
|
||||
" %3 = torch.prim.ListConstruct %0#1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
|
||||
" return %4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
|
|
|
@ -500,6 +500,7 @@ STABLEHLO_PASS_SET = {
|
|||
"ElementwiseNeIntTensorStaticModule_basic",
|
||||
"ElementwiseNegModule_basic",
|
||||
"ElementwiseOrTensorStaticShapeModule_basic",
|
||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||
"ElementwisePowTensorBroadcastStaticModule_basic",
|
||||
"ElementwisePowTensorStaticModule_basic",
|
||||
"ElementwisePreluStaticModule_basic",
|
||||
|
@ -1667,6 +1668,8 @@ ONNX_XFAIL_SET = {
|
|||
"DivIntModule_basic",
|
||||
"ElementwiseAcoshIntModule_basic",
|
||||
"ElementwiseAcoshModule_basic",
|
||||
"ElementwiseAndScalarModule_basic",
|
||||
"ElementwiseAndScalarStaticShapeModule_basic",
|
||||
"ElementwiseAsinhIntModule_basic",
|
||||
"ElementwiseAsinhModule_basic",
|
||||
"ElementwiseAtanhIntModule_basic",
|
||||
|
|
|
@ -478,6 +478,9 @@ def aten〇div〇Scalar〡shape(self: List[int], other: float) -> List[int]:
|
|||
def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇__and__〇Scalar〡shape(self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇remainder〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -3002,6 +3005,15 @@ def aten〇rsub〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))
|
||||
def aten〇__and__〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[int, float, complex]) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
ranks: List[Optional[int]] = [self_rank, None]
|
||||
dtypes = [self_dtype, get_dtype_of_scalar(other)]
|
||||
return promote_dtypes(ranks, dtypes)
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇__and__〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
|
|
|
@ -535,6 +535,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::logsumexp : (Tensor, int[], bool) -> (Tensor)")
|
||||
emit("aten::mean.dim : (Tensor, int[]?, bool, int?) -> (Tensor)")
|
||||
emit("aten::__and__.Tensor : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::__and__.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::__or__.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::_softmax : (Tensor, int, bool) -> (Tensor)")
|
||||
emit("aten::mean : (Tensor, int?) -> (Tensor)")
|
||||
|
|
|
@ -3070,6 +3070,51 @@ def ElementwiseOrTensorStaticShapeModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAndscalarModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.__and__(x, 12)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAndscalarModule())
|
||||
def ElementwiseAndScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10).to(torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAndScalarStaticShapeModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([3, 4], torch.int32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.__and__(x, 12)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAndScalarStaticShapeModule())
|
||||
def ElementwiseAndScalarStaticShapeModule_basic(module, tu: TestUtils):
|
||||
module.forward(
|
||||
tu.randint(3, 4, low=-10, high=10).to(torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseBitwiseXorModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
Loading…
Reference in New Issue