mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Add canonicalize pattern for aten.is_floating_point (#2194)
* [Torch Dialect] Add canonicalize pattern for aten.is_floating_point * implement as fold * add lit testpull/2198/head snapshot-20230607.862
parent
816880774b
commit
5a7bf4e4cb
|
@ -6284,6 +6284,7 @@ def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [
|
||||||
printDefaultTorchOp(printer, *this, 1, 1);
|
printDefaultTorchOp(printer, *this, 1, 1);
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
|
def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
|
||||||
|
|
|
@ -51,24 +51,6 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
|
||||||
class ConvertAtenIsFloatingPointOp
|
|
||||||
: public OpConversionPattern<AtenIsFloatingPointOp> {
|
|
||||||
public:
|
|
||||||
using OpConversionPattern::OpConversionPattern;
|
|
||||||
LogicalResult
|
|
||||||
matchAndRewrite(AtenIsFloatingPointOp op, OpAdaptor adaptor,
|
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
|
||||||
auto tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
|
||||||
bool result =
|
|
||||||
tensorType.hasDtype() && tensorType.getDtype().isa<mlir::FloatType>();
|
|
||||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
|
||||||
op, BoolAttr::get(getContext(), result));
|
|
||||||
return success();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
|
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
|
||||||
public:
|
public:
|
||||||
|
@ -400,8 +382,6 @@ public:
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
target.addIllegalOp<AtenDimOp>();
|
target.addIllegalOp<AtenDimOp>();
|
||||||
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenIsFloatingPointOp>();
|
|
||||||
patterns.add<ConvertAtenIsFloatingPointOp>(typeConverter, context);
|
|
||||||
target.addIllegalOp<RuntimeAssertOp>();
|
target.addIllegalOp<RuntimeAssertOp>();
|
||||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp>();
|
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp, AtenGeIntOp>();
|
||||||
|
|
|
@ -1860,6 +1860,22 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// AtenIsFloatingPointOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
||||||
|
auto operandType = getSelf().getType().dyn_cast<BaseTensorType>();
|
||||||
|
if (!operandType)
|
||||||
|
return nullptr;
|
||||||
|
if (operandType.hasDtype()) {
|
||||||
|
bool isFloatType = operandType.getDtype().isa<mlir::FloatType>();
|
||||||
|
return IntegerAttr::get(IntegerType::get(getContext(), 1), isFloatType);
|
||||||
|
}
|
||||||
|
// doesn't has dtype
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// AtenAddTOp
|
// AtenAddTOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -456,7 +456,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
||||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||||
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
||||||
emit("aten::is_floating_point : (Tensor) -> (bool)")
|
emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True)
|
||||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||||
|
|
|
@ -21,6 +21,22 @@ func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.i
|
||||||
return %0, %1, %2, %3 : !torch.int, !torch.int, !torch.int, !torch.int
|
return %0, %1, %2, %3 : !torch.int, !torch.int, !torch.int, !torch.int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true
|
||||||
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
|
// CHECK: return %[[TRUE]] : !torch.bool
|
||||||
|
func.func @torch.aten.is_floating_point$fold_true(%arg0: !torch.vtensor<[], f32>) -> !torch.bool {
|
||||||
|
%0 = torch.aten.is_floating_point %arg0 : !torch.vtensor<[], f32> -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_false
|
||||||
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
|
func.func @torch.aten.is_floating_point$fold_false(%arg0: !torch.vtensor<[], si64>) -> !torch.bool {
|
||||||
|
%0 = torch.aten.is_floating_point %arg0 : !torch.vtensor<[], si64> -> !torch.bool
|
||||||
|
return %0 : !torch.bool
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.__is__
|
// CHECK-LABEL: func.func @torch.aten.__is__
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: return %[[FALSE]] : !torch.bool
|
// CHECK: return %[[FALSE]] : !torch.bool
|
||||||
|
|
Loading…
Reference in New Issue