mirror of https://github.com/llvm/torch-mlir
Rob's atenTensor folder (#2867)
If a tensor is initialized by a list with a single constant integer, this folder turns it into a torch.vtensor.literal --------- Co-authored-by: Dave Liddell <dliddell@xilinx.com>pull/2870/head
parent
041a54ae0c
commit
1cb14f6879
|
@ -8582,6 +8582,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [
|
|||
printDefaultTorchOp(printer, *this, 4, 1);
|
||||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [
|
||||
|
|
|
@ -2758,6 +2758,27 @@ void AtenDeviceWithIndexOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) {
|
||||
// If a torch.aten.tensor op is initialized by a list with a constant, single
|
||||
// element, fold it into a torch.vtensor.literal
|
||||
auto resultTy = dyn_cast<ValueTensorType>(getType());
|
||||
Type eTy = resultTy.getDtype();
|
||||
ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy);
|
||||
|
||||
SmallVector<int64_t> data;
|
||||
if (matchPattern(getData(), m_TorchListOfConstantInts(data)) &&
|
||||
data.size() == 1) {
|
||||
Attribute attribute = IntegerAttr::get(eTy, data[0]);
|
||||
return DenseElementsAttr::get(shapedTy, attribute);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIntTensorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -570,7 +570,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True)
|
||||
emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)")
|
||||
emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)")
|
||||
|
|
|
@ -1461,6 +1461,17 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to
|
|||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.tensor$one_elem(
|
||||
// CHECK-NEXT: torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64>
|
||||
func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) {
|
||||
%none = torch.constant.none
|
||||
%false = torch.constant.bool false
|
||||
%int42 = torch.constant.int 42
|
||||
%66 = torch.prim.ListConstruct %int42 : (!torch.int) -> !torch.list<int>
|
||||
%67 = torch.aten.tensor %66, %none, %none, %false : !torch.list<int>, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64>
|
||||
return %67 : !torch.vtensor<[1],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype(
|
||||
// CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> {
|
||||
// CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32>
|
||||
|
|
Loading…
Reference in New Issue