mirror of https://github.com/llvm/torch-mlir
[Torch] add AtenAsStridedOp in torch dialect (#3706)
parent
3f07077ff9
commit
edf725ef42
|
@ -13195,6 +13195,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$self,
|
||||||
|
AnyTorchListOfTorchIntType:$size,
|
||||||
|
AnyTorchListOfTorchIntType:$stride,
|
||||||
|
AnyTorchOptionalIntType:$storage_offset
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||||
|
}
|
||||||
|
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 4, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
|
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
ReadOnly
|
ReadOnly
|
||||||
|
|
|
@ -10002,6 +10002,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %0 : !torch.list<int>\n"
|
" return %0 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||||
|
" return %arg1 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||||
" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||||
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
||||||
|
@ -12297,6 +12300,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||||
" return %arg3 : !torch.int\n"
|
" return %arg3 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
|
|
@ -247,11 +247,12 @@ bool Torch::isViewLikeOp(Operation *op) {
|
||||||
// correct. We could potentially be more precise and identify the cases
|
// correct. We could potentially be more precise and identify the cases
|
||||||
// that it does not return a view and treat those as having value
|
// that it does not return a view and treat those as having value
|
||||||
// semantics.
|
// semantics.
|
||||||
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
|
return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
|
||||||
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
|
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
|
||||||
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
|
AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
|
||||||
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
|
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
|
||||||
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
|
||||||
|
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
||||||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||||
|
|
|
@ -1849,6 +1849,9 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
|
||||||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||||
|
|
||||||
|
def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
|
||||||
|
return size
|
||||||
|
|
||||||
def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]:
|
def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]:
|
||||||
return self, self
|
return self, self
|
||||||
|
|
||||||
|
@ -3377,6 +3380,10 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0
|
||||||
self_rank, self_dtype = self_rank_dtype
|
self_rank, self_dtype = self_rank_dtype
|
||||||
return self_dtype
|
return self_dtype
|
||||||
|
|
||||||
|
def aten〇as_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int:
|
||||||
|
self_rank, self_dtype = self_rank_dtype
|
||||||
|
return self_dtype
|
||||||
|
|
||||||
@check_dtype_function(
|
@check_dtype_function(
|
||||||
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) +
|
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) +
|
||||||
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) +
|
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) +
|
||||||
|
|
|
@ -968,6 +968,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||||
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
|
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
|
||||||
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
|
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
|
||||||
|
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
|
||||||
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
|
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||||
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
|
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
|
||||||
|
|
Loading…
Reference in New Issue