Register torch.aten._local_scalar_dense

pull/2925/head
Max Dawkins 2024-02-15 12:39:51 -05:00
parent ec2b80b433
commit f8332a4da2
4 changed files with 32 additions and 0 deletions

View File

@ -10317,6 +10317,29 @@ def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
}];
}
def Torch_Aten_LocalScalarDenseOp : Torch_Op<"aten._local_scalar_dense", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_local_scalar_dense : (Tensor) -> (Scalar)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchScalarType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_LocalScalarDenseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void Aten_LocalScalarDenseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [
AllowsTypeRefinement,
ReadOnly

View File

@ -10181,6 +10181,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten._local_scalar_dense\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.select.int\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"

View File

@ -2541,6 +2541,10 @@ def atenscatter_reducetwo〡dtype(self_rank_dtype: Tuple[int, int], dim: i
self_rank, self_dtype = self_rank_dtype
return self_dtype
def aten_local_scalar_dense〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, index=0))
def atenselectint〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index: int) -> int:
self_rank, self_dtype = self_rank_dtype

View File

@ -636,6 +636,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)")
emit("aten::resize : (Tensor, int[], int?) -> (Tensor)")
emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)")
emit("aten::_local_scalar_dense : (Tensor) -> (Scalar)")
emit("aten::select.int : (Tensor, int, int) -> (Tensor)", has_folder=1)
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::sum : (Tensor, int?) -> (Tensor)")