mirror of https://github.com/llvm/torch-mlir
Add E2E support for aten.is_floating_point
parent
246c2df65a
commit
708a51ae2e
|
@ -4149,6 +4149,29 @@ def Torch_AtenBoolTensorOp : Torch_Op<"aten.Bool.Tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenIsFloatingPointOp : Torch_Op<"aten.is_floating_point", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::is_floating_point : (Tensor) -> (bool)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenIsFloatingPointOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 1, 1);
|
||||
}
|
||||
void AtenIsFloatingPointOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 1, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -50,6 +50,24 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertAtenIsFloatingPointOp
|
||||
: public OpConversionPattern<AtenIsFloatingPointOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AtenIsFloatingPointOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto tensorType = op.self().getType().cast<BaseTensorType>();
|
||||
bool result =
|
||||
tensorType.hasDtype() && tensorType.getDtype().isa<mlir::FloatType>();
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, BoolAttr::get(getContext(), result));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
|
||||
public:
|
||||
|
@ -301,6 +319,8 @@ public:
|
|||
RewritePatternSet patterns(context);
|
||||
target.addIllegalOp<AtenDimOp>();
|
||||
patterns.add<ConvertAtenDimOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenIsFloatingPointOp>();
|
||||
patterns.add<ConvertAtenIsFloatingPointOp>(typeConverter, context);
|
||||
target.addIllegalOp<RuntimeAssertOp>();
|
||||
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenNeIntOp, AtenEqIntOp, AtenGtIntOp>();
|
||||
|
|
|
@ -392,6 +392,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::dim : (Tensor) -> (int)", has_folder=True)
|
||||
emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True)
|
||||
emit("aten::Bool.Tensor : (Tensor) -> (bool)")
|
||||
emit("aten::is_floating_point : (Tensor) -> (bool)")
|
||||
emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
|
|
|
@ -64,6 +64,50 @@ def BmmModule_basic(module, tu: TestUtils):
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class IsFloatingPointInt(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.is_floating_point(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IsFloatingPointInt())
|
||||
def IsFloatingPointInt_False(module, tu: TestUtils):
|
||||
module.forward(torch.randint(100, (3, 3)))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class IsFloatingPointFloat(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.is_floating_point(x)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: IsFloatingPointFloat())
|
||||
def IsFloatingPointFloat_True(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
# A subgraph with multiple mm ops.
|
||||
class MmDagModule(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue