mirror of https://github.com/llvm/torch-mlir
Add lowering to linalg for softplus and log1p
Follows existing conventions for unary operators.pull/1115/head
parent
44ead68772
commit
e8f327cc00
|
@ -2069,6 +2069,51 @@ def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLog1pOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenLog1p_Op::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
@ -2708,6 +2753,31 @@ def Torch_AtenFloorDivideOp : Torch_Op<"aten.floor_divide", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSoftplusOp : Torch_Op<"aten.softplus", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$beta,
|
||||
AnyTorchScalarType:$threshold
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSoftplusOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenSoftplusOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
|
|
|
@ -139,6 +139,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return createCalculationForMathOpWithDtypeConversion<math::Log2Op>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenLog1pOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::Log1pOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
}
|
||||
if (isa<AtenErfOp>(op)) {
|
||||
return createCalculationForMathOpWithDtypeConversion<math::ErfOp>(
|
||||
b, converter, payloadArgs[0], op);
|
||||
|
@ -922,14 +926,15 @@ public:
|
|||
AtenLerpTensorOp, AtenSigmoidOp, AtenExpOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp,
|
||||
AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp,
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenRsqrtOp, AtenDivScalarOp,
|
||||
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
|
||||
AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp,
|
||||
AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenSubScalarOp, AtenAddScalarOp,
|
||||
AtenThresholdOp, AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp,
|
||||
AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp,
|
||||
AtenLogicalOrOp, AtenTriuOp>(op))
|
||||
AtenPowTensorScalarOp, AtenLog2Op, AtenLog1pOp, AtenRsqrtOp,
|
||||
AtenDivScalarOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp,
|
||||
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
|
||||
AtenCeilOp, AtenGtTensorOp, AtenEqTensorOp, AtenLtTensorOp,
|
||||
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
|
||||
AtenThresholdBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
|
||||
AtenNeScalarOp, AtenNegOp, AtenMaskedFillScalarOp, AtenLogicalOrOp,
|
||||
AtenTriuOp>(op))
|
||||
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");
|
||||
|
||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||
|
@ -1662,7 +1667,7 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
|||
AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp,
|
||||
AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenRsubScalarOp, AtenLogOp,
|
||||
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenCeilOp, AtenPowTensorScalarOp,
|
||||
AtenLog2Op, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenLog2Op, AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp,
|
||||
AtenBitwiseAndTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp,
|
||||
AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp,
|
||||
AtenEqTensorOp, AtenLtTensorOp, AtenThresholdOp, AtenThresholdBackwardOp,
|
||||
|
|
|
@ -1172,6 +1172,40 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Softplus(x, beta, threshold) =
|
||||
// x * beta > threshold ? x : log(1 + exp(x * beta)) / beta
|
||||
namespace {
|
||||
class DecomposeAtenSoftplusOp : public OpRewritePattern<AtenSoftplusOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSoftplusOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
Value inputTimesBeta = rewriter.create<AtenMulScalarOp>(
|
||||
loc, inputType, input, op.beta());
|
||||
|
||||
// out = log1p(exp(input * beta)) / beta
|
||||
Value exp = rewriter.create<AtenExpOp>(loc, inputType, inputTimesBeta);
|
||||
Value log1p = rewriter.create<AtenLog1pOp>(loc, inputType, exp);
|
||||
Value out = rewriter.create<AtenDivScalarOp>(
|
||||
loc, inputType, log1p, op.beta());
|
||||
|
||||
// Select where x * beta > threshold
|
||||
auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(),
|
||||
rewriter.getI1Type());
|
||||
Value condition = rewriter.create<AtenGtScalarOp>(
|
||||
loc, boolResType, inputTimesBeta, op.threshold());
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(
|
||||
op, op.getType(), condition, input, out);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Hardsigmoid(x) = max(0, min(1, (x+3)/6))
|
||||
namespace {
|
||||
class DecomposeAtenHardsigmoidOp : public OpRewritePattern<AtenHardsigmoidOp> {
|
||||
|
@ -2344,6 +2378,8 @@ class DecomposeComplexOpsPass
|
|||
target.addIllegalOp<AtenHardsigmoidOp>();
|
||||
patterns.add<DecomposeAtenHardswishOp>(context);
|
||||
target.addIllegalOp<AtenHardswishOp>();
|
||||
patterns.add<DecomposeAtenSoftplusOp>(context);
|
||||
target.addIllegalOp<AtenSoftplusOp>();
|
||||
patterns.add<DecomposeAtenSiluOp>(context);
|
||||
target.addIllegalOp<AtenSiluOp>();
|
||||
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewZerosOp, AtenZerosOp>>(
|
||||
|
|
|
@ -664,8 +664,8 @@ void TypeAnalysis::visitOperation(Operation *op,
|
|||
|
||||
// Dtype is always float32, except for bfloat16, float64 and nullptr.
|
||||
if (isa<AtenTanhOp, AtenExpOp, AtenSinOp, AtenCosOp, AtenSigmoidOp,
|
||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenRsqrtOp,
|
||||
AtenErfOp>(op)) {
|
||||
AtenReciprocalOp, AtenLogOp, AtenSqrtOp, AtenLog2Op, AtenLog1pOp,
|
||||
AtenRsqrtOp, AtenErfOp, AtenSoftplusOp>(op)) {
|
||||
ValueKnowledge knowledge =
|
||||
ValueKnowledge::getTensorPessimisticValueState(op->getContext());
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
|
|
|
@ -5317,6 +5317,10 @@ module {
|
|||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.softplus"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.square"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
@ -5365,6 +5369,10 @@ module {
|
|||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.log1p"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
}
|
||||
func.func @"__torch_mlir_shape_fn.aten.rsqrt"(%arg0: !torch.list<int>) -> !torch.list<int> {
|
||||
%0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>
|
||||
return %0 : !torch.list<int>
|
||||
|
|
|
@ -300,6 +300,9 @@ def aten〇sigmoid(self: List[int]) -> List[int]:
|
|||
def aten〇hardsigmoid(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇softplus(self: List[int], beta: float = 1, threshold: float = 20) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇square(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -336,6 +339,9 @@ def aten〇detach(self: List[int]) -> List[int]:
|
|||
def aten〇log2(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇log1p(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇rsqrt(self: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
|
|
@ -282,6 +282,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
"aten::clamp_max : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log2 : (Tensor) -> (Tensor)",
|
||||
"aten::sqrt : (Tensor) -> (Tensor)",
|
||||
"aten::log1p : (Tensor) -> (Tensor)",
|
||||
"aten::rsqrt : (Tensor) -> (Tensor)",
|
||||
"aten::abs : (Tensor) -> (Tensor)",
|
||||
"aten::reciprocal : (Tensor) -> (Tensor)",
|
||||
|
@ -304,6 +305,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::pow.Tensor_Scalar : (Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::threshold_backward : (Tensor, Tensor, Scalar) -> (Tensor)")
|
||||
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||
|
||||
# Ops without value semantics but the corresponding without trailing
|
||||
# underscore variant doesn't exist.
|
||||
|
|
|
@ -959,6 +959,28 @@ def _LogSoftmaxModuleStable_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class SoftplusModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.softplus(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: SoftplusModule())
|
||||
def SoftplusModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class HardsigmoidModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
|
|
@ -844,6 +844,27 @@ class ElementwiseLogIntModule(torch.nn.Module):
|
|||
def ElementwiseLogIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(1, 10, (3, 4), dtype=torch.int32))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class ElementwiseLog1pModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.log1p(a)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseLog1pModule())
|
||||
def ElementwiseLog1pModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
|
@ -1189,3 +1189,22 @@ func.func @torch.aten.var.dim(%arg0: !torch.vtensor<[3,4,7],f32>) -> !torch.vten
|
|||
%0 = torch.aten.var.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,7],f32>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32>
|
||||
return %0 : !torch.vtensor<[3,4,1],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func.func @torch.aten.softplus(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.int) -> !torch.tensor<[2,3],f32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[VAL_3:.*]] = torch.aten.mul.Scalar %[[VAL_0]], %[[VAL_1]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[VAL_4:.*]] = torch.aten.exp %[[VAL_3]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch.aten.log1p %[[VAL_4]] : !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch.aten.div.Scalar %[[VAL_5]], %[[VAL_1]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: %[[VAL_7:.*]] = torch.aten.gt.Scalar %[[VAL_3]], %[[VAL_2]] : !torch.tensor<[2,3],f32>, !torch.int -> !torch.tensor<[2,3],i1>
|
||||
// CHECK: %[[VAL_8:.*]] = torch.aten.where.self %[[VAL_7]], %[[VAL_0]], %[[VAL_6]] : !torch.tensor<[2,3],i1>, !torch.tensor<[2,3],f32>, !torch.tensor<[2,3],f32> -> !torch.tensor<[2,3],f32>
|
||||
// CHECK: return %[[VAL_8]] : !torch.tensor<[2,3],f32>
|
||||
// CHECK: }
|
||||
func.func @torch.aten.softplus(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) -> !torch.tensor<[2,3],f32> {
|
||||
%int0 = torch.constant.int 0
|
||||
%ret = torch.aten.softplus %t, %dim, %int0: !torch.tensor<[2,3],f32>, !torch.int, !torch.int -> !torch.tensor<[2,3],f32>
|
||||
return %ret : !torch.tensor<[2,3],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue