mirror of https://github.com/llvm/torch-mlir
parent
abf5c94a1b
commit
fc419b1e7d
|
@ -982,6 +982,53 @@ def Torch_AtenDiv_TensorOp : Torch_Op<"aten.div_.Tensor", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenLogicalOrOp : Torch_Op<"aten.logical_or", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::logical_or : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenLogicalOrOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenLogicalOrOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenLogicalOr_Op : Torch_Op<"aten.logical_or_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::logical_or_ : (Tensor, Tensor) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchTensorType:$other
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenLogicalOr_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenLogicalOr_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [
|
def Torch_AtenLerpTensorOp : Torch_Op<"aten.lerp.Tensor", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -190,6 +190,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
return b.create<arith::AndIOp>(loc, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
if (auto logicalOr = dyn_cast<AtenLogicalOrOp>(op)) {
|
||||||
|
MLIRContext *context = op->getContext();
|
||||||
|
Type floatDtype = mlir::FloatType::getF64(context);
|
||||||
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype);
|
||||||
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype);
|
||||||
|
Value zero =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getFloatAttr(floatDtype, 0));
|
||||||
|
Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero);
|
||||||
|
Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero);
|
||||||
|
return b.create<arith::OrIOp>(loc, lhsTest, rhsTest);
|
||||||
|
}
|
||||||
if (isa<AtenAbsOp>(op))
|
if (isa<AtenAbsOp>(op))
|
||||||
return b.create<math::AbsOp>(loc, payloadArgs[0]);
|
return b.create<math::AbsOp>(loc, payloadArgs[0]);
|
||||||
if (isa<AtenSigmoidOp>(op)) {
|
if (isa<AtenSigmoidOp>(op)) {
|
||||||
|
@ -844,7 +855,8 @@ public:
|
||||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp>(op))
|
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||||
|
AtenLogicalOrOp>(op))
|
||||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||||
|
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
|
@ -1581,7 +1593,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||||
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||||
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp,
|
||||||
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp>();
|
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
|
||||||
|
AtenLogicalOrOp>();
|
||||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||||
|
|
|
@ -673,7 +673,7 @@ ChangeResult TypeAnalyzer::visitOperation(
|
||||||
// Dtype is always i1.
|
// Dtype is always i1.
|
||||||
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
if (isa<AtenEqScalarOp, AtenGeScalarOp, AtenGtScalarOp, AtenLtScalarOp,
|
||||||
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
AtenLeScalarOp, AtenNeScalarOp, AtenAnyOp, AtenAllOp, AtenEqTensorOp,
|
||||||
AtenGtTensorOp, AtenLtTensorOp>(op)) {
|
AtenGtTensorOp, AtenLtTensorOp, AtenLogicalOrOp>(op)) {
|
||||||
auto knowledge =
|
auto knowledge =
|
||||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||||
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
knowledge.dtype = IntegerType::get(op->getContext(), 1);
|
||||||
|
|
|
@ -2155,6 +2155,10 @@ module {
|
||||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func.func @"__torch_mlir_shape_fn.aten.logical_or"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {
|
||||||
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>
|
||||||
|
return %0 : !torch.list<int>
|
||||||
|
}
|
||||||
func.func @"__torch_mlir_shape_fn.aten.threshold"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
func.func @"__torch_mlir_shape_fn.aten.threshold"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
||||||
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
%0 = call @__torch__.torch_mlir.dialects.torch.importer.jit_ir.build_tools.upstream_shape_helpers.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||||
return %0 : !torch.list<int>
|
return %0 : !torch.list<int>
|
||||||
|
|
|
@ -710,6 +710,9 @@ def aten〇maximum(self: List[int], other: List[int]) -> List[int]:
|
||||||
def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
def aten〇bitwise_and〇Tensor(self: List[int], other: List[int]) -> List[int]:
|
||||||
return upstream_shape_helpers.broadcast(self, other)
|
return upstream_shape_helpers.broadcast(self, other)
|
||||||
|
|
||||||
|
def aten〇logical_or(self: List[int], other: List[int]) -> List[int]:
|
||||||
|
return upstream_shape_helpers.broadcast(self, other)
|
||||||
|
|
||||||
def aten〇threshold(self: List[int], threshold: float, value: float) -> List[int]:
|
def aten〇threshold(self: List[int], threshold: float, value: float) -> List[int]:
|
||||||
return upstream_shape_helpers.unary(self)
|
return upstream_shape_helpers.unary(self)
|
||||||
|
|
||||||
|
|
|
@ -249,6 +249,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
"aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
|
"aten::logical_or : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
"aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
|
"aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)",
|
||||||
|
|
|
@ -1302,3 +1302,156 @@ class ElementwiseNegModule(torch.nn.Module):
|
||||||
@register_test_case(module_factory=lambda: ElementwiseNegModule())
|
@register_test_case(module_factory=lambda: ElementwiseNegModule())
|
||||||
def ElementwiseNegModule_basic(module, tu: TestUtils):
|
def ElementwiseNegModule_basic(module, tu: TestUtils):
|
||||||
module.forward(tu.rand(3, 4))
|
module.forward(tu.rand(3, 4))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.bool, True),
|
||||||
|
([-1], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpModule())
|
||||||
|
def ElementwiseAtenLogicalOrOpModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([False, True]), torch.tensor([False, False]))
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpDiffArgs1Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.float64, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs1Module())
|
||||||
|
def ElementwiseAtenLogicalOrOpDiffArgs1Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([0.2, 0.1]), torch.tensor([0, 1]))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpDiffArgs2Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.bool, True),
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs2Module())
|
||||||
|
def ElementwiseAtenLogicalOrOpDiffArgs2Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([True, False]), torch.tensor([0, 1]))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpDiffArgs3Module(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1], torch.bool, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpDiffArgs3Module())
|
||||||
|
def ElementwiseAtenLogicalOrOpDiffArgs3Module_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.tensor([1, 2]), torch.tensor([False, True]))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpRandomModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.int64, True),
|
||||||
|
([-1, -1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomModule())
|
||||||
|
def ElementwiseAtenLogicalOrOpRandomModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(3, 10, (2, 3, 4, 5)), torch.randint(10, 100, (2, 3, 4, 5)))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpRandomFloatModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
([-1, -1, -1, -1], torch.float32, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpRandomFloatModule())
|
||||||
|
def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.rand(2, 3, 3, 5), torch.rand(2, 3, 3, 5))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1, -1, -1], torch.int64, True),
|
||||||
|
([-1, -1, -1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpNegativeModule())
|
||||||
|
def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.neg(torch.randint(3, 10, (2, 3, 4, 5))), torch.neg(torch.randint(10, 100, (2, 3, 4, 5))))
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1], torch.int64, True),
|
||||||
|
([-1, -1], torch.int64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x, y):
|
||||||
|
return torch.ops.aten.logical_or(x, y)
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseAtenLogicalOrOpBrodcastModule())
|
||||||
|
def ElementwiseAtenLogicalOrOpBrodcastModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(torch.randint(3, (3,)), torch.randint(3, (4, 3)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue