mirror of https://github.com/llvm/torch-mlir
parent
17a4843cf7
commit
3fd9b7789e
|
@ -13,6 +13,10 @@ if(POLICY CMP0077)
|
||||||
cmake_policy(SET CMP0077 NEW)
|
cmake_policy(SET CMP0077 NEW)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(POLICY CMP0116)
|
||||||
|
cmake_policy(SET CMP0116 OLD)
|
||||||
|
endif()
|
||||||
|
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
# Project setup and globals
|
# Project setup and globals
|
||||||
#-------------------------------------------------------------------------------
|
#-------------------------------------------------------------------------------
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 63f0c00d38ee7879239975a6743d4e6c7847b725
|
Subproject commit 881ff4e4ebe8cc0cc045c7c167cffb01f94f27f8
|
|
@ -14,6 +14,7 @@
|
||||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Math/IR/Math.h"
|
#include "mlir/Dialect/Math/IR/Math.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||||
#include "mlir/Dialect/Traits.h"
|
#include "mlir/Dialect/Traits.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
@ -275,14 +276,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
||||||
SmallVectorImpl<int64_t> &highPaddingInts,
|
SmallVectorImpl<int64_t> &highPaddingInts,
|
||||||
Value pad) {
|
Value pad) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Type rankedTensorType = linalg::PadTensorOp::inferResultType(
|
Type rankedTensorType = tensor::PadOp::inferResultType(
|
||||||
input.getType().cast<RankedTensorType>(), lowPaddingInts,
|
input.getType().cast<RankedTensorType>(), lowPaddingInts,
|
||||||
highPaddingInts);
|
highPaddingInts);
|
||||||
SmallVector<OpFoldResult> lowPaddings =
|
SmallVector<OpFoldResult> lowPaddings =
|
||||||
getAsOpFoldResult(b, loc, lowPaddingInts);
|
getAsOpFoldResult(b, loc, lowPaddingInts);
|
||||||
SmallVector<OpFoldResult> highPaddings =
|
SmallVector<OpFoldResult> highPaddings =
|
||||||
getAsOpFoldResult(b, loc, highPaddingInts);
|
getAsOpFoldResult(b, loc, highPaddingInts);
|
||||||
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
|
Value paddedInput = tensor::createPadScalarOp(
|
||||||
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
|
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
|
||||||
/*packing=*/false, loc, b);
|
/*packing=*/false, loc, b);
|
||||||
return paddedInput;
|
return paddedInput;
|
||||||
|
|
|
@ -1059,7 +1059,7 @@ public:
|
||||||
// Step: generate the common dim/shape information
|
// Step: generate the common dim/shape information
|
||||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||||
bool isDynamicDim =
|
bool isDynamicDim =
|
||||||
lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]);
|
ShapedType::isDynamic(lhsBroadcastedShape[dim]);
|
||||||
if (isDynamicDim ||
|
if (isDynamicDim ||
|
||||||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
|
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
|
||||||
commonValue *= lhsBroadcastedShape[dim];
|
commonValue *= lhsBroadcastedShape[dim];
|
||||||
|
@ -1071,7 +1071,7 @@ public:
|
||||||
bool hasDynamicDims = false;
|
bool hasDynamicDims = false;
|
||||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||||
bool isDynamicDim =
|
bool isDynamicDim =
|
||||||
lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]);
|
ShapedType::isDynamic(lhsBroadcastedShape[dim]);
|
||||||
hasDynamicDims |= isDynamicDim;
|
hasDynamicDims |= isDynamicDim;
|
||||||
if (!isDynamicDim &&
|
if (!isDynamicDim &&
|
||||||
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
|
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
|
||||||
|
@ -1156,7 +1156,7 @@ public:
|
||||||
hasDynamicDims = false;
|
hasDynamicDims = false;
|
||||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||||
bool isDynamicDim =
|
bool isDynamicDim =
|
||||||
rhsBroadcastedTy.isDynamic(rhsBroadcastedShape[dim]);
|
ShapedType::isDynamic(rhsBroadcastedShape[dim]);
|
||||||
hasDynamicDims |= isDynamicDim;
|
hasDynamicDims |= isDynamicDim;
|
||||||
if (!isDynamicDim &&
|
if (!isDynamicDim &&
|
||||||
rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) {
|
rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) {
|
||||||
|
|
|
@ -63,28 +63,6 @@ void TorchDialect::initialize() {
|
||||||
addInterfaces<TorchInlinerInterface>();
|
addInterfaces<TorchInlinerInterface>();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// Type-related Dialect methods.
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
|
||||||
StringRef keyword;
|
|
||||||
if (parser.parseKeyword(&keyword))
|
|
||||||
return Type();
|
|
||||||
Type type;
|
|
||||||
if (generatedTypeParser(parser, keyword, type).hasValue())
|
|
||||||
return type;
|
|
||||||
|
|
||||||
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
|
|
||||||
<< keyword << "'";
|
|
||||||
return Type();
|
|
||||||
}
|
|
||||||
|
|
||||||
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
|
||||||
if (failed(generatedTypePrinter(type, printer)))
|
|
||||||
llvm_unreachable("unknown 'torch' type");
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Dialect-level verifiers.
|
// Dialect-level verifiers.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -38,7 +38,7 @@ public:
|
||||||
matchAndRewrite(FuncOp func, OpAdaptor adaptor,
|
matchAndRewrite(FuncOp func, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
MLIRContext *context = func.getContext();
|
MLIRContext *context = func.getContext();
|
||||||
auto typeBoundIdent = Identifier::get("torch.type_bound", context);
|
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
|
||||||
TypeConverter::SignatureConversion conversion(func.getNumArguments());
|
TypeConverter::SignatureConversion conversion(func.getNumArguments());
|
||||||
|
|
||||||
// The TypeConverter hooks for type conversion are "context free", so we
|
// The TypeConverter hooks for type conversion are "context free", so we
|
||||||
|
|
|
@ -164,7 +164,7 @@ struct FuncBackendTypeConversionPass
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||||
|
|
||||||
populateFuncOpTypeConversionPattern(patterns, typeConverter);
|
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
|
||||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||||
return typeConverter.isSignatureLegal(op.getType()) &&
|
return typeConverter.isSignatureLegal(op.getType()) &&
|
||||||
typeConverter.isLegal(&op.getBody());
|
typeConverter.isLegal(&op.getBody());
|
||||||
|
|
|
@ -173,7 +173,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
||||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||||
auto inserter = caffe2::MakeGuard([&]() {
|
auto inserter = caffe2::MakeGuard([&]() {
|
||||||
mlirBlockInsertOwnedOperationBefore(
|
mlirBlockInsertOwnedOperationBefore(
|
||||||
|
@ -441,7 +441,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
||||||
mlirStringAttrGet(
|
mlirStringAttrGet(
|
||||||
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
||||||
MlirRegion region = mlirOperationGetRegion(op, 0);
|
MlirRegion region = mlirOperationGetRegion(op, 0);
|
||||||
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
|
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
|
||||||
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
||||||
|
|
||||||
ClassAnnotation &classAnnotation =
|
ClassAnnotation &classAnnotation =
|
||||||
|
|
|
@ -303,7 +303,8 @@ MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
|
||||||
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
||||||
std::vector<MlirType> blockArgTypes =
|
std::vector<MlirType> blockArgTypes =
|
||||||
getMlirTypesFromValues(loc, paramNode->outputs());
|
getMlirTypesFromValues(loc, paramNode->outputs());
|
||||||
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data());
|
std::vector<MlirLocation> blockArgLocs(blockArgTypes.size(), loc);
|
||||||
|
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data());
|
||||||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||||
Value *jitValue = paramNode->outputs()[i];
|
Value *jitValue = paramNode->outputs()[i];
|
||||||
MlirValue value = mlirBlockGetArgument(block, i);
|
MlirValue value = mlirBlockGetArgument(block, i);
|
||||||
|
|
|
@ -11,7 +11,7 @@ builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?
|
||||||
%int7 = torch.constant.int 7
|
%int7 = torch.constant.int 7
|
||||||
%int8 = torch.constant.int 8
|
%int8 = torch.constant.int 8
|
||||||
%false = torch.constant.bool false
|
%false = torch.constant.bool false
|
||||||
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32
|
// CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32
|
||||||
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
|
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
|
||||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
|
|
|
@ -663,7 +663,7 @@ builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.ten
|
||||||
return %res: !torch.tensor
|
return %res: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.tensor.float(
|
// CHECK-LABEL: func @torch.aten.tensor.float(
|
||||||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||||
|
@ -680,7 +680,7 @@ builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype(
|
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype(
|
||||||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||||
|
@ -699,7 +699,7 @@ builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torc
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.tensor(
|
// CHECK-LABEL: func @torch.aten.tensor(
|
||||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
|
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
|
||||||
|
@ -718,7 +718,7 @@ builtin.func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.tensor$empty_list() -> !torch.tensor {
|
// CHECK-LABEL: func @torch.aten.tensor$empty_list() -> !torch.tensor {
|
||||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
@ -735,7 +735,7 @@ builtin.func @torch.aten.tensor$empty_list() -> !torch.tensor {
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
|
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
|
||||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
|
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
|
||||||
|
@ -754,7 +754,7 @@ builtin.func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!tor
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.zeros(
|
// CHECK-LABEL: func @torch.aten.zeros(
|
||||||
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
|
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
|
||||||
|
@ -773,7 +773,7 @@ builtin.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.index_select(
|
// CHECK-LABEL: func @torch.aten.index_select(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||||
|
@ -789,7 +789,7 @@ builtin.func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %inde
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.index_select$unknown_indexes(
|
// CHECK-LABEL: func @torch.aten.index_select$unknown_indexes(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||||
|
@ -805,7 +805,7 @@ builtin.func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.index_select$unknown_dim(
|
// CHECK-LABEL: func @torch.aten.index_select$unknown_dim(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||||
|
@ -820,7 +820,7 @@ builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4],
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.select.int(
|
// CHECK-LABEL: func @torch.aten.select.int(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||||
|
@ -836,7 +836,7 @@ builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index:
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.type_as(
|
// CHECK-LABEL: func @torch.aten.type_as(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
|
||||||
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
|
||||||
|
@ -849,7 +849,7 @@ builtin.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch
|
||||||
return %ret: !torch.tensor
|
return %ret: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.gather(
|
// CHECK-LABEL: func @torch.aten.gather(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||||
|
@ -866,7 +866,7 @@ builtin.func @torch.aten.gather(%input: !torch.tensor<[2,3,4], f32>, %dim: !torc
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.expand(
|
// CHECK-LABEL: func @torch.aten.expand(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
|
||||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
|
@ -888,7 +888,7 @@ builtin.func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.t
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.expand$higher_rank(
|
// CHECK-LABEL: func @torch.aten.expand$higher_rank(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
|
||||||
|
@ -913,7 +913,7 @@ builtin.func @torch.aten.expand$higher_rank(%input: !torch.tensor<[2,1,4], f32>)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.expand$unknown_sizes(
|
// CHECK-LABEL: func @torch.aten.expand$unknown_sizes(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
|
||||||
// CHECK-SAME: %[[SIZEX:.*]]: !torch.int) -> !torch.tensor {
|
// CHECK-SAME: %[[SIZEX:.*]]: !torch.int) -> !torch.tensor {
|
||||||
|
@ -933,7 +933,7 @@ builtin.func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.repeat(
|
// CHECK-LABEL: func @torch.aten.repeat(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
|
||||||
// CHECK-SAME: %[[REPEATX:.*]]: !torch.int) -> !torch.tensor {
|
// CHECK-SAME: %[[REPEATX:.*]]: !torch.int) -> !torch.tensor {
|
||||||
|
@ -952,7 +952,7 @@ builtin.func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !t
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @torch.aten.cat(
|
// CHECK-LABEL: func @torch.aten.cat(
|
||||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
||||||
|
@ -970,7 +970,7 @@ builtin.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tenso
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.cat$unknown_dim(
|
// CHECK-LABEL: func @torch.aten.cat$unknown_dim(
|
||||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
||||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>,
|
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||||
|
@ -986,7 +986,7 @@ builtin.func @torch.aten.cat$unknown_dim(%t0: !torch.tensor<[?,1,4], f32>, %t1:
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten._shape_as_tensor(
|
// CHECK-LABEL: func @torch.aten._shape_as_tensor(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
||||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<[3],si64>
|
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<[3],si64>
|
||||||
|
@ -997,7 +997,7 @@ builtin.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape(
|
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
|
||||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<[?],si64>
|
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<[?],si64>
|
||||||
|
@ -1008,7 +1008,7 @@ builtin.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.ten
|
||||||
return %ret : !torch.tensor
|
return %ret : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.embedding(
|
// CHECK-LABEL: func @torch.aten.embedding(
|
||||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
|
||||||
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||||
|
@ -1024,7 +1024,7 @@ builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indic
|
||||||
return %ret: !torch.tensor
|
return %ret: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.softmax.int(
|
// CHECK-LABEL: func @torch.aten.softmax.int(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||||
|
@ -1039,7 +1039,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
|
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||||
|
@ -1054,7 +1054,7 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||||
|
@ -1067,7 +1067,7 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
||||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||||
|
@ -1096,7 +1096,7 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
||||||
return %0 : !torch.tensor
|
return %0 : !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
||||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
||||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
||||||
|
@ -1107,7 +1107,7 @@ func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
||||||
return %0: !torch.tensor
|
return %0: !torch.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.BinaryBroadcasting(
|
// CHECK-LABEL: func @torch.aten.BinaryBroadcasting(
|
||||||
// CHECK-SAME: %[[T0:.*]]: !torch.vtensor<[5,4,3,3,1],f32>,
|
// CHECK-SAME: %[[T0:.*]]: !torch.vtensor<[5,4,3,3,1],f32>,
|
||||||
// CHECK-SAME: %[[T1:.*]]: !torch.vtensor<[?,3,1,2],f32>,
|
// CHECK-SAME: %[[T1:.*]]: !torch.vtensor<[?,3,1,2],f32>,
|
||||||
|
|
Loading…
Reference in New Issue