mirror of https://github.com/llvm/torch-mlir
E2E support for AtenRemainderScalarOp (#1200)
parent
79b9cf9468
commit
dd2da5a038
|
@ -179,7 +179,7 @@ TOSA_PASS_SET = {
|
|||
"TransposeIntNegDimsModule_basic",
|
||||
"ArgmaxModule_keepDim",
|
||||
"ArgmaxModule_with_dim",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
"_LogSoftmaxModuleStable_basic",
|
||||
}
|
||||
|
||||
LTC_XFAIL_SET = {
|
||||
|
@ -338,4 +338,8 @@ LTC_XFAIL_SET = {
|
|||
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
|
||||
"AtenEmbeddingBagSumExample_basic",
|
||||
"Aten_EmbeddingBagExample_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Float_basic",
|
||||
"ElementwiseRemainderScalarModule_Int_basic",
|
||||
"ElementwiseRemainderScalarModule_Bool_basic",
|
||||
}
|
||||
|
|
|
@ -7910,6 +7910,30 @@ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$other
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenRemainderScalarOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void AtenRemainderScalarOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -803,6 +803,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||
return b.create<arith::DivFOp>(loc, self, other);
|
||||
}
|
||||
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
|
||||
Type newResultType = converter->convertType(remScalar.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
|
||||
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
|
||||
Value result;
|
||||
|
||||
if (newResultType.isa<mlir::FloatType>()) {
|
||||
result = b.create<arith::RemFOp>(loc, self, other);
|
||||
} else if (newResultType.isa<mlir::IntegerType>()) {
|
||||
result = b.create<arith::RemSIOp>(loc, self, other);
|
||||
} else {
|
||||
remScalar.emitError(
|
||||
"Unsupported type encountered for AtenRemainderScalarOp.");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
||||
Type dtype = converter->convertType(reciprocal.getType())
|
||||
.cast<RankedTensorType>()
|
||||
|
@ -943,14 +963,14 @@ public:
|
|||
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
|
||||
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
|
||||
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp,
|
||||
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp,
|
||||
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
|
||||
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp,
|
||||
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp,
|
||||
AtenTriuOp>(op))
|
||||
AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp, AtenTriuOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1688,7 +1708,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>();
|
||||
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
|
||||
patterns.add<ConvertElementwiseOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNllLossForwardOp>();
|
||||
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
|
||||
|
|
|
@ -770,7 +770,7 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
// Promote LHS with scalar RHS.
|
||||
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
|
||||
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
|
||||
AtenRsubScalarOp, AtenLeakyReluOp>(op)) {
|
||||
AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
|
||||
auto lhs = operands[0]->getValue();
|
||||
Value scalar = op->getOperand(1);
|
||||
auto knowledge =
|
||||
|
|
|
@ -5441,6 +5441,10 @@ module {
|
|||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -481,6 +481,9 @@ def aten〇mul〇Scalar(self: List[int], other: float) -> List[int]:
|
|||
def aten〇div〇Scalar(self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇remainder〇Scalar(self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇floor_divide〇Scalar(self: List[int], other: float) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -550,6 +550,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
|
||||
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::add.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
|
||||
emit("aten::mul.int : (int, int) -> (int)", has_folder=True)
|
||||
|
|
|
@ -1304,6 +1304,90 @@ class ElementwiseDivScalarModule(torch.nn.Module):
|
|||
def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
|
||||
def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3,), dtype=torch.int32))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2.0)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float())
|
||||
def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
|
||||
module.forward(torch.rand(10, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int())
|
||||
def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 2), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseRemainderScalarModule_Bool(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.bool, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.remainder(x, 2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Bool())
|
||||
def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils):
|
||||
module.forward(torch.tensor([True, False, True, True, True]))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue