mirror of https://github.com/llvm/torch-mlir
parent
dd5992514d
commit
ddea56a832
|
@ -103,6 +103,7 @@ Torch::getTypeForScalarType(MLIRContext *context,
|
|||
case torch_upstream::ScalarType::Half:
|
||||
return mlir::FloatType::getF16(context);
|
||||
case torch_upstream::ScalarType::Byte:
|
||||
return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned);
|
||||
case torch_upstream::ScalarType::Char:
|
||||
return mlir::IntegerType::get(context, 8, signedness);
|
||||
case torch_upstream::ScalarType::ComplexHalf:
|
||||
|
|
|
@ -15,3 +15,13 @@ func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor
|
|||
%0 = torch.aten.argmax %arg0, %int0, %true : !torch.vtensor<[?,?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[1,?],si64>
|
||||
return %0 : !torch.vtensor<[1,?],si64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @torch.uint8
|
||||
func.func @torch.uint8(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[3,4],ui8>}) -> !torch.tensor {
|
||||
%int12 = torch.constant.int 12
|
||||
%0 = torch.prim.ListConstruct %int12 : (!torch.int) -> !torch.list<int>
|
||||
// CHECK: torch.aten.view
|
||||
// CHECK-SAME: !torch.vtensor<[12],ui8>
|
||||
%1 = torch.aten.reshape %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
|
||||
return %1 : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue