[Torch] emit and lowering frac, signbit, ldexp, copysign ops (#3851)

also fix `aten.exp2` with integer type
byteir
Yuanqiang Liu 2024-11-06 10:21:37 +08:00
parent f0f59d0f5b
commit c0ec22df2c
8 changed files with 663 additions and 2 deletions

View File

@ -1483,6 +1483,51 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [
}]; }];
} }
def Torch_AtenFracOp : Torch_Op<"aten.frac", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::frac : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFracOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFracOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenFrac_Op : Torch_Op<"aten.frac_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::frac_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFrac_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenFrac_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -3400,6 +3445,53 @@ def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [
}]; }];
} }
def Torch_AtenCopysignTensorOp : Torch_Op<"aten.copysign.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::copysign.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCopysignTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCopysignTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenCopysign_TensorOp : Torch_Op<"aten.copysign_.Tensor", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::copysign_.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_NonValueTensorType:$other
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenCopysign_TensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenCopysign_TensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [
AllowsTypeRefinement, AllowsTypeRefinement,
ReadOnly ReadOnly
@ -3850,6 +3942,53 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [
}]; }];
} }
def Torch_AtenLdexpTensorOp : Torch_Op<"aten.ldexp.Tensor", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::ldexp.Tensor : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLdexpTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenLdexpTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}
def Torch_AtenSignbitOp : Torch_Op<"aten.signbit", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::signbit : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSignbitOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSignbitOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -9154,6 +9154,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.frac\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.signbit\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.ldexp.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.copysign.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.__and__.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.__and__.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
@ -12825,6 +12839,93 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n" " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n" " return %4 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.frac\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.signbit\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.ldexp.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %false = torch.constant.bool false\n"
" %int7 = torch.constant.int 7\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.eq.int %0#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %6 = torch.prim.If %5 -> (!torch.int) {\n"
" torch.prim.If.yield %1#1 : !torch.int\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %8 = torch.prim.If %7 -> (!torch.bool) {\n"
" %10 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If.yield %10 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %9 = torch.prim.If %8 -> (!torch.int) {\n"
" torch.prim.If.yield %0#1 : !torch.int\n"
" } else {\n"
" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %11 = torch.prim.If %10 -> (!torch.bool) {\n"
" %13 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %13 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %13 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" torch.prim.If.yield %13 : !torch.int\n"
" }\n"
" torch.prim.If.yield %12 : !torch.int\n"
" }\n"
" torch.prim.If.yield %9 : !torch.int\n"
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.copysign.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %false = torch.constant.bool false\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %5 = torch.prim.If %4 -> (!torch.bool) {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If.yield %7 : !torch.bool\n"
" } else {\n"
" torch.prim.If.yield %false : !torch.bool\n"
" }\n"
" %6 = torch.prim.If %5 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" torch.prim.If.yield %7 : !torch.int\n"
" }\n"
" return %6 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.__and__.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.__and__.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"

View File

@ -7490,6 +7490,160 @@ class DecomposeAtenTruncOp : public OpRewritePattern<AtenTruncOp> {
}; };
} // namespace } // namespace
namespace {
// decompose `signbit(x)` to `view.dtype(x, si32/si64) < 0 `
class DecomposeAtenSignbitOp : public OpRewritePattern<AtenSignbitOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSignbitOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
auto operandTy = dyn_cast<ValueTensorType>(self.getType());
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!operandTy || !operandTy.hasDtype() || !resultTy ||
!resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op,
"operand and result must have dtype");
}
if (isa<mlir::FloatType>(operandTy.getDtype())) {
mlir::IntegerType intType = rewriter.getIntegerType(
operandTy.getDtype().getIntOrFloatBitWidth(), /*isSigned*/ true);
Value dtype = getDtypeIntValueForType(rewriter, loc, intType);
Value view = rewriter.create<AtenViewDtypeOp>(
loc,
operandTy.getWithSizesAndDtype(operandTy.getOptionalSizes(), intType),
self, dtype);
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value shift = rewriter.create<AtenLtScalarOp>(loc, resultTy, view, zero);
rewriter.replaceOp(op, shift);
return success();
} else if (isa<mlir::IntegerType>(operandTy.getDtype())) {
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value shift = rewriter.create<AtenLtScalarOp>(loc, resultTy, self, zero);
rewriter.replaceOp(op, shift);
}
return failure();
}
};
} // namespace
namespace {
// decompose `frac(x)` to `x - trunc(x)`
class DecomposeAtenFracOp : public OpRewritePattern<AtenFracOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFracOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
auto resultTy = op.getType();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value trunc = rewriter.create<AtenTruncOp>(loc, resultTy, self);
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, trunc,
/*alpha=*/one);
return success();
}
};
} // namespace
namespace {
// decompose `copysign(x, y)` to `signbit(y) ? -abs(x) : abs(x)`
class DecomposeAtenCopysignTensorOp
: public OpRewritePattern<AtenCopysignTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCopysignTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value other = op.getOther();
auto selfTy = self.getType();
auto otherTy = cast<BaseTensorType>(other.getType());
auto resultTy = op.getType();
Value signbit = rewriter.create<AtenSignbitOp>(
loc,
otherTy.getWithSizesAndDtype(otherTy.getOptionalSizes(),
rewriter.getI1Type()),
other);
Value abs = rewriter.create<AtenAbsOp>(loc, selfTy, self);
Value neg = rewriter.create<AtenNegOp>(loc, selfTy, abs);
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, signbit, neg,
abs);
return success();
}
};
} // namespace
namespace {
// decompose `ldexp(x, y)` to `x * 2^y`
class DecomposeAtenLdexpTensorOp : public OpRewritePattern<AtenLdexpTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLdexpTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value other = op.getOther();
auto otherTy = dyn_cast<BaseTensorType>(other.getType());
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result must have dtype");
}
Value exp2 = rewriter.create<AtenExp2Op>(
loc,
resultTy.getWithSizesAndDtype(otherTy.getOptionalSizes(),
resultTy.getDtype()),
other);
rewriter.replaceOpWithNewOp<AtenMulTensorOp>(op, resultTy, self, exp2);
return success();
}
};
} // namespace
namespace {
// decompose `fmod(x, y)` to `x - trunc(x/y) * y`
class DecomposeAtenFmodTensorOp : public OpRewritePattern<AtenFmodTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFmodTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value other = op.getOther();
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result must have dtype");
}
if (isa<mlir::IntegerType>(resultTy.getDtype())) {
Value div = rewriter.create<AtenDivTensorOp>(loc, resultTy, self, other);
Value mul = rewriter.create<AtenMulTensorOp>(loc, resultTy, div, other);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, mul,
alpha);
return success();
} else if (isa<mlir::FloatType>(resultTy.getDtype())) {
Value div = rewriter.create<AtenDivTensorOp>(loc, resultTy, self, other);
Value trunc = rewriter.create<AtenTruncOp>(loc, resultTy, div);
Value mul = rewriter.create<AtenMulTensorOp>(loc, resultTy, trunc, other);
Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenSubTensorOp>(op, resultTy, self, mul,
alpha);
return success();
}
return failure();
}
};
} // namespace
namespace { namespace {
// Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and // Decompose `aten.baddbmm` op into `aten.bmm`, `aten.mul.Scalar`, and
// `aten.add.Tensor` op. // `aten.add.Tensor` op.
@ -8598,10 +8752,16 @@ class DecomposeAtenExp2Op : public OpRewritePattern<AtenExp2Op> {
Location loc = op.getLoc(); Location loc = op.getLoc();
Value self = op.getSelf(); Value self = op.getSelf();
auto resultTy = dyn_cast<ValueTensorType>(op.getType());
if (!resultTy || !resultTy.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result must have dtype");
}
auto two = auto two =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2)); rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(2));
rewriter.replaceOpWithNewOp<AtenPowScalarOp>(op, op.getType(), two, self); Value to = convertTensorToDtype(rewriter, loc, self, resultTy.getDtype());
Value pow = rewriter.create<AtenPowScalarOp>(loc, resultTy, two, to);
rewriter.replaceOp(op, pow);
return success(); return success();
} }
}; };
@ -9696,6 +9856,11 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenRad2degOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenRad2degOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenCosineSimilarityOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenTruncOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignbitOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFracOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopysignTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLdexpTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFmodTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenBaddbmmOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns); addPatternIfTargetOpIsIllegal<DecomposeAtenFloorDivideScalarOp>(patterns);

View File

@ -535,6 +535,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRad2degOp>(); target.addIllegalOp<AtenRad2degOp>();
target.addIllegalOp<AtenCosineSimilarityOp>(); target.addIllegalOp<AtenCosineSimilarityOp>();
target.addIllegalOp<AtenTruncOp>(); target.addIllegalOp<AtenTruncOp>();
target.addIllegalOp<AtenSignbitOp>();
target.addIllegalOp<AtenFracOp>();
target.addIllegalOp<AtenCopysignTensorOp>();
target.addIllegalOp<AtenLdexpTensorOp>();
target.addIllegalOp<AtenNewEmptyStridedOp>(); target.addIllegalOp<AtenNewEmptyStridedOp>();
target.addIllegalOp<AtenEmptyStridedOp>(); target.addIllegalOp<AtenEmptyStridedOp>();
target.addIllegalOp<AtenBucketizeTensorOp>(); target.addIllegalOp<AtenBucketizeTensorOp>();

View File

@ -500,6 +500,52 @@ FX_IMPORTER_XFAIL_SET = {
"ViewSizeFromOtherTensor_basic", "ViewSizeFromOtherTensor_basic",
"ViewDtypeStaticModule_basic", "ViewDtypeStaticModule_basic",
"WeightNormInterfaceModule_basic", "WeightNormInterfaceModule_basic",
# Error: `aten.as_strided` op is not supported
"ChunkListUnpackDynamic_Module_basic",
"ChunkListUnpackUnevenDynamic_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"ChunkListUnpack_Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic",
"AdaptiveMaxPool1dDynamic_basic",
"AdaptiveMaxPool1dStatic_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl2DImplicitModule_basic",
"IndexPutImpl2DIndexModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"InterpolateDynamicModule_sizes_nearest",
"IouOfModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"OneHotModule_basic",
# RuntimeError: cannot mutate tensors with frozen storage
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"ElementwiseRreluWithNoiseTrainModule_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"ElementwiseSignbitModule_basic",
"ElementwiseCopysignModule_basic",
} }
FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
@ -2845,6 +2891,12 @@ ONNX_XFAIL_SET = {
"ElementwiseToDtypeI64ToUI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseUnaryIntModule_basic", "ElementwiseUnaryIntModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic", "ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseSignbitModule_basic",
"ElementwiseSignbitIntModule_basic",
"ElementwiseFracModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseLdexpModule_basic",
"Exp2StaticIntModule_basic",
"MaskedFillTensorFloatValueModule_basic", "MaskedFillTensorFloatValueModule_basic",
"MultinomialModule_basic", "MultinomialModule_basic",
"MultinomialModule2D_basic", "MultinomialModule2D_basic",

View File

@ -1471,6 +1471,18 @@ def atenfloor_divide〡shape(self: List[int], other: List[int]) -> List[int]:
def atenatan2〡shape(self: List[int], other: List[int]) -> List[int]: def atenatan2〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
def atenfrac〡shape(self: List[int]) -> List[int]:
return self
def atensignbit〡shape(self: List[int]) -> List[int]:
return self
def atenldexpTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def atencopysignTensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)
def aten__and__Tensor〡shape(self: List[int], other: List[int]) -> List[int]: def aten__and__Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other) return upstream_shape_functions.broadcast(self, other)
@ -3746,6 +3758,42 @@ def atenrsubScalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)]) return promote_dtypes([self_rank, None], [self_dtype, get_dtype_of_scalar(other)])
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool}))
def atenfrac〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.bool
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def atensignbit〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool
@check_dtype_function(_check_two_tensor_op())
def atenldexpTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
if self_dtype == torch.double and is_complex_dtype(other_dtype):
return other_dtype
elif is_complex_dtype(self_dtype) and other_dtype == torch.double:
return self_dtype
elif is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype):
return torch.float
else:
return promote_dtypes(ranks, dtypes)
@check_dtype_function(_check_two_tensor_op())
def atencopysignTensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
other_rank, other_dtype = other_rank_dtype
ranks: List[Optional[int]] = [self_rank, other_rank]
dtypes = [self_dtype, other_dtype]
if is_integer_dtype(self_dtype) and is_integer_dtype(other_dtype):
return torch.float
else:
return promote_dtypes(ranks, dtypes)
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0.0) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=0)) _check_tensors_with_the_same_dtype(num_of_tensors=1, other=0))

View File

@ -328,6 +328,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::atanh : (Tensor) -> (Tensor)", "aten::atanh : (Tensor) -> (Tensor)",
"aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)",
"aten::neg : (Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)",
"aten::frac : (Tensor) -> (Tensor)",
"aten::bitwise_not : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)",
"aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::logical_or : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)",
@ -369,6 +370,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::zero : (Tensor) -> (Tensor)", "aten::zero : (Tensor) -> (Tensor)",
"aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)",
"aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)",
"aten::copysign.Tensor : (Tensor, Tensor) -> (Tensor)",
]: ]:
emit_with_mutating_variants(key) emit_with_mutating_variants(key)
# Shape manipulations: # Shape manipulations:
@ -417,6 +419,12 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
has_canonicalizer=True, has_canonicalizer=True,
has_folder=True, has_folder=True,
) )
emit(
"aten::ldexp.Tensor : (Tensor, Tensor) -> (Tensor)",
)
emit(
"aten::signbit : (Tensor) -> (Tensor)",
)
emit_with_mutating_variants( emit_with_mutating_variants(
"aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True
) )

View File

@ -2701,6 +2701,130 @@ def ElementwiseTruncIntModule_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class ElementwiseSignbitModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([1, 8], torch.float32, True),
]
)
def forward(self, a):
return torch.signbit(a)
@register_test_case(module_factory=lambda: ElementwiseSignbitModule())
def ElementwiseSignbitModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor(
[[-torch.inf, torch.inf, torch.nan, -torch.nan, 2.3, -2.3, 0.0, -0.0]]
)
)
class ElementwiseSignbitIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 4], torch.int32, True),
]
)
def forward(self, a):
return torch.signbit(a)
@register_test_case(module_factory=lambda: ElementwiseSignbitIntModule())
def ElementwiseSignbitIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-100, high=100).to(torch.int32))
# ==============================================================================
class ElementwiseFracModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([1, 6], torch.float32, True),
]
)
def forward(self, a):
return torch.frac(a)
@register_test_case(module_factory=lambda: ElementwiseFracModule())
def ElementwiseFracModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[2.3, -2.3, 0.0, -0.0, 2.0, -2.0]]))
# ==============================================================================
class ElementwiseCopysignModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([1, 1], torch.float32, True),
([1, 6], torch.float32, True),
]
)
def forward(self, a, b):
return torch.copysign(a, b)
@register_test_case(module_factory=lambda: ElementwiseCopysignModule())
def ElementwiseCopysignModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[1.0]]),
torch.tensor([[2.3, -2.3, 0.0, -0.0, torch.inf, -torch.inf]]),
)
# ==============================================================================
class ElementwiseLdexpModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([1, 6], torch.float32, True),
([1, 1], torch.int64, True),
]
)
def forward(self, a, b):
return torch.ldexp(a, b)
@register_test_case(module_factory=lambda: ElementwiseLdexpModule())
def ElementwiseLdexpModule_basic(module, tu: TestUtils):
module.forward(
torch.tensor([[2.3, -2.3, 0.0, -0.0, 4.5, -4.5]]),
torch.tensor([[2]]),
)
# ==============================================================================
class ElementwiseSignModule(torch.nn.Module): class ElementwiseSignModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2792,6 +2916,26 @@ def Exp2StaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 2)) module.forward(tu.rand(3, 2))
class Exp2StaticIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([3, 4], torch.int64, True),
]
)
def forward(self, x):
return torch.ops.aten.exp2(x)
@register_test_case(module_factory=lambda: Exp2StaticIntModule())
def Exp2StaticIntModule_basic(module, tu: TestUtils):
module.forward(tu.randint(3, 4, low=-20, high=20))
# ============================================================================== # ==============================================================================