mirror of https://github.com/llvm/torch-mlir
[Torch] Support Aten_CastLongOp. (#3160)
By canonicalize Aten_CastLongOp into AtenToDtypeOppull/3134/head
parent
e4b11a0ab4
commit
d2ba956e69
|
@ -11047,6 +11047,31 @@ def Torch_Aten_CastFloatOp : Torch_Op<"aten._cast_Float", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_Aten_CastLongOp : Torch_Op<"aten._cast_Long", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::_cast_Long : (Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_BoolType:$non_blocking
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult Aten_CastLongOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 2, 1);
|
||||
}
|
||||
void Aten_CastLongOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 2, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenTypeAsOp : Torch_Op<"aten.type_as", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
|
|
|
@ -959,6 +959,27 @@ void Aten_CastFloatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Aten_CastLongOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
// `aten.cast_long` -> `aten.to.dtype`
|
||||
patterns.add(+[](Aten_CastLongOp op, PatternRewriter &rewriter) {
|
||||
auto self = op.getSelf();
|
||||
auto loc = op.getLoc();
|
||||
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value longType = rewriter.create<ConstantIntOp>(
|
||||
loc, (int)torch_upstream::ScalarType::Long);
|
||||
Value constFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), self, longType,
|
||||
op.getNonBlocking(), constFalse,
|
||||
constNone);
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenViewOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -6790,6 +6790,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten._cast_Long\"(%arg0: !torch.list<int>, %arg1: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.type_as\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
|
||||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
|
@ -12793,6 +12797,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %int6 = torch.constant.int 6\n"
|
||||
" return %int6 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._cast_Long\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.bool) -> !torch.int {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" return %int4 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.type_as\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
|
|
|
@ -577,6 +577,7 @@ STABLEHLO_PASS_SET = {
|
|||
"AtenSubFloatModule_basic",
|
||||
"AtenToDeviceModule_basic",
|
||||
"Aten_CastFloatModule_basic",
|
||||
"Aten_CastLongModule_basic",
|
||||
"AvgPool1dStaticModule_basic",
|
||||
"AvgPool2dStaticModule_basic",
|
||||
"BaddbmmBroadcast1DInputModule_basic",
|
||||
|
|
|
@ -394,6 +394,9 @@ def aten〇to〇other〡shape(self: List[int], other: List[int], non_blocking: b
|
|||
def aten〇_cast_Float〡shape(self: List[int], non_blocking: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇_cast_Long〡shape(self: List[int], non_blocking: bool = False) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
def aten〇type_as〡shape(self: List[int], other: List[int]) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
@ -4366,6 +4369,9 @@ def aten〇to〇other〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype
|
|||
def aten〇_cast_Float〡dtype(self_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int:
|
||||
return torch.float32
|
||||
|
||||
def aten〇_cast_Long〡dtype(self_rank_dtype: Tuple[int, int], non_blocking: bool = False) -> int:
|
||||
return torch.int64
|
||||
|
||||
@check_dtype_function(_check_two_tensor_op())
|
||||
def aten〇type_as〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
|
||||
other_rank, other_dtype = other_rank_dtype
|
||||
|
|
|
@ -673,6 +673,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)")
|
||||
emit("aten::to.device : (Tensor, Device, int, bool, bool, int?) -> (Tensor)")
|
||||
emit("aten::_cast_Float : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::_cast_Long : (Tensor, bool) -> (Tensor)", has_canonicalizer=True)
|
||||
emit("aten::type_as : (Tensor, Tensor) -> (Tensor)")
|
||||
emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True)
|
||||
emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)")
|
||||
|
|
|
@ -4257,7 +4257,25 @@ class Aten_CastFloatModule(torch.nn.Module):
|
|||
def Aten_CastFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2, 4))
|
||||
|
||||
|
||||
class Aten_CastLongModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 4], torch.float32, True),
|
||||
])
|
||||
|
||||
def forward(self, val):
|
||||
return torch.ops.aten._cast_Long(val)
|
||||
|
||||
@register_test_case(module_factory=lambda: Aten_CastLongModule())
|
||||
def Aten_CastLongModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 4))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class UpSampleNearest2dBackward(torch.nn.Module):
|
||||
|
|
Loading…
Reference in New Issue