Add e2e support for aten logical or/and/xor/not ops (#1752)

pull/1757/head snapshot-20221226.699
Jiahao Li 2022-12-26 10:23:38 +08:00 committed by GitHub
parent 86d25e310b
commit eaab9be207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 480 additions and 10 deletions

View File

@ -102,6 +102,8 @@ MHLO_PASS_SET = {
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
"BroadcastZeroRankInputStaticModule_basic",
"ElementwiseAtenLogicalNotOpModule_basic",
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampMinModule_basic",

View File

@ -1021,6 +1021,145 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [
}];
}
def Torch_AtenLogicalAndOp : Torch_Op<"aten.logical_and", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logical_and : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalAndOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalAndOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenLogicalAnd_Op : Torch_Op<"aten.logical_and_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logical_and_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalAnd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalAnd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenLogicalXorOp : Torch_Op<"aten.logical_xor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logical_xor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalXorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalXorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenLogicalXor_Op : Torch_Op<"aten.logical_xor_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logical_xor_ : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalXor_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLogicalXor_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenLogicalNotOp : Torch_Op<"aten.logical_not", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::logical_not : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalNotOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogicalNotOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenLogicalNot_Op : Torch_Op<"aten.logical_not_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::logical_not_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogicalNot_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogicalNot_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,

View File

@ -215,7 +215,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
return b.create<arith::OrIOp>(loc, lhs, rhs);
}
if (auto logicalOr = dyn_cast<AtenLogicalOrOp>(op)) {
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
@ -224,7 +224,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero);
Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero);
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
if (isa<AtenLogicalOrOp>(op)) {
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenLogicalAndOp>(op)) {
return b.create<arith::AndIOp>(loc, lhsTest, rhsTest);
}
if (isa<AtenLogicalXorOp>(op)) {
return b.create<arith::XOrIOp>(loc, lhsTest, rhsTest);
}
llvm_unreachable("Unknown op type");
}
if (isa<AtenLogicalNotOp>(op)) {
MLIRContext *context = op->getContext();
Type floatDtype = mlir::FloatType::getF64(context);
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
Value zero =
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
return createEqual(b, loc, floatDtype, self, zero);
}
if (isa<AtenAbsOp>(op))
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
@ -1052,9 +1069,9 @@ public:
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(
op))
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1529,9 +1546,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp>();
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -447,9 +447,45 @@ public:
return success();
}
};
} // namespace
// Binary op legalizations for Logical And/Or/Xor.
namespace {
template <typename AtenOpT, typename ChloOpT>
class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<TensorType>();
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
DenseIntElementsAttr bcastDimensions;
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
bcastDimensions);
return success();
}
};
} // namespace
// AtenLogicalNotOp
template <>
LogicalResult ConvertAtenOp<AtenLogicalNotOp>::matchAndRewrite(
AtenLogicalNotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
TensorType outType =
getTypeConverter()->convertType(op.getType()).cast<TensorType>();
Value self = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
rewriter.replaceOpWithNewOp<mhlo::NotOp>(op, outType, self);
return success();
}
// AtenTransposeIntOp
namespace {
class ConvertAtenTransposeIntOp
@ -1389,6 +1425,16 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp);
#undef INSERT_BINARY_COMPARE_PATTERN
#define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenLogicalBinaryOp<AtenOp, ChloOp>>(typeConverter, \
context)
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp);
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp);
#undef INSERT_BINARY_LOGICAL_PATTERN
#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
@ -1401,6 +1447,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenContiguousOp);
INSERT_ATENOP_PATTERN(AtenLogicalNotOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenGeluOp);

View File

@ -6501,6 +6501,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logical_and\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logical_xor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.logical_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.threshold\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"

View File

@ -701,7 +701,8 @@ void TypeAnalysis::visitOperation(Operation *op,
// Dtype is always i1.
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp>(op)) {
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
auto knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
knowledge.dtype = IntegerType::get(op->getContext(), 1);

View File

@ -681,6 +681,15 @@ def atenbitwise_not〡shape(self: List[int]) -> List[int]:
def atenlogical_or〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenlogical_and〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenlogical_xor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atenlogical_not〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenthreshold〡shape(self: List[int], threshold: float, value: float) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -262,6 +262,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_or : (Tensor, Tensor) -> (Tensor)",
"aten::logical_and : (Tensor, Tensor) -> (Tensor)",
"aten::logical_xor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_not : (Tensor) -> (Tensor)",
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",

View File

@ -2192,6 +2192,130 @@ def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
# ==============================================================================
class ElementwiseAtenLogicalAndOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_and(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpModule())
def ElementwiseAtenLogicalAndOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool())
# ==============================================================================
class ElementwiseAtenLogicalAndOpPromoteBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float64, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_and(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastModule())
def ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5), tu.randint(4, 5, low=-1, high=2))
# ==============================================================================
class ElementwiseAtenLogicalXorOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
([-1, -1], torch.bool, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_xor(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpModule())
def ElementwiseAtenLogicalXorOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool())
# ==============================================================================
class ElementwiseAtenLogicalXorOpPromoteBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.float64, True),
([-1, -1], torch.int64, True),
])
def forward(self, x, y):
return torch.ops.aten.logical_xor(x, y)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastModule())
def ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5), tu.randint(4, 5, low=-1, high=2))
# ==============================================================================
class ElementwiseAtenLogicalNotOpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, x):
return torch.ops.aten.logical_not(x)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpModule())
def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, high=2).bool())
# ==============================================================================
class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.ops.aten.logical_not(x)
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpPromoteModule())
def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils):
module.forward(tu.randint(4, 5, low=-1, high=2))
# ==============================================================================
class ElementwiseAtenFloorDivideModule(torch.nn.Module):
def __init__(self):

View File

@ -624,3 +624,119 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_or$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T2:.*]] = chlo.broadcast_or %[[T0]], %[[T1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_or$promote(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor<?x?xi32>) -> tensor<?x?xi1>
// CHECK: %[[T4:.*]] = chlo.broadcast_or %[[T2]], %[[T3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_or$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_and$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T2:.*]] = chlo.broadcast_and %[[T0]], %[[T1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_and$promote(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor<?x?xi32>) -> tensor<?x?xi1>
// CHECK: %[[T4:.*]] = chlo.broadcast_and %[[T2]], %[[T3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_and$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_xor$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T2:.*]] = chlo.broadcast_xor %[[T0]], %[[T1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_xor$promote(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor<?x?xi32>) -> tensor<?x?xi1>
// CHECK: %[[T4:.*]] = chlo.broadcast_xor %[[T2]], %[[T3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_xor$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_not$basic(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
// CHECK: %[[T1:.*]] = mhlo.not %[[T0]] : tensor<?x?xi1>
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_not$basic(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func.func @torch.aten.logical_not$promote(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[T1:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[T2:.*]] = mhlo.not %[[T1]] : tensor<?x?xi1>
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
func.func @torch.aten.logical_not$promote(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}