Bump LLVM to 881ff4e4ebe8cc0cc045c7c167cffb01f94f27f8 (#539)

pull/541/head snapshot-20220126.229
stephenneuendorffer 2022-01-25 22:16:30 -08:00 committed by GitHub
parent 17a4843cf7
commit 3fd9b7789e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 47 additions and 63 deletions

View File

@ -13,6 +13,10 @@ if(POLICY CMP0077)
cmake_policy(SET CMP0077 NEW) cmake_policy(SET CMP0077 NEW)
endif() endif()
if(POLICY CMP0116)
cmake_policy(SET CMP0116 OLD)
endif()
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------
# Project setup and globals # Project setup and globals
#------------------------------------------------------------------------------- #-------------------------------------------------------------------------------

@ -1 +1 @@
Subproject commit 63f0c00d38ee7879239975a6743d4e6c7847b725 Subproject commit 881ff4e4ebe8cc0cc045c7c167cffb01f94f27f8

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Traits.h" #include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h" #include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
@ -275,14 +276,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
SmallVectorImpl<int64_t> &highPaddingInts, SmallVectorImpl<int64_t> &highPaddingInts,
Value pad) { Value pad) {
Location loc = op->getLoc(); Location loc = op->getLoc();
Type rankedTensorType = linalg::PadTensorOp::inferResultType( Type rankedTensorType = tensor::PadOp::inferResultType(
input.getType().cast<RankedTensorType>(), lowPaddingInts, input.getType().cast<RankedTensorType>(), lowPaddingInts,
highPaddingInts); highPaddingInts);
SmallVector<OpFoldResult> lowPaddings = SmallVector<OpFoldResult> lowPaddings =
getAsOpFoldResult(b, loc, lowPaddingInts); getAsOpFoldResult(b, loc, lowPaddingInts);
SmallVector<OpFoldResult> highPaddings = SmallVector<OpFoldResult> highPaddings =
getAsOpFoldResult(b, loc, highPaddingInts); getAsOpFoldResult(b, loc, highPaddingInts);
Value paddedInput = linalg::PadTensorOp::createPadScalarOp( Value paddedInput = tensor::createPadScalarOp(
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings, rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
/*packing=*/false, loc, b); /*packing=*/false, loc, b);
return paddedInput; return paddedInput;

View File

@ -1059,7 +1059,7 @@ public:
// Step: generate the common dim/shape information // Step: generate the common dim/shape information
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
bool isDynamicDim = bool isDynamicDim =
lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]); ShapedType::isDynamic(lhsBroadcastedShape[dim]);
if (isDynamicDim || if (isDynamicDim ||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
commonValue *= lhsBroadcastedShape[dim]; commonValue *= lhsBroadcastedShape[dim];
@ -1071,7 +1071,7 @@ public:
bool hasDynamicDims = false; bool hasDynamicDims = false;
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
bool isDynamicDim = bool isDynamicDim =
lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]); ShapedType::isDynamic(lhsBroadcastedShape[dim]);
hasDynamicDims |= isDynamicDim; hasDynamicDims |= isDynamicDim;
if (!isDynamicDim && if (!isDynamicDim &&
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) { lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
@ -1156,7 +1156,7 @@ public:
hasDynamicDims = false; hasDynamicDims = false;
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) { for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
bool isDynamicDim = bool isDynamicDim =
rhsBroadcastedTy.isDynamic(rhsBroadcastedShape[dim]); ShapedType::isDynamic(rhsBroadcastedShape[dim]);
hasDynamicDims |= isDynamicDim; hasDynamicDims |= isDynamicDim;
if (!isDynamicDim && if (!isDynamicDim &&
rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) { rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) {

View File

@ -63,28 +63,6 @@ void TorchDialect::initialize() {
addInterfaces<TorchInlinerInterface>(); addInterfaces<TorchInlinerInterface>();
} }
//===----------------------------------------------------------------------===//
// Type-related Dialect methods.
//===----------------------------------------------------------------------===//
Type TorchDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
Type type;
if (generatedTypeParser(parser, keyword, type).hasValue())
return type;
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
<< keyword << "'";
return Type();
}
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
if (failed(generatedTypePrinter(type, printer)))
llvm_unreachable("unknown 'torch' type");
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Dialect-level verifiers. // Dialect-level verifiers.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -321,9 +321,9 @@ static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
static void print(OpAsmPrinter &p, PrimIfOp op) { static void print(OpAsmPrinter &p, PrimIfOp op) {
p << " " << op.condition(); p << " " << op.condition();
p << " -> (" << op.getResultTypes() << ")"; p << " -> (" << op.getResultTypes() << ") ";
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false); p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
p << " else"; p << " else ";
p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false); p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false);
p.printOptionalAttrDict(op->getAttrs()); p.printOptionalAttrDict(op->getAttrs());

View File

@ -38,7 +38,7 @@ public:
matchAndRewrite(FuncOp func, OpAdaptor adaptor, matchAndRewrite(FuncOp func, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = func.getContext(); MLIRContext *context = func.getContext();
auto typeBoundIdent = Identifier::get("torch.type_bound", context); auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
TypeConverter::SignatureConversion conversion(func.getNumArguments()); TypeConverter::SignatureConversion conversion(func.getNumArguments());
// The TypeConverter hooks for type conversion are "context free", so we // The TypeConverter hooks for type conversion are "context free", so we

View File

@ -164,7 +164,7 @@ struct FuncBackendTypeConversionPass
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter); TorchConversion::setupBackendTypeConversion(target, typeConverter);
populateFuncOpTypeConversionPattern(patterns, typeConverter); populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) && return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody()); typeConverter.isLegal(&op.getBody());

View File

@ -173,7 +173,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate()); mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion); MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
auto inserter = caffe2::MakeGuard([&]() { auto inserter = caffe2::MakeGuard([&]() {
mlirBlockInsertOwnedOperationBefore( mlirBlockInsertOwnedOperationBefore(
@ -441,7 +441,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
mlirStringAttrGet( mlirStringAttrGet(
context, toMlirStringRef(classType->name()->qualifiedName())))); context, toMlirStringRef(classType->name()->qualifiedName()))));
MlirRegion region = mlirOperationGetRegion(op, 0); MlirRegion region = mlirOperationGetRegion(op, 0);
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr)); mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region); MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
ClassAnnotation &classAnnotation = ClassAnnotation &classAnnotation =

View File

@ -303,7 +303,8 @@ MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
MlirLocation loc = getMlirLocationFromNode(context, paramNode); MlirLocation loc = getMlirLocationFromNode(context, paramNode);
std::vector<MlirType> blockArgTypes = std::vector<MlirType> blockArgTypes =
getMlirTypesFromValues(loc, paramNode->outputs()); getMlirTypesFromValues(loc, paramNode->outputs());
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data()); std::vector<MlirLocation> blockArgLocs(blockArgTypes.size(), loc);
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data());
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) { for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
Value *jitValue = paramNode->outputs()[i]; Value *jitValue = paramNode->outputs()[i];
MlirValue value = mlirBlockGetArgument(block, i); MlirValue value = mlirBlockGetArgument(block, i);

View File

@ -11,7 +11,7 @@ builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?
%int7 = torch.constant.int 7 %int7 = torch.constant.int 7
%int8 = torch.constant.int 8 %int8 = torch.constant.int 8
%false = torch.constant.bool false %false = torch.constant.bool false
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6] // CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
// CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32 // CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32> // CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index

View File

@ -663,7 +663,7 @@ builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.ten
return %res: !torch.tensor return %res: !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.tensor.float( // CHECK-LABEL: func @torch.aten.tensor.float(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { // CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
@ -680,7 +680,7 @@ builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype( // CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype(
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor { // CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
@ -699,7 +699,7 @@ builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torc
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.tensor( // CHECK-LABEL: func @torch.aten.tensor(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor { // CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
@ -718,7 +718,7 @@ builtin.func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.tensor$empty_list() -> !torch.tensor { // CHECK-LABEL: func @torch.aten.tensor$empty_list() -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -735,7 +735,7 @@ builtin.func @torch.aten.tensor$empty_list() -> !torch.tensor {
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype( // CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor { // CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
@ -754,7 +754,7 @@ builtin.func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!tor
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.zeros( // CHECK-LABEL: func @torch.aten.zeros(
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
@ -773,7 +773,7 @@ builtin.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.index_select( // CHECK-LABEL: func @torch.aten.index_select(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
@ -789,7 +789,7 @@ builtin.func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %inde
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.index_select$unknown_indexes( // CHECK-LABEL: func @torch.aten.index_select$unknown_indexes(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
@ -805,7 +805,7 @@ builtin.func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.index_select$unknown_dim( // CHECK-LABEL: func @torch.aten.index_select$unknown_dim(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
@ -820,7 +820,7 @@ builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4],
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.select.int( // CHECK-LABEL: func @torch.aten.select.int(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
@ -836,7 +836,7 @@ builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index:
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.type_as( // CHECK-LABEL: func @torch.aten.type_as(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor { // CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
@ -849,7 +849,7 @@ builtin.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch
return %ret: !torch.tensor return %ret: !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.gather( // CHECK-LABEL: func @torch.aten.gather(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
@ -866,7 +866,7 @@ builtin.func @torch.aten.gather(%input: !torch.tensor<[2,3,4], f32>, %dim: !torc
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.expand( // CHECK-LABEL: func @torch.aten.expand(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
// CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FALSE:.*]] = torch.constant.bool false
@ -888,7 +888,7 @@ builtin.func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.t
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.expand$higher_rank( // CHECK-LABEL: func @torch.aten.expand$higher_rank(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
@ -913,7 +913,7 @@ builtin.func @torch.aten.expand$higher_rank(%input: !torch.tensor<[2,1,4], f32>)
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.expand$unknown_sizes( // CHECK-LABEL: func @torch.aten.expand$unknown_sizes(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
// CHECK-SAME: %[[SIZEX:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[SIZEX:.*]]: !torch.int) -> !torch.tensor {
@ -933,7 +933,7 @@ builtin.func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.repeat( // CHECK-LABEL: func @torch.aten.repeat(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
// CHECK-SAME: %[[REPEATX:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[REPEATX:.*]]: !torch.int) -> !torch.tensor {
@ -952,7 +952,7 @@ builtin.func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !t
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.cat( // CHECK-LABEL: func @torch.aten.cat(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, // CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
@ -970,7 +970,7 @@ builtin.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tenso
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.cat$unknown_dim( // CHECK-LABEL: func @torch.aten.cat$unknown_dim(
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>, // CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>, // CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>,
@ -986,7 +986,7 @@ builtin.func @torch.aten.cat$unknown_dim(%t0: !torch.tensor<[?,1,4], f32>, %t1:
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten._shape_as_tensor( // CHECK-LABEL: func @torch.aten._shape_as_tensor(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<[3],si64> // CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<[3],si64>
@ -997,7 +997,7 @@ builtin.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape( // CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor { // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<[?],si64> // CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<[?],si64>
@ -1008,7 +1008,7 @@ builtin.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.ten
return %ret : !torch.tensor return %ret : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.embedding( // CHECK-LABEL: func @torch.aten.embedding(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>, // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { // CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
@ -1024,7 +1024,7 @@ builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indic
return %ret: !torch.tensor return %ret: !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.softmax.int( // CHECK-LABEL: func @torch.aten.softmax.int(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
@ -1039,7 +1039,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype( // CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>, // CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
@ -1054,7 +1054,7 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix( // CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, // CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor { // CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
@ -1067,7 +1067,7 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector( // CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>, // CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor { // CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
@ -1096,7 +1096,7 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
return %0 : !torch.tensor return %0 : !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar( // CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor { // CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64> // CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
@ -1107,7 +1107,7 @@ func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
return %0: !torch.tensor return %0: !torch.tensor
} }
// ---- // -----
// CHECK-LABEL: func @torch.aten.BinaryBroadcasting( // CHECK-LABEL: func @torch.aten.BinaryBroadcasting(
// CHECK-SAME: %[[T0:.*]]: !torch.vtensor<[5,4,3,3,1],f32>, // CHECK-SAME: %[[T0:.*]]: !torch.vtensor<[5,4,3,3,1],f32>,
// CHECK-SAME: %[[T1:.*]]: !torch.vtensor<[?,3,1,2],f32>, // CHECK-SAME: %[[T1:.*]]: !torch.vtensor<[?,3,1,2],f32>,