E2E support for AtenRemainderScalarOp (#1119)

* E2E support for AtenRemainderScalarOp
pull/1183/head
Vidush Singhal 2022-08-08 20:02:52 -04:00 committed by GitHub
parent b70548edff
commit 34e207eeb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 134 additions and 18 deletions

View File

@ -178,7 +178,10 @@ TOSA_PASS_SET = {
"TransposeIntNegDimsModule_basic", "TransposeIntNegDimsModule_basic",
"ArgmaxModule_keepDim", "ArgmaxModule_keepDim",
"ArgmaxModule_with_dim", "ArgmaxModule_with_dim",
"_LogSoftmaxModuleStable_basic", "_LogSoftmaxModuleStable_basic",
"ElementwiseRemainderScalarModule_Int_Float_basic",
"ElementwiseRemainderScalarModule_Float_basic",
"ElementwiseRemainderScalarModule_Int_basic",
} }
LTC_XFAIL_SET = { LTC_XFAIL_SET = {

View File

@ -7883,6 +7883,30 @@ def Torch_AtenRemainderIntOp : Torch_Op<"aten.remainder.int", [
let hasFolder = 1; let hasFolder = 1;
} }
def Torch_AtenRemainderScalarOp : Torch_Op<"aten.remainder.Scalar", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$other
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenRemainderScalarOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenRemainderScalarOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [ def Torch_AtenAddIntOp : Torch_Op<"aten.add.int", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -803,6 +803,23 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
Value other = convertScalarToDtype(b, loc, operands[1], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype);
return b.create<arith::DivFOp>(loc, self, other); return b.create<arith::DivFOp>(loc, self, other);
} }
if (auto remScalar = dyn_cast<AtenRemainderScalarOp>(op)) {
Type newResultType = converter->convertType(remScalar.getType())
.cast<RankedTensorType>()
.getElementType();
Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, operands[1], newResultType);
Value result;
if (newResultType.isa<mlir::FloatType>()) {
result = b.create<arith::RemFOp>(loc, self, other);
} else {
result = b.create<arith::RemSIOp>(loc, self, other);
}
return result;
}
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) { if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
Type dtype = converter->convertType(reciprocal.getType()) Type dtype = converter->convertType(reciprocal.getType())
.cast<RankedTensorType>() .cast<RankedTensorType>()
@ -943,14 +960,14 @@ public:
AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenMinimumOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp,
AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenRsubScalarOp, AtenMulScalarOp, AtenLogOp, AtenErfOp,
AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op, AtenSqrtOp, AtenFloorOp, AtenPowTensorScalarOp, AtenLog2Op,
AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenAbsOp, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, AtenRemainderScalarOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenTriuOp>(op)) AtenLogicalOrOp, AtenTriuOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1681,14 +1698,14 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp, AtenTanhOp, AtenReluOp, AtenLeakyReluOp, AtenGeluOp, AtenGeluBackwardOp,
AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp,
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillScalarOp,
AtenNeScalarOp, AtenMaskedFillScalarOp, AtenLogicalOrOp, AtenTriuOp>(); AtenLogicalOrOp, AtenTriuOp, AtenRemainderScalarOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context); patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>(); target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context); patterns.add<ConvertAtenDetachOp>(typeConverter, context);

View File

@ -770,7 +770,7 @@ void TypeAnalysis::visitOperation(Operation *op,
// Promote LHS with scalar RHS. // Promote LHS with scalar RHS.
if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp, if (isa<AtenAddScalarOp, AtenSubScalarOp, AtenMulScalarOp, AtenDivScalarOp,
AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp, AtenFmodScalarOp, AtenFloorDivideScalarOp, AtenPowTensorScalarOp,
AtenRsubScalarOp, AtenLeakyReluOp>(op)) { AtenRsubScalarOp, AtenLeakyReluOp, AtenRemainderScalarOp>(op)) {
auto lhs = operands[0]->getValue(); auto lhs = operands[0]->getValue();
Value scalar = op->getOperand(1); Value scalar = op->getOperand(1);
auto knowledge = auto knowledge =

View File

@ -5441,6 +5441,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>
} }
func.func @"__torch_mlir_shape_fn.aten.remainder.Scalar"(%arg0: !torch.list<int>, %arg1: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> { func.func @"__torch_mlir_shape_fn.aten.to.dtype"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool, %arg3: !torch.bool, %arg4: !torch.optional<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int> %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int> return %0 : !torch.list<int>

View File

@ -478,6 +478,9 @@ def atenmulScalar(self: List[int], other: float) -> List[int]:
def atendivScalar(self: List[int], other: float) -> List[int]: def atendivScalar(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def atenremainderScalar(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self)
def atenfloor_divideScalar(self: List[int], other: float) -> List[int]: def atenfloor_divideScalar(self: List[int], other: float) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)

View File

@ -549,6 +549,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::eq.int : (int, int) -> (bool)", has_folder=True) emit("aten::eq.int : (int, int) -> (bool)", has_folder=True)
emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True) emit("aten::floordiv.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.int : (int, int) -> (int)", has_folder=True) emit("aten::remainder.int : (int, int) -> (int)", has_folder=True)
emit("aten::remainder.Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::add.int : (int, int) -> (int)", has_folder=True) emit("aten::add.int : (int, int) -> (int)", has_folder=True)
emit("aten::sub.int : (int, int) -> (int)", has_folder=True) emit("aten::sub.int : (int, int) -> (int)", has_folder=True)
emit("aten::mul.int : (int, int) -> (int)", has_folder=True) emit("aten::mul.int : (int, int) -> (int)", has_folder=True)

View File

@ -1304,6 +1304,70 @@ class ElementwiseDivScalarModule(torch.nn.Module):
def ElementwiseDivScalarModule_basic(module, tu: TestUtils): def ElementwiseDivScalarModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4)) module.forward(tu.rand(3, 4))
# ==============================================================================
class ElementwiseRemainderScalarModule_Int_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int64, True),
])
def forward(self, x):
return torch.remainder(x, 2.0)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int_Float())
def ElementwiseRemainderScalarModule_Int_Float_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3,)))
# ==============================================================================
class ElementwiseRemainderScalarModule_Float(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float, True),
])
def forward(self, x):
return torch.remainder(x, 2.0)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Float())
def ElementwiseRemainderScalarModule_Float_basic(module, tu: TestUtils):
module.forward(torch.rand(10, 3))
# ==============================================================================
class ElementwiseRemainderScalarModule_Int(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
])
def forward(self, x):
return torch.remainder(x, 2)
@register_test_case(module_factory=lambda: ElementwiseRemainderScalarModule_Int())
def ElementwiseRemainderScalarModule_Int_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 2)))
# ============================================================================== # ==============================================================================