mirror of https://github.com/llvm/torch-mlir
[Torch] add AtenAsStridedOp in torch dialect (#3706)
parent
3f07077ff9
commit
edf725ef42
|
@ -13195,6 +13195,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$size,
|
||||
AnyTorchListOfTorchIntType:$stride,
|
||||
AnyTorchOptionalIntType:$storage_offset
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 4, 1);
|
||||
}
|
||||
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -10002,6 +10002,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list<int>, !torch.int, !torch.optional<int>, !torch.optional<int>, !torch.int) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
|
||||
" return %arg1 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.sort\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %0 = torch.prim.TupleConstruct %arg0, %arg0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %0 : !torch.tuple<list<int>, list<int>>\n"
|
||||
|
@ -12297,6 +12300,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
|
||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||
" return %0#1 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.aten._softmax_backward_data\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
|
||||
" return %arg3 : !torch.int\n"
|
||||
" }\n"
|
||||
|
|
|
@ -247,11 +247,12 @@ bool Torch::isViewLikeOp(Operation *op) {
|
|||
// correct. We could potentially be more precise and identify the cases
|
||||
// that it does not return a view and treat those as having value
|
||||
// semantics.
|
||||
return isa<AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, AtenExpandAsOp,
|
||||
AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
|
||||
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
|
||||
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
|
||||
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
|
||||
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
|
||||
AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
|
||||
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
|
||||
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
|
||||
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
|
||||
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
|
||||
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
|
||||
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
|
||||
|
|
|
@ -1849,6 +1849,9 @@ def aten〇_weight_norm_interface〡shape(v: List[int], g: List[int], dim: int =
|
|||
def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]:
|
||||
return upstream_shape_functions.slice(self, dim, start, end, step)
|
||||
|
||||
def aten〇as_strided〡shape(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]:
|
||||
return size
|
||||
|
||||
def aten〇sort〡shape(self: List[int], dim: int = -1, descending: bool = False) -> Tuple[List[int], List[int]]:
|
||||
return self, self
|
||||
|
||||
|
@ -3377,6 +3380,10 @@ def aten〇slice〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], dim: int = 0
|
|||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
def aten〇as_strided〡dtype(self_rank_dtype: Tuple[int, int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> int:
|
||||
self_rank, self_dtype = self_rank_dtype
|
||||
return self_dtype
|
||||
|
||||
@check_dtype_function(
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float32) +
|
||||
_check_tensors_with_the_same_dtype(num_of_tensors=2, dim=0, input_dtype=torch.float64) +
|
||||
|
|
|
@ -968,6 +968,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::alias_copy : (Tensor) -> (Tensor)")
|
||||
emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True)
|
||||
emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)")
|
||||
emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)")
|
||||
emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)")
|
||||
|
|
Loading…
Reference in New Issue