mirror of https://github.com/llvm/torch-mlir
parent
267052df2a
commit
285b087a5d
|
@ -256,6 +256,106 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenRreluOp : Torch_Op<"aten.rrelu", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$lower,
|
||||||
|
AnyTorchScalarType:$upper,
|
||||||
|
Torch_BoolType:$training,
|
||||||
|
AnyTorchOptionalGeneratorType:$generator
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenRreluOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||||
|
}
|
||||||
|
void AtenRreluOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 5, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::rrelu_ : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
AnyTorchScalarType:$lower,
|
||||||
|
AnyTorchScalarType:$upper,
|
||||||
|
Torch_BoolType:$training,
|
||||||
|
AnyTorchOptionalGeneratorType:$generator
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalNonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenRrelu_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 5, 1);
|
||||||
|
}
|
||||||
|
void AtenRrelu_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 5, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchScalarType:$alpha
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenCeluOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
|
||||||
|
IsTrailingUnderscoreInplaceVariant,
|
||||||
|
AllowsTypeRefinement
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
Torch_NonValueTensorType:$self,
|
||||||
|
AnyTorchScalarType:$alpha
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalNonValueTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||||
|
}
|
||||||
|
void AtenCelu_Op::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 2, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
|
def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
@ -4810,53 +4910,6 @@ def Torch_AtenPreluOp : Torch_Op<"aten.prelu", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenCeluOp : Torch_Op<"aten.celu", [
|
|
||||||
AllowsTypeRefinement,
|
|
||||||
HasValueSemantics,
|
|
||||||
ReadOnly
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::celu : (Tensor, Scalar) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
AnyTorchTensorType:$self,
|
|
||||||
AnyTorchScalarType:$alpha
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchOptionalTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenCeluOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
|
||||||
}
|
|
||||||
void AtenCeluOp::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenCelu_Op : Torch_Op<"aten.celu_", [
|
|
||||||
IsTrailingUnderscoreInplaceVariant,
|
|
||||||
AllowsTypeRefinement
|
|
||||||
]> {
|
|
||||||
let summary = "Generated op for `aten::celu_ : (Tensor, Scalar) -> (Tensor)`";
|
|
||||||
let arguments = (ins
|
|
||||||
Torch_NonValueTensorType:$self,
|
|
||||||
AnyTorchScalarType:$alpha
|
|
||||||
);
|
|
||||||
let results = (outs
|
|
||||||
AnyTorchOptionalNonValueTensorType:$result
|
|
||||||
);
|
|
||||||
let hasCustomAssemblyFormat = 1;
|
|
||||||
let extraClassDefinition = [{
|
|
||||||
ParseResult AtenCelu_Op::parse(OpAsmParser &parser, OperationState &result) {
|
|
||||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
|
||||||
}
|
|
||||||
void AtenCelu_Op::print(OpAsmPrinter &printer) {
|
|
||||||
printDefaultTorchOp(printer, *this, 2, 1);
|
|
||||||
}
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
|
|
||||||
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
def Torch_AtenRealOp : Torch_Op<"aten.real", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
|
|
|
@ -7074,6 +7074,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.rrelu\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.list<int> {\n"
|
||||||
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
|
" return %0 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
|
@ -10610,6 +10614,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
|
||||||
|
" torch.prim.If.yield %true : !torch.bool\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.bool\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If %2 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: \"\n"
|
" %str = torch.constant.str \"AssertionError: \"\n"
|
||||||
|
|
|
@ -2520,6 +2520,77 @@ public:
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
// rrelu = max(0, x) + min(0, alpha * x)
|
||||||
|
// if in training mode, the alpha is sampled from uniform distribution (lower,
|
||||||
|
// upper) if in testing mode, the alpha is (lower + upper) / 2
|
||||||
|
namespace {
|
||||||
|
class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
|
||||||
|
public:
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
LogicalResult matchAndRewrite(AtenRreluOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Location loc = op.getLoc();
|
||||||
|
Value self = op.getSelf();
|
||||||
|
Value lower = op.getLower();
|
||||||
|
Value upper = op.getUpper();
|
||||||
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
|
if (!resType.hasDtype()) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool training;
|
||||||
|
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "training should be a constant");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value constantZeroFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||||
|
Value constantOneFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value constantTwoFloat =
|
||||||
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
|
||||||
|
|
||||||
|
Value alpha;
|
||||||
|
if (training) {
|
||||||
|
// Create a uniform random op with low and high set to `lower` and
|
||||||
|
// `upper`, respectively.
|
||||||
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
|
||||||
|
loc, resType, self, constantZeroFloat, /*dtype=*/none,
|
||||||
|
/*layout=*/none,
|
||||||
|
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
|
||||||
|
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
|
||||||
|
/*from=*/lower, /*to=*/upper,
|
||||||
|
/*generator=*/none);
|
||||||
|
} else {
|
||||||
|
Value half = rewriter.create<AtenAddOp>(loc, constantTwoFloat.getType(),
|
||||||
|
lower, upper);
|
||||||
|
alpha = rewriter.create<AtenDivOp>(loc, constantTwoFloat.getType(), half,
|
||||||
|
constantTwoFloat);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value zeroTensor =
|
||||||
|
createRank0Tensor(rewriter, loc, resType, constantZeroFloat);
|
||||||
|
Value positiveOutput =
|
||||||
|
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
|
||||||
|
|
||||||
|
Value scaledSelf;
|
||||||
|
if (training) {
|
||||||
|
scaledSelf = rewriter.create<AtenMulTensorOp>(loc, resType, self, alpha);
|
||||||
|
} else {
|
||||||
|
scaledSelf = rewriter.create<AtenMulScalarOp>(loc, resType, self, alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
Value negativeOutput =
|
||||||
|
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledSelf);
|
||||||
|
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
|
||||||
|
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
|
||||||
|
rewriter.replaceOp(op, rreluOutput);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
|
||||||
namespace {
|
namespace {
|
||||||
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
|
||||||
|
@ -8065,6 +8136,7 @@ public:
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenPreluOp>(patterns);
|
||||||
|
addPatternIfTargetOpIsIllegal<DecomposeAtenRreluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenCeluOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenEinsumOp>(patterns);
|
||||||
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
addPatternIfTargetOpIsIllegal<DecomposeAtenTraceOp>(patterns);
|
||||||
|
|
|
@ -483,6 +483,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
target.addIllegalOp<Aten_UnsafeIndexPutHackedTwinOp>();
|
||||||
target.addIllegalOp<AtenPadOp>();
|
target.addIllegalOp<AtenPadOp>();
|
||||||
target.addIllegalOp<AtenPreluOp>();
|
target.addIllegalOp<AtenPreluOp>();
|
||||||
|
target.addIllegalOp<AtenRreluOp>();
|
||||||
target.addIllegalOp<AtenCeluOp>();
|
target.addIllegalOp<AtenCeluOp>();
|
||||||
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
target.addIllegalOp<AtenToDtypeLayoutOp>();
|
||||||
target.addIllegalOp<AtenToDeviceOp>();
|
target.addIllegalOp<AtenToDeviceOp>();
|
||||||
|
|
|
@ -387,6 +387,10 @@ FX_IMPORTER_XFAIL_SET = {
|
||||||
"ElementwiseDequantizePerTensorModule_basic",
|
"ElementwiseDequantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorModule_basic",
|
"ElementwiseQuantizePerTensorModule_basic",
|
||||||
"ElementwiseQuantizePerTensorUIntModule_basic",
|
"ElementwiseQuantizePerTensorUIntModule_basic",
|
||||||
|
"ElementwiseRreluEvalModule_basic",
|
||||||
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
|
"ElementwiseRreluTrainModule_basic",
|
||||||
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||||
"EqIntModule_basic",
|
"EqIntModule_basic",
|
||||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||||
|
@ -1014,6 +1018,8 @@ STABLEHLO_PASS_SET = {
|
||||||
"ElementwiseRemainderTensorModule_Float_basic",
|
"ElementwiseRemainderTensorModule_Float_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
"ElementwiseRemainderTensorModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderTensorModule_Int_basic",
|
"ElementwiseRemainderTensorModule_Int_basic",
|
||||||
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
|
"ElementwiseRreluTrainStaticModule_basic",
|
||||||
"ElementwiseRsqrtModule_basic",
|
"ElementwiseRsqrtModule_basic",
|
||||||
"ElementwiseSigmoidModule_basic",
|
"ElementwiseSigmoidModule_basic",
|
||||||
"ElementwiseSinModule_basic",
|
"ElementwiseSinModule_basic",
|
||||||
|
@ -1692,6 +1698,8 @@ TOSA_PASS_SET = {
|
||||||
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
"ElementwiseRemainderScalarModule_Int_Float_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_basic",
|
"ElementwiseRemainderScalarModule_Int_basic",
|
||||||
"ElementwiseRemainderScalarModule_Int_basic",
|
"ElementwiseRemainderScalarModule_Int_basic",
|
||||||
|
"ElementwiseRreluEvalModule_basic",
|
||||||
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
"ElementwiseRsqrtModule_basic",
|
"ElementwiseRsqrtModule_basic",
|
||||||
"ElementwiseSeluModule_basic",
|
"ElementwiseSeluModule_basic",
|
||||||
"ElementwiseSigmoidModule_basic",
|
"ElementwiseSigmoidModule_basic",
|
||||||
|
@ -1978,6 +1986,9 @@ MAKE_FX_TOSA_PASS_SET = (
|
||||||
"ElementwisePreluModule_basic",
|
"ElementwisePreluModule_basic",
|
||||||
"ElementwisePreluStaticModule_basic",
|
"ElementwisePreluStaticModule_basic",
|
||||||
"ElementwiseLogSigmoidModule_basic",
|
"ElementwiseLogSigmoidModule_basic",
|
||||||
|
# failed to legalize operation 'torch.aten.rrelu_with_noise'
|
||||||
|
"ElementwiseRreluEvalModule_basic",
|
||||||
|
"ElementwiseRreluEvalStaticModule_basic",
|
||||||
# Shape Related failures
|
# Shape Related failures
|
||||||
"PrimListUnpackNumMismatchModule_basic",
|
"PrimListUnpackNumMismatchModule_basic",
|
||||||
"ReshapeExpandModule_basic",
|
"ReshapeExpandModule_basic",
|
||||||
|
|
|
@ -555,6 +555,9 @@ def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
|
||||||
def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
|
def aten〇celu〡shape(self: List[int], alpha: float = 1.) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]:
|
||||||
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
def aten〇selu〡shape(self: List[int]) -> List[int]:
|
||||||
return upstream_shape_functions.unary(self)
|
return upstream_shape_functions.unary(self)
|
||||||
|
|
||||||
|
@ -2723,6 +2726,12 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()}))
|
||||||
|
def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype)
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
|
||||||
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
|
|
@ -301,6 +301,8 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
"aten::relu : (Tensor) -> (Tensor)",
|
"aten::relu : (Tensor) -> (Tensor)",
|
||||||
"aten::relu6 : (Tensor) -> (Tensor)",
|
"aten::relu6 : (Tensor) -> (Tensor)",
|
||||||
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
|
||||||
|
"aten::rrelu : (Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)",
|
||||||
|
"aten::celu : (Tensor, Scalar) -> (Tensor)",
|
||||||
"aten::selu : (Tensor) -> (Tensor)",
|
"aten::selu : (Tensor) -> (Tensor)",
|
||||||
"aten::sigmoid : (Tensor) -> (Tensor)",
|
"aten::sigmoid : (Tensor) -> (Tensor)",
|
||||||
"aten::sinh : (Tensor) -> (Tensor)",
|
"aten::sinh : (Tensor) -> (Tensor)",
|
||||||
|
@ -472,7 +474,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::floor_divide : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
emit("aten::softplus : (Tensor, Scalar, Scalar) -> (Tensor)")
|
||||||
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
emit("aten::prelu : (Tensor, Tensor) -> (Tensor)")
|
||||||
emit_with_mutating_variants("aten::celu : (Tensor, Scalar) -> (Tensor)")
|
|
||||||
emit("aten::real : (Tensor) -> (Tensor)")
|
emit("aten::real : (Tensor) -> (Tensor)")
|
||||||
emit("aten::imag : (Tensor) -> (Tensor)")
|
emit("aten::imag : (Tensor) -> (Tensor)")
|
||||||
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
|
||||||
|
|
|
@ -1062,6 +1062,100 @@ def ElementwiseCeluModule_basic(module, tu: TestUtils):
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluTrainModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
res = torch.ops.aten.rrelu(x, 0.4, 0.6, True)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluTrainModule())
|
||||||
|
def ElementwiseRreluTrainModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1024, 1536))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluTrainStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([1024, 1536], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
res = torch.ops.aten.rrelu(x, 0.1, 0.9, True)
|
||||||
|
return torch.mean(res), torch.std(res)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule())
|
||||||
|
def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1024, 1536))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluEvalModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.rrelu(x, 0.4, 0.6, False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluEvalModule())
|
||||||
|
def ElementwiseRreluEvalModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ElementwiseRreluEvalStaticModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
([5, 3], torch.float32, True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.ops.aten.rrelu(x, 0.1, 0.9, False)
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: ElementwiseRreluEvalStaticModule())
|
||||||
|
def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(5, 3, low=-1, high=1))
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
class ElementwiseCeluStaticModule(torch.nn.Module):
|
class ElementwiseCeluStaticModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue