mirror of https://github.com/llvm/torch-mlir
parent
86d25e310b
commit
eaab9be207
|
@ -102,6 +102,8 @@ MHLO_PASS_SET = {
|
|||
"BroadcastToModule_basic",
|
||||
"BroadcastToSameRankStaticModule_basic",
|
||||
"BroadcastZeroRankInputStaticModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpModule_basic",
|
||||
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
|
||||
"ElementwiseAtenWhereSelfModule_basic",
|
||||
"ElementwiseClampModule_basic",
|
||||
"ElementwiseClampMinModule_basic",
|
||||
|
|
|
@ -1021,6 +1021,145 @@ def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogicalAndOp : Torch_Op<"aten.logical_and", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::logical_and : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogicalAndOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenLogicalAndOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogicalAnd_Op : Torch_Op<"aten.logical_and_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::logical_and_ : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogicalAnd_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenLogicalAnd_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogicalXorOp : Torch_Op<"aten.logical_xor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::logical_xor : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogicalXorOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenLogicalXorOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogicalXor_Op : Torch_Op<"aten.logical_xor_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::logical_xor_ : (Tensor, Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchTensorType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogicalXor_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenLogicalXor_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogicalNotOp : Torch_Op<"aten.logical_not", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::logical_not : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogicalNotOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLogicalNotOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLogicalNot_Op : Torch_Op<"aten.logical_not_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::logical_not_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLogicalNot_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLogicalNot_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -215,7 +215,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||
return b.create<arith::OrIOp>(loc, lhs, rhs);
|
||||
}
|
||||
if (auto logicalOr = dyn_cast<AtenLogicalOrOp>(op)) {
|
||||
if (isa<AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp>(op)) {
|
||||
MLIRContext *context = op->getContext();
|
||||
Type floatDtype = mlir::FloatType::getF64(context);
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
|
||||
|
@ -224,7 +224,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||
Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero);
|
||||
Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero);
|
||||
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
|
||||
if (isa<AtenLogicalOrOp>(op)) {
|
||||
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
|
||||
}
|
||||
if (isa<AtenLogicalAndOp>(op)) {
|
||||
return b.create<arith::AndIOp>(loc, lhsTest, rhsTest);
|
||||
}
|
||||
if (isa<AtenLogicalXorOp>(op)) {
|
||||
return b.create<arith::XOrIOp>(loc, lhsTest, rhsTest);
|
||||
}
|
||||
llvm_unreachable("Unknown op type");
|
||||
}
|
||||
if (isa<AtenLogicalNotOp>(op)) {
|
||||
MLIRContext *context = op->getContext();
|
||||
Type floatDtype = mlir::FloatType::getF64(context);
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
|
||||
Value zero =
|
||||
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||
return createEqual(b, loc, floatDtype, self, zero);
|
||||
}
|
||||
if (isa<AtenAbsOp>(op))
|
||||
return b.create<math::AbsFOp>(loc, payloadArgs[0]);
|
||||
|
@ -1052,9 +1069,9 @@ public:
|
|||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(
|
||||
op))
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenBitwiseNotOp,
|
||||
AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1529,9 +1546,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenTriuOp,
|
||||
AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
|
||||
AtenFillTensorOp>();
|
||||
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenRemainderScalarOp,
|
||||
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -447,9 +447,45 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Binary op legalizations for Logical And/Or/Xor.
|
||||
namespace {
|
||||
template <typename AtenOpT, typename ChloOpT>
|
||||
class ConvertAtenLogicalBinaryOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
using OpConversionPattern<AtenOpT>::OpConversionPattern;
|
||||
using OpAdaptor = typename AtenOpT::Adaptor;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
TensorType outType = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||
->convertType(op.getType())
|
||||
.template cast<TensorType>();
|
||||
Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType);
|
||||
|
||||
DenseIntElementsAttr bcastDimensions;
|
||||
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs,
|
||||
bcastDimensions);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// AtenLogicalNotOp
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenLogicalNotOp>::matchAndRewrite(
|
||||
AtenLogicalNotOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
TensorType outType =
|
||||
getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
||||
Value self = mhlo::promoteType(rewriter, adaptor.getSelf(), outType);
|
||||
rewriter.replaceOpWithNewOp<mhlo::NotOp>(op, outType, self);
|
||||
return success();
|
||||
}
|
||||
|
||||
// AtenTransposeIntOp
|
||||
namespace {
|
||||
class ConvertAtenTransposeIntOp
|
||||
|
@ -1389,6 +1425,16 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp);
|
||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
||||
#define INSERT_BINARY_LOGICAL_PATTERN(AtenOp, ChloOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenLogicalBinaryOp<AtenOp, ChloOp>>(typeConverter, \
|
||||
context)
|
||||
|
||||
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalOrOp, chlo::BroadcastOrOp);
|
||||
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalAndOp, chlo::BroadcastAndOp);
|
||||
INSERT_BINARY_LOGICAL_PATTERN(AtenLogicalXorOp, chlo::BroadcastXorOp);
|
||||
#undef INSERT_BINARY_LOGICAL_PATTERN
|
||||
|
||||
#define INSERT_ATENOP_PATTERN(AtenOp) \
|
||||
target.addIllegalOp<AtenOp>(); \
|
||||
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context, options)
|
||||
|
@ -1401,6 +1447,7 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
|
|||
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
|
||||
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
|
||||
INSERT_ATENOP_PATTERN(AtenContiguousOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLogicalNotOp);
|
||||
|
||||
INSERT_ATENOP_PATTERN(AtenReluOp);
|
||||
INSERT_ATENOP_PATTERN(AtenGeluOp);
|
||||
|
|
|
@ -6501,6 +6501,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.logical_and\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.logical_xor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.logical_not\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.threshold\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
|
|
@ -701,7 +701,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
// Dtype is always i1.
|
||||
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
||||
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp>(op)) {
|
||||
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
|
||||
AtenLogicalXorOp, AtenLogicalNotOp>(op)) {
|
||||
auto knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||
|
|
|
@ -681,6 +681,15 @@ def aten〇bitwise_not〡shape(self: List[int]) -> List[int]:
|
|||
def aten〇logical_or〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇logical_and〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇logical_xor〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.broadcast(self, other)
|
||||
|
||||
def aten〇logical_not〡shape(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇threshold〡shape(self: List[int], threshold: float, value: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -262,6 +262,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::bitwise_not : (Tensor) -> (Tensor)",
|
||||
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::logical_or : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::logical_and : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::logical_xor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::logical_not : (Tensor) -> (Tensor)",
|
||||
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||
|
|
|
@ -2192,6 +2192,130 @@ def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalAndOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_and(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpModule())
|
||||
def ElementwiseAtenLogicalAndOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalAndOpPromoteBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_and(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalAndOpPromoteBroadcastModule())
|
||||
def ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5), tu.randint(4, 5, low=-1, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalXorOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_xor(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpModule())
|
||||
def ElementwiseAtenLogicalXorOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=2).bool(), tu.randint(4, 5, high=2).bool())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalXorOpPromoteBroadcastModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.logical_xor(x, y)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalXorOpPromoteBroadcastModule())
|
||||
def ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5), tu.randint(4, 5, low=-1, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalNotOpModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.logical_not(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpModule())
|
||||
def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, high=2).bool())
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.logical_not(x)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalNotOpPromoteModule())
|
||||
def ElementwiseAtenLogicalNotOpPromoteModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(4, 5, low=-1, high=2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseAtenFloorDivideModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -624,3 +624,119 @@ func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>
|
|||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_or$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_or %[[T0]], %[[T1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_or$promote(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_or %[[T2]], %[[T3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_or$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_and$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_and %[[T0]], %[[T1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_and$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_and$promote(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_and %[[T2]], %[[T3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_and$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_xor$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>, %[[ARG1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T2:.*]] = chlo.broadcast_xor %[[T0]], %[[T1]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_xor$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_xor$promote(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T1]] : (tensor<?x?xi32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T4:.*]] = chlo.broadcast_xor %[[T2]], %[[T3]] : (tensor<?x?xi1>, tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_xor$promote(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_not$basic(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],i1> -> tensor<?x?xi1>
|
||||
// CHECK: %[[T1:.*]] = mhlo.not %[[T0]] : tensor<?x?xi1>
|
||||
// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T2]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_not$basic(%arg0: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.logical_not$promote(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[T1:.*]] = mhlo.convert %[[T0]] : (tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[T2:.*]] = mhlo.not %[[T1]] : tensor<?x?xi1>
|
||||
// CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[T3]] : !torch.vtensor<[?,?],i1>
|
||||
func.func @torch.aten.logical_not$promote(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.logical_not %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue