[Torch] add AtenAsStridedOp in torch dialect (#3706)

pull/3710/head
yyp0 2024-09-12 19:07:11 +08:00 committed by GitHub
parent 3f07077ff9
commit edf725ef42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 5 deletions

View File

@ -13195,6 +13195,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
}]; }];
} }
def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$storage_offset
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
AllowsTypeRefinement, AllowsTypeRefinement,
ReadOnly ReadOnly

View File

@ -10002,6 +10002,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n" " func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n" " %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n" " return %0 : !torch.tuple<list<int>, list<int>>\n"
@ -12297,6 +12300,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" return %arg3 : !torch.int\n" " return %arg3 : !torch.int\n"
" }\n" " }\n"

View File

@ -247,11 +247,12 @@ bool Torch::isViewLikeOp(Operation *op) {
// correct. We could potentially be more precise and identify the cases // correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value // that it does not return a view and treat those as having value
// semantics. // semantics.
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp, return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp, AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp, AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,

View File

@ -1849,6 +1849,9 @@ def aten_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: def atensliceTensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
return upstream_shape_functions.slice(self, dim, start, end, step) return upstream_shape_functions.slice(self, dim, start, end, step)
def atenas_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
return size
def atensort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]: def atensort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]:
return self, self return self, self
@ -3377,6 +3380,10 @@ def atensliceTensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype return self_dtype
def atenas_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function( @check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) +
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) + _check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) +

View File

@ -968,6 +968,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias_copy : (Tensor) -> (Tensor)")
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True) emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")