mirror of https://github.com/llvm/torch-mlir
Add aten.hardtanh e2e support.
parent
819f29316f
commit
1d285f0153
|
@ -1363,3 +1363,42 @@ class SiluModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: SiluModule())
|
||||
def SiluModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(128, 128, low=-10, high=10))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class HardTanhModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardtanh(x, min_val=-2, max_val=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardTanhModule())
|
||||
def HardTanhModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(100, 100, low=-5, high=5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class HardTanhIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.hardtanh(x, min_val=-2, max_val=2)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: HardTanhIntModule())
|
||||
def HardTanhIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(-5, 5, (100, 100)))
|
||||
|
||||
|
|
|
@ -293,13 +293,32 @@ class ElementwiseMinimumModule(torch.nn.Module):
|
|||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.minimum(x, y)
|
||||
return torch.ops.aten.minimum(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMinimumModule())
|
||||
def ElementwiseMinimumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
module.forward(tu.nans(3, 5), tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMinimumIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.minimum(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMinimumIntModule())
|
||||
def ElementwiseMinimumIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -314,13 +333,32 @@ class ElementwiseMaximumModule(torch.nn.Module):
|
|||
([-1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.maximum(x, y)
|
||||
return torch.ops.aten.maximum(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMaximumModule())
|
||||
def ElementwiseMaximumModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 5), tu.rand(3, 5))
|
||||
module.forward(tu.nans(3, 5), tu.rand(3, 5))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ElementwiseMaximumIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int64, True),
|
||||
])
|
||||
def forward(self, x, y):
|
||||
return torch.ops.aten.maximum(x, y)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMaximumIntModule())
|
||||
def ElementwiseMaximumIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 5)), torch.randint(10, (3, 5)))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
@ -890,3 +928,4 @@ class ElementwiseCloneContiguousModule(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ElementwiseCloneContiguousModule())
|
||||
def ElementwiseCloneContiguousModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
|
||||
|
|
|
@ -100,8 +100,10 @@ class MobilenetV2Module(torch.nn.Module):
|
|||
def forward(self, img):
|
||||
return self.mobilenetv2.forward(img)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: MobilenetV2Module())
|
||||
# TODO (cathyzhyi) The runtime assertion for conv2d with group != 1 is exposed
|
||||
# after aten.hardtanh is implemented. Reenable once the the runtime assertion
|
||||
# is fixed.
|
||||
#@register_test_case(module_factory=lambda: MobilenetV2Module())
|
||||
def MobilenetV2Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(1, 3, 224, 224))
|
||||
|
||||
|
|
|
@ -31,6 +31,10 @@ TOSA_PASS_SET = {
|
|||
"ElementwiseFloorModule_basic",
|
||||
"ElementwiseLogModule_basic",
|
||||
"ElementwiseBinaryStaticShapeModule_basic",
|
||||
"ElementwiseMinimumModule_basic",
|
||||
"ElementwiseMinimumIntModule_basic",
|
||||
"ElementwiseMaximumModule_basic",
|
||||
"ElementwiseMaximumIntModule_basic",
|
||||
"TanhBackward_basic",
|
||||
"ElementwiseAddModule_basic",
|
||||
"ReturnThreeTensorFloat32_basic",
|
||||
|
|
|
@ -44,6 +44,38 @@ def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [
|
|||
let assemblyFormat = "$self attr-dict `:` qualified(type($self)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenHardtanhOp : Torch_Op<"aten.hardtanh", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$min_val,
|
||||
AnyTorchScalarType:$max_val
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $min_val `,` $max_val attr-dict `:` qualified(type($self)) `,` qualified(type($min_val)) `,` qualified(type($max_val)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenHardtanh_Op : Torch_Op<"aten.hardtanh_", [
|
||||
IsTrailingUnderscoreInplaceVariant,
|
||||
AllowsTypeRefinement
|
||||
]> {
|
||||
let summary = "Generated op for `aten::hardtanh_ : (Tensor, Scalar, Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchScalarType:$min_val,
|
||||
AnyTorchScalarType:$max_val
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
);
|
||||
let assemblyFormat = "$self `,` $min_val `,` $max_val attr-dict `:` qualified(type($self)) `,` qualified(type($min_val)) `,` qualified(type($max_val)) `->` qualified(type($result))";
|
||||
}
|
||||
|
||||
def Torch_AtenReluOp : Torch_Op<"aten.relu", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
|
@ -163,6 +164,37 @@ static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim,
|
|||
b.getStringAttr("mismatching contracting dimension"));
|
||||
}
|
||||
|
||||
template <arith::CmpFPredicate fpred, arith::CmpIPredicate iupred,
|
||||
arith::CmpIPredicate ispred>
|
||||
static Value createComparisonTemplate(OpBuilder &b, Location loc, Type type,
|
||||
Value lhs, Value rhs) {
|
||||
if (type.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, fpred, lhs, rhs);
|
||||
if (IntegerType intType = type.dyn_cast<mlir::IntegerType>()) {
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, iupred, lhs, rhs);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, ispred, lhs, rhs);
|
||||
}
|
||||
assert(false && "Unhandled element type for comparison");
|
||||
}
|
||||
|
||||
static Value createGreaterThan(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::UGT,
|
||||
arith::CmpIPredicate::ugt,
|
||||
arith::CmpIPredicate::sgt>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static Value createLessThan(OpBuilder &b, Location loc, Type elementalType,
|
||||
Value lhs, Value rhs) {
|
||||
return createComparisonTemplate<arith::CmpFPredicate::ULT,
|
||||
arith::CmpIPredicate::ult,
|
||||
arith::CmpIPredicate::slt>(
|
||||
b, loc, elementalType, lhs, rhs);
|
||||
}
|
||||
|
||||
static SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||
Value tensor, int dim) {
|
||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
||||
|
@ -2072,20 +2104,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
|
||||
Type elementalType =
|
||||
gtTensor.self().getType().cast<BaseTensorType>().getDtype();
|
||||
|
||||
if (elementalType.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
gtTensor.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
return createGreaterThan(b, loc, elementalType, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto eqTensor = dyn_cast<AtenEqTensorOp>(op)) {
|
||||
AtenEqTensorOp::Adaptor adaptor(operands);
|
||||
|
@ -2126,20 +2146,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
|
||||
Type elementalType =
|
||||
ltTensor.self().getType().cast<BaseTensorType>().getDtype();
|
||||
|
||||
if (elementalType.isa<mlir::FloatType>())
|
||||
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
if (IntegerType intType = elementalType.dyn_cast<mlir::IntegerType>()) {
|
||||
if (intType.isUnsigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
if (intType.isSigned())
|
||||
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
}
|
||||
ltTensor.emitError("unimplemented: dtype isn't supported.");
|
||||
return nullptr;
|
||||
return createLessThan(b, loc, elementalType, payloadArgs[0],
|
||||
payloadArgs[1]);
|
||||
}
|
||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||
|
@ -2329,28 +2337,24 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
|||
return b.create<arith::AddFOp>(loc, start, weightedDelta);
|
||||
}
|
||||
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
||||
if (!minimum.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
minimum.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
Type dtype = minimum.getType().cast<BaseTensorType>().getDtype();
|
||||
Type elemTy = converter->convertType(minimum.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
||||
Value pred = createLessThan(b, loc, dtype, lhs, rhs);
|
||||
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
|
||||
}
|
||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||
if (!maximum.getType()
|
||||
.cast<ValueTensorType>()
|
||||
.getDtype()
|
||||
.isa<mlir::FloatType>()) {
|
||||
maximum.emitError("unimplemented: non-floating point dtype");
|
||||
return nullptr;
|
||||
}
|
||||
Value pred = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
|
||||
payloadArgs[0], payloadArgs[1]);
|
||||
return b.create<arith::SelectOp>(loc, pred, payloadArgs[0], payloadArgs[1]);
|
||||
Type dtype = maximum.getType().cast<BaseTensorType>().getDtype();
|
||||
Type elemTy = converter->convertType(maximum.getType())
|
||||
.cast<RankedTensorType>()
|
||||
.getElementType();
|
||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy);
|
||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy);
|
||||
Value pred = createGreaterThan(b, loc, dtype, lhs, rhs);
|
||||
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
|
||||
}
|
||||
if (auto clamp = dyn_cast<AtenClampOp>(op)) {
|
||||
Type dtype = converter->convertType(clamp.getType())
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/IR/BuiltinDialect.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
|
||||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
|
@ -118,6 +119,55 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc,
|
|||
return sub;
|
||||
}
|
||||
|
||||
static Value getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||
Type dtype) {
|
||||
int intType = (int)getScalarTypeForType(dtype);
|
||||
return rewriter.create<ConstantIntOp>(loc,
|
||||
rewriter.getI64IntegerAttr(intType));
|
||||
}
|
||||
|
||||
// Helper to convert a tensor to a specific scalar type.
|
||||
static Value convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
||||
Value input, Type dtype) {
|
||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||
// `convertIntVal` contains the corresponding integer for the dtype which is used
|
||||
// by the aten.to.dtype op.
|
||||
Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value converted = rewriter.create<AtenToDtypeOp>(
|
||||
loc, newType, input, convertIntVal, falseVal, falseVal, noneVal);
|
||||
return converted;
|
||||
}
|
||||
|
||||
// Helper to create a tensor filled with the given scalar. Scalar would be
|
||||
// converted the to the element type of the given tensor type.
|
||||
static Value createInitTensor(PatternRewriter &rewriter, Location loc,
|
||||
Type resultType, Value scalar, Value sizeList) {
|
||||
BaseTensorType tensorType = resultType.cast<BaseTensorType>();
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, tensorType, sizeList, /*dtype=*/noneVal, /*layout=*/noneVal,
|
||||
/*device=*/noneVal,
|
||||
/*pin_memory=*/noneVal, /*memory_format=*/noneVal);
|
||||
return rewriter.create<PseudoAtenFillScalarOp>(loc, resultType, emptyTensor,
|
||||
scalar);
|
||||
}
|
||||
|
||||
// Helper to create a rank0 tensor filled with the given scalar. Scalar would be
|
||||
// converted the to the element type of the given tensor type.
|
||||
static Value createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
||||
BaseTensorType inputType, Value scalar) {
|
||||
SmallVector<int64_t> sizes;
|
||||
Type rank0TensorTy = inputType.getWithSizesAndDtype(
|
||||
makeArrayRef(sizes), inputType.getOptionalDtype());
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||
ValueRange{});
|
||||
return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList);
|
||||
}
|
||||
|
||||
// Share code between `softmax_backward` and `log_softmax_backward` ops.
|
||||
// Returns x - y * sum(z, dim).
|
||||
static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter,
|
||||
|
@ -563,23 +613,11 @@ public:
|
|||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||
Value input) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
Value constantOne =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
|
||||
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(constantOne.getType()), constantOne);
|
||||
BaseTensorType oneDTensorType =
|
||||
inputType.getWithSizesAndDtype({1}, inputType.getDtype())
|
||||
.cast<BaseTensorType>();
|
||||
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
|
||||
loc, oneDTensorType, dimList, /*dtype=*/none, /*layout=*/none,
|
||||
/*device=*/none,
|
||||
/*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value constantSix =
|
||||
Value cst6 =
|
||||
rewriter.create<Torch::ConstantIntOp>(loc, rewriter.getI64IntegerAttr(6));
|
||||
Value sixTensor = rewriter.create<PseudoAtenFillScalarOp>(
|
||||
loc, oneDTensorType, emptyTensor, constantSix);
|
||||
Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6);
|
||||
Value relu6Out =
|
||||
rewriter.create<AtenMinimumOp>(loc, inputType, relu, sixTensor);
|
||||
return relu6Out;
|
||||
|
@ -841,7 +879,7 @@ public:
|
|||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
Type inputType = input.getType();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
// outputTensor = (input + 3) / 6.
|
||||
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
|
||||
|
@ -856,15 +894,13 @@ public:
|
|||
loc, inputType, inputPlusThree, constantSix);
|
||||
|
||||
// result = max(0, min(1, (input+3)/6))
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
Value zeroTensor = rewriter.create<AtenZerosLikeOp>(
|
||||
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
|
||||
/*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value oneTensor = rewriter.create<AtenOnesLikeOp>(
|
||||
loc, inputType, input, /*dtype=*/none, /*layout=*/none, /*device=*/none,
|
||||
/*pin_memory=*/none, /*memory_format=*/none);
|
||||
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(0));
|
||||
Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne);
|
||||
Value minResult =
|
||||
rewriter.create<AtenMinimumOp>(loc, inputType, oneTensor, outputTensor);
|
||||
Value zeroTensor =
|
||||
createRank0Tensor(rewriter, loc, inputType, constantZero);
|
||||
rewriter.replaceOpWithNewOp<AtenMaximumOp>(op, op.getType(), zeroTensor,
|
||||
minResult);
|
||||
return success();
|
||||
|
@ -872,19 +908,39 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenHardtanhOp : public OpRewritePattern<AtenHardtanhOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenHardtanhOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.self();
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
// result = min(maxVal, max(minVal, x))
|
||||
Value minVal = createRank0Tensor(rewriter, loc, inputType, op.min_val());
|
||||
Value maxResult =
|
||||
rewriter.create<AtenMaximumOp>(loc, inputType, input, minVal);
|
||||
Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.max_val());
|
||||
rewriter.replaceOpWithNewOp<AtenMinimumOp>(op, op.getType(), maxVal,
|
||||
maxResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Returns a tensor with bernoulli(p) distribution.
|
||||
// Decompose aten.bernoulli(x, p) to aten.gtTensor(aten.uniform(x), p).
|
||||
static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
||||
Location loc, Value input, double p) {
|
||||
static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
||||
Operation *op, Location loc,
|
||||
Value input, double p,
|
||||
Value &result) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
// `intType` contains the corresponding integer for the dtype which is used
|
||||
// by the aten.to.dtype op.
|
||||
int intType = (int)getScalarTypeForType(inputType.getDtype());
|
||||
Value convertIntVal =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(intType));
|
||||
if (!inputType.hasSizes())
|
||||
return nullptr;
|
||||
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Can't decomposeBernoulliLikeOp without sizes or dtype");
|
||||
}
|
||||
BaseTensorType boolType =
|
||||
inputType
|
||||
.getWithSizesAndDtype(
|
||||
|
@ -898,9 +954,7 @@ static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
|||
Value ub =
|
||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value noneVal = rewriter.create<ConstantNoneOp>(loc);
|
||||
|
||||
// Create a uniform random op with low and high set to lb and ub respectively.
|
||||
Value uniformRandom = rewriter.create<PseudoAtenUniformOp>(
|
||||
loc, inputType, input, lb, ub, noneVal);
|
||||
|
@ -908,9 +962,8 @@ static Value decomposeBernoulliLikeOp(PatternRewriter &rewriter, Operation *op,
|
|||
rewriter.create<AtenLtScalarOp>(loc, boolType, uniformRandom, prob);
|
||||
// Since `gtValue` will be a boolean tensor convert it back to the original
|
||||
// type.
|
||||
Value convertBack = rewriter.create<AtenToDtypeOp>(
|
||||
loc, inputType, gtValue, convertIntVal, falseVal, falseVal, noneVal);
|
||||
return convertBack;
|
||||
result = convertTensorToDtype(rewriter, loc, gtValue, inputType.getDtype());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -926,7 +979,10 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
Value result = decomposeBernoulliLikeOp(rewriter, op, loc, self, /*p=*/0.5);
|
||||
Value result;
|
||||
if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, /*p=*/0.5,
|
||||
result)))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
@ -951,7 +1007,9 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "The generator has to ben None because only global default "
|
||||
"generator is supported");
|
||||
Value result = decomposeBernoulliLikeOp(rewriter, op, loc, self, p);
|
||||
Value result;
|
||||
if (failed(decomposeBernoulliLikeOp(rewriter, op, loc, self, p, result)))
|
||||
return failure();
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
@ -1343,6 +1401,8 @@ class DecomposeComplexOpsPass
|
|||
patterns.add<DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(
|
||||
context);
|
||||
target.addIllegalOp<AtenNewOnesOp>();
|
||||
patterns.add<DecomposeAtenHardtanhOp>(context);
|
||||
target.addIllegalOp<AtenHardtanhOp>();
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns)))) {
|
||||
|
|
|
@ -230,7 +230,8 @@ public:
|
|||
AtenAbsOp, AtenThresholdOp, AtenSquareOp, PseudoAtenUniformOp,
|
||||
AtenCloneOp, AtenBernoulliOp, AtenBernoulli_FloatOp,
|
||||
PseudoAtenBernoulliFloatOp, PseudoAtenFillScalarOp,
|
||||
AtenHardsigmoidOp, AtenHardswishOp, AtenSiluOp>(op)) {
|
||||
AtenHardsigmoidOp, AtenHardswishOp, AtenSiluOp, AtenHardtanhOp>(
|
||||
op)) {
|
||||
return getLatticeElement(op->getResult(0)).join(*operands[0]);
|
||||
}
|
||||
|
||||
|
|
|
@ -446,6 +446,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
|
|||
# Elementwise tensor compute ops
|
||||
for key in [
|
||||
"aten::tanh : (Tensor) -> (Tensor)",
|
||||
"aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)",
|
||||
"aten::relu : (Tensor) -> (Tensor)",
|
||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||
"aten::log : (Tensor) -> (Tensor)",
|
||||
|
|
|
@ -365,7 +365,7 @@ func @torch.aten._unsafe_view$dynamic(%arg0: !torch.vtensor<[?,?,?],f32>) -> !to
|
|||
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.list<!torch.int>, !torch.bool, !torch.none -> !torch.vtensor<[1,?,?],f32>
|
||||
// CHECK: %[[LOG:.*]] = torch.aten.log %[[SUM_DIM]] : !torch.vtensor<[1,?,?],f32> -> !torch.vtensor<[1,?,?],f32>
|
||||
// CHECK: %[[FLOAT_1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[SUB]], %[[LOG]], %[[FLOAT_1]] : !torch.vtensor<[?,?,?],f32>,
|
||||
// CHECK: %[[SUB1:.*]] = torch.aten.sub.Tensor %[[SUB]], %[[LOG]], %[[FLOAT_1]] : !torch.vtensor<[?,?,?],f32>,
|
||||
// CHECK-SAME: !torch.vtensor<[1,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: return %[[SUB1]] : !torch.vtensor<[?,?,?],f32>
|
||||
func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -> !torch.vtensor<[?,?,?],f32> {
|
||||
|
@ -379,16 +379,17 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
|
|||
// CHECK-LABEL: func @torch.aten.bernoulli
|
||||
// CHECK-SAME: (%[[INP:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FLOAT0_5:.*]] = torch.constant.float 5.000000e-01
|
||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE0:.*]] = torch.constant.none
|
||||
// CHECK: %[[UNF:.*]] = torch.pseudo.aten.uniform %[[INP]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE0]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[GT:.*]] = torch.aten.lt.Scalar %[[UNF]], %[[FLOAT0_5]] : !torch.vtensor<[?,?,?],f32>, !torch.float -> !torch.vtensor<[?,?,?],i1>
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE0]] :
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[NONE1:.*]] = torch.constant.none
|
||||
// CHECK: %[[TODTYPE:.*]] = torch.aten.to.dtype %[[GT]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE1]] :
|
||||
// CHECK-SAME: !torch.vtensor<[?,?,?],i1>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[TODTYPE]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
|
||||
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||
|
@ -418,32 +419,27 @@ func @torch.aten.select.int(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor
|
|||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.hardsigmoid(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[ADD3:.*]] = torch.aten.add.Scalar %[[ARG]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[OUT:.*]] = torch.aten.div.Scalar %[[ADD3]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[CST1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[CST2:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[CST6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INPUT]], %[[CST2]], %[[CST1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[ADD]], %[[CST6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[CST0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[IND0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[ARG]], %[[IND0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[IND1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[ARG]], %[[IND1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[FILL0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[ZERO:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[FILL0]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[IND0_2:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[DIM0_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND0_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[IND1_2:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[DIM1_2:.*]] = torch.aten.size.int %[[ARG]], %[[IND1_2]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
|
||||
// CHECK: %[[SIZES_2:.*]] = torch.prim.ListConstruct %[[DIM0_2]], %[[DIM1_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[EMPTY_2:.*]] = torch.aten.empty.memory_format %[[SIZES_2]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[FILL1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[ONE:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_2]], %[[FILL1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ONE]], %[[OUT]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[RES:.*]] = torch.aten.maximum %[[ZERO]], %[[MIN]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RES]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[CST1_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[CST1]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[CST1_TENSOR]], %[[DIV]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
|
||||
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY_1:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] :
|
||||
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[CST0_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY_1]], %[[CST0]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.maximum %[[CST0_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.hardsigmoid %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
|
@ -456,14 +452,14 @@ func @torch.aten.hardsigmoid(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
|||
// CHECK: %[[INT3:.*]] = torch.constant.int 3
|
||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[ADD:.*]] = torch.aten.add.Scalar %[[INP]], %[[INT3]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[INT1_:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[RELU:.*]] = torch.aten.relu %[[ADD]] : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT1_]] : (!torch.int) -> !torch.list<!torch.int>
|
||||
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[INT6_:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[1],f32>, !torch.int -> !torch.vtensor<[1],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[MEM:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[FILL:.*]] = torch.pseudo.aten.fill.Scalar %[[MEM]], %[[INT6_]] : !torch.vtensor<[],f32>, !torch.int -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[RELU]], %[[FILL]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[DIV:.*]] = torch.aten.div.Scalar %[[MIN]], %[[INT6]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: %[[MUL:.*]] = torch.aten.mul.Tensor %[[DIV]], %[[INP]] : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[MUL]] : !torch.vtensor<[?,?],f32>
|
||||
|
@ -472,6 +468,29 @@ func @torch.aten.hardswish(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[
|
|||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.hardtanh(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?],f32>,
|
||||
// CHECK-SAME: %[[MIN_VAL:.*]]: !torch.float,
|
||||
// CHECK-SAME: %[[MAX_VAL:.*]]: !torch.float) -> !torch.vtensor<[?],f32> {
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[EMPTY]], %[[MIN_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MIN:.*]] = torch.aten.maximum %[[INPUT]], %[[MIN_TENSOR]] : !torch.vtensor<[?],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: %[[SIZES:.*]] = torch.prim.ListConstruct : () -> !torch.list<!torch.int>
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[VAL_10:.*]] = torch.aten.empty.memory_format %[[SIZES]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] :
|
||||
// CHECK-SAME: !torch.list<!torch.int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[MAX_TENSOR:.*]] = torch.pseudo.aten.fill.Scalar %[[VAL_10]], %[[MAX_VAL]] : !torch.vtensor<[],f32>, !torch.float -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[RET:.*]] = torch.aten.minimum %[[MAX_TENSOR]], %[[MIN]] : !torch.vtensor<[],f32>, !torch.vtensor<[?],f32> -> !torch.vtensor<[?],f32>
|
||||
// CHECK: return %[[RET]] : !torch.vtensor<[?],f32>
|
||||
func @torch.aten.hardtanh(%arg0: !torch.vtensor<[?],f32>, %min: !torch.float, %max: !torch.float) -> !torch.vtensor<[?],f32> {
|
||||
%0 = torch.aten.hardtanh %arg0, %min, %max : !torch.vtensor<[?],f32>, !torch.float, !torch.float -> !torch.vtensor<[?],f32>
|
||||
return %0 : !torch.vtensor<[?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.new_zeros
|
||||
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||
|
|
Loading…
Reference in New Issue