Add lowering to linalg for softplus and log1p

Follows existing conventions for unary operators.
pull/1115/head
Kevin Kiningham 2022-07-17 12:00:29 -07:00 committed by Vivek Khandelwal
parent 44ead68772
commit e8f327cc00
10 changed files with 200 additions and 11 deletions

View File

@ -2069,6 +2069,51 @@ def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [
}];
}
def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog1pOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog1p_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [
AllowsTypeRefinement,
HasValueSemantics,
@ -2708,6 +2753,31 @@ def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [
}];
}
def Torch_AtenSoftplusOp : Torch_Op<"aten.softplus", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchScalarType:$beta,
AnyTorchScalarType:$threshold
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSoftplusOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenSoftplusOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
AllowsTypeRefinement
]> {

View File

@ -139,6 +139,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenLog1pOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
b, converter, payloadArgs[0], op);
}
if (isa<AtenErfOp>(op)) {
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
b, converter, payloadArgs[0], op);
@ -922,14 +926,15 @@ public:
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
AtenLogicalOrOp, AtenTriuOp>(op))
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp,
AtenTriuOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
@ -1662,7 +1667,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,

View File

@ -1172,6 +1172,40 @@ public:
};
} // namespace
// Softplus(x, beta, threshold) =
// x * beta > threshold ? x : log(1 + exp(x * beta)) / beta
namespace {
class DecomposeAtenSoftplusOp : public OpRewritePattern<AtenSoftplusOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSoftplusOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.self();
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
Value inputTimesBeta = rewriter.create<AtenMulScalarOp>(
loc, inputType, input, op.beta());
// out = log1p(exp(input * beta)) / beta
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
Value out = rewriter.create<AtenDivScalarOp>(
loc, inputType, log1p, op.beta());
// Select where x * beta > threshold
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
rewriter.getI1Type());
Value condition = rewriter.create<AtenGtScalarOp>(
loc, boolResType, inputTimesBeta, op.threshold());
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(
op, op.getType(), condition, input, out);
return success();
}
};
} // namespace
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
namespace {
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
@ -2344,6 +2378,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenHardsigmoidOp>();
patterns.add<DecomposeAtenHardswishOp>(context);
target.addIllegalOp<AtenHardswishOp>();
patterns.add<DecomposeAtenSoftplusOp>(context);
target.addIllegalOp<AtenSoftplusOp>();
patterns.add<DecomposeAtenSiluOp>(context);
target.addIllegalOp<AtenSiluOp>();
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(

View File

@ -664,8 +664,8 @@ void TypeAnalysis::visitOperation(Operation *op,
// Dtype is always float32, except for bfloat16, float64 and nullptr.
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
AtenErfOp>(op)) {
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp>(op)) {
ValueKnowledge knowledge =
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
Type dtype = operands[0]->getValue().dtype;

View File

@ -5317,6 +5317,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.softplus"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.square"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
@ -5365,6 +5369,10 @@ module {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.log1p"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>
}
func.func @"__torch_mlir_shape_fn.aten.rsqrt"(%arg0: !torch.list<int>) -> !torch.list<int> {
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
return %0 : !torch.list<int>

View File

@ -300,6 +300,9 @@ def atensigmoid(self: List[int]) -> List[int]:
def atenhardsigmoid(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atensoftplus(self: List[int], beta: float = 1, threshold: float = 20) -> List[int]:
return upstream_shape_functions.unary(self)
def atensquare(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
@ -336,6 +339,9 @@ def atendetach(self: List[int]) -> List[int]:
def atenlog2(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenlog1p(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)
def atenrsqrt(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

View File

@ -282,6 +282,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
"aten::log2 : (Tensor) -> (Tensor)",
"aten::sqrt : (Tensor) -> (Tensor)",
"aten::log1p : (Tensor) -> (Tensor)",
"aten::rsqrt : (Tensor) -> (Tensor)",
"aten::abs : (Tensor) -> (Tensor)",
"aten::reciprocal : (Tensor) -> (Tensor)",
@ -304,6 +305,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
# Ops without value semantics but the corresponding without trailing
# underscore variant doesn't exist.

View File

@ -959,6 +959,28 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils):
# ==============================================================================
class SoftplusModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.softplus(x)
@register_test_case(module_factory=lambda: SoftplusModule())
def SoftplusModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3))
# ==============================================================================
class HardsigmoidModule(torch.nn.Module):
def __init__(self):

View File

@ -844,6 +844,27 @@ class ElementwiseLogIntModule(torch.nn.Module):
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
# ==============================================================================
class ElementwiseLog1pModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, a):
return torch.log1p(a)
@register_test_case(module_factory=lambda: ElementwiseLog1pModule())
def ElementwiseLog1pModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
# ==============================================================================

View File

@ -1189,3 +1189,22 @@ func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vten
%0 = torch.aten.var.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32>
return %0 : !torch.vtensor<[3,4,1],f32>
}
// -----
// CHECK-LABEL: func.func @torch.aten.softplus(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.aten.mul.Scalar %[[VAL_0]], %[[VAL_1]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],f32>
// CHECK: %[[VAL_4:.*]] = torch.aten.exp %[[VAL_3]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[VAL_5:.*]] = torch.aten.log1p %[[VAL_4]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: %[[VAL_6:.*]] = torch.aten.div.Scalar %[[VAL_5]], %[[VAL_1]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],f32>
// CHECK: %[[VAL_7:.*]] = torch.aten.gt.Scalar %[[VAL_3]], %[[VAL_2]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],i1>
// CHECK: %[[VAL_8:.*]] = torch.aten.where.self %[[VAL_7]], %[[VAL_0]], %[[VAL_6]] : !torch.tensor<[2,3],i1>, !torch.tensor<[2,3],f32>, !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
// CHECK: return %[[VAL_8]] : !torch.tensor<[2,3],f32>
// CHECK: }
func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> {
%int0 = torch.constant.int 0
%ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
return %ret : !torch.tensor<[2,3],f32>
}