mirror of https://github.com/llvm/torch-mlir
parent
17a4843cf7
commit
3fd9b7789e
|
@ -13,6 +13,10 @@ if(POLICY CMP0077)
|
|||
cmake_policy(SET CMP0077 NEW)
|
||||
endif()
|
||||
|
||||
if(POLICY CMP0116)
|
||||
cmake_policy(SET CMP0116 OLD)
|
||||
endif()
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Project setup and globals
|
||||
#-------------------------------------------------------------------------------
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 63f0c00d38ee7879239975a6743d4e6c7847b725
|
||||
Subproject commit 881ff4e4ebe8cc0cc045c7c167cffb01f94f27f8
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
@ -275,14 +276,14 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
|
|||
SmallVectorImpl<int64_t> &highPaddingInts,
|
||||
Value pad) {
|
||||
Location loc = op->getLoc();
|
||||
Type rankedTensorType = linalg::PadTensorOp::inferResultType(
|
||||
Type rankedTensorType = tensor::PadOp::inferResultType(
|
||||
input.getType().cast<RankedTensorType>(), lowPaddingInts,
|
||||
highPaddingInts);
|
||||
SmallVector<OpFoldResult> lowPaddings =
|
||||
getAsOpFoldResult(b, loc, lowPaddingInts);
|
||||
SmallVector<OpFoldResult> highPaddings =
|
||||
getAsOpFoldResult(b, loc, highPaddingInts);
|
||||
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
|
||||
Value paddedInput = tensor::createPadScalarOp(
|
||||
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
|
||||
/*packing=*/false, loc, b);
|
||||
return paddedInput;
|
||||
|
|
|
@ -1059,7 +1059,7 @@ public:
|
|||
// Step: generate the common dim/shape information
|
||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||
bool isDynamicDim =
|
||||
lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]);
|
||||
ShapedType::isDynamic(lhsBroadcastedShape[dim]);
|
||||
if (isDynamicDim ||
|
||||
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
|
||||
commonValue *= lhsBroadcastedShape[dim];
|
||||
|
@ -1071,7 +1071,7 @@ public:
|
|||
bool hasDynamicDims = false;
|
||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||
bool isDynamicDim =
|
||||
lhsBroadcastedTy.isDynamic(lhsBroadcastedShape[dim]);
|
||||
ShapedType::isDynamic(lhsBroadcastedShape[dim]);
|
||||
hasDynamicDims |= isDynamicDim;
|
||||
if (!isDynamicDim &&
|
||||
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
|
||||
|
@ -1156,7 +1156,7 @@ public:
|
|||
hasDynamicDims = false;
|
||||
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
|
||||
bool isDynamicDim =
|
||||
rhsBroadcastedTy.isDynamic(rhsBroadcastedShape[dim]);
|
||||
ShapedType::isDynamic(rhsBroadcastedShape[dim]);
|
||||
hasDynamicDims |= isDynamicDim;
|
||||
if (!isDynamicDim &&
|
||||
rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) {
|
||||
|
|
|
@ -63,28 +63,6 @@ void TorchDialect::initialize() {
|
|||
addInterfaces<TorchInlinerInterface>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type-related Dialect methods.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type TorchDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
Type type;
|
||||
if (generatedTypeParser(parser, keyword, type).hasValue())
|
||||
return type;
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "invalid 'torch' type: `")
|
||||
<< keyword << "'";
|
||||
return Type();
|
||||
}
|
||||
|
||||
void TorchDialect::printType(Type type, DialectAsmPrinter &printer) const {
|
||||
if (failed(generatedTypePrinter(type, printer)))
|
||||
llvm_unreachable("unknown 'torch' type");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect-level verifiers.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -321,9 +321,9 @@ static ParseResult parsePrimIfOp(OpAsmParser &parser, OperationState &result) {
|
|||
|
||||
static void print(OpAsmPrinter &p, PrimIfOp op) {
|
||||
p << " " << op.condition();
|
||||
p << " -> (" << op.getResultTypes() << ")";
|
||||
p << " -> (" << op.getResultTypes() << ") ";
|
||||
p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false);
|
||||
p << " else";
|
||||
p << " else ";
|
||||
p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/false);
|
||||
|
||||
p.printOptionalAttrDict(op->getAttrs());
|
||||
|
|
|
@ -38,7 +38,7 @@ public:
|
|||
matchAndRewrite(FuncOp func, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MLIRContext *context = func.getContext();
|
||||
auto typeBoundIdent = Identifier::get("torch.type_bound", context);
|
||||
auto typeBoundIdent = StringAttr::get(context, "torch.type_bound");
|
||||
TypeConverter::SignatureConversion conversion(func.getNumArguments());
|
||||
|
||||
// The TypeConverter hooks for type conversion are "context free", so we
|
||||
|
|
|
@ -164,7 +164,7 @@ struct FuncBackendTypeConversionPass
|
|||
typeConverter.addConversion([](Type type) { return type; });
|
||||
TorchConversion::setupBackendTypeConversion(target, typeConverter);
|
||||
|
||||
populateFuncOpTypeConversionPattern(patterns, typeConverter);
|
||||
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns, typeConverter);
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
|
|
|
@ -173,7 +173,7 @@ MlirValue IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
torchMlirTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
mlirRegionCreate());
|
||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr, nullptr));
|
||||
MlirBlock nnModuleBody = mlirRegionGetFirstBlock(nnModuleRegion);
|
||||
auto inserter = caffe2::MakeGuard([&]() {
|
||||
mlirBlockInsertOwnedOperationBefore(
|
||||
|
@ -441,7 +441,7 @@ void IValueImporter::importClassType(c10::ClassType *classType) {
|
|||
mlirStringAttrGet(
|
||||
context, toMlirStringRef(classType->name()->qualifiedName()))));
|
||||
MlirRegion region = mlirOperationGetRegion(op, 0);
|
||||
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr));
|
||||
mlirRegionAppendOwnedBlock(region, mlirBlockCreate(0, nullptr, nullptr));
|
||||
MlirBlock classTypeBody = mlirRegionGetFirstBlock(region);
|
||||
|
||||
ClassAnnotation &classAnnotation =
|
||||
|
|
|
@ -303,7 +303,8 @@ MlirBlock NodeImporter::createBlockFor(Block *jitBlock) {
|
|||
MlirLocation loc = getMlirLocationFromNode(context, paramNode);
|
||||
std::vector<MlirType> blockArgTypes =
|
||||
getMlirTypesFromValues(loc, paramNode->outputs());
|
||||
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data());
|
||||
std::vector<MlirLocation> blockArgLocs(blockArgTypes.size(), loc);
|
||||
MlirBlock block = mlirBlockCreate(blockArgTypes.size(), blockArgTypes.data(), blockArgLocs.data());
|
||||
for (int i = 0, e = mlirBlockGetNumArguments(block); i < e; i++) {
|
||||
Value *jitValue = paramNode->outputs()[i];
|
||||
MlirValue value = mlirBlockGetArgument(block, i);
|
||||
|
|
|
@ -11,7 +11,7 @@ builtin.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?
|
|||
%int7 = torch.constant.int 7
|
||||
%int8 = torch.constant.int 8
|
||||
%false = torch.constant.bool false
|
||||
// CHECK: %[[PADDED:.*]] = linalg.pad_tensor %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||
// CHECK: %[[PADDED:.*]] = tensor.pad %{{.*}} low[0, 0, 5, 6] high[0, 0, 5, 6]
|
||||
// CHECK: %[[NEUTRAL:.*]] = arith.constant -1.401300e-45 : f32
|
||||
// CHECK: %[[OUT:.*]] = linalg.fill(%[[NEUTRAL]], %{{.*}}) : f32, tensor<?x?x?x?xf32> -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
|
|
|
@ -663,7 +663,7 @@ builtin.func @prim.if$refined_type_conflicting(%none: !torch.none) -> !torch.ten
|
|||
return %res: !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.tensor.float(
|
||||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||
|
@ -680,7 +680,7 @@ builtin.func @torch.aten.tensor.float(%t: !torch.float) -> !torch.tensor {
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.tensor.float$specified_dtype(
|
||||
// CHECK-SAME: %[[t:.*]]: !torch.float) -> !torch.tensor {
|
||||
|
@ -699,7 +699,7 @@ builtin.func @torch.aten.tensor.float$specified_dtype(%t: !torch.float) -> !torc
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.tensor(
|
||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
|
||||
|
@ -718,7 +718,7 @@ builtin.func @torch.aten.tensor(%t: !torch.list<!torch.list<!torch.float>>) -> !
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.tensor$empty_list() -> !torch.tensor {
|
||||
// CHECK: %[[NONE:.*]] = torch.constant.none
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
|
@ -735,7 +735,7 @@ builtin.func @torch.aten.tensor$empty_list() -> !torch.tensor {
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.tensor$specified_dtype(
|
||||
// CHECK-SAME: %[[DATA:.*]]: !torch.list<!torch.list<!torch.float>>) -> !torch.tensor {
|
||||
|
@ -754,7 +754,7 @@ builtin.func @torch.aten.tensor$specified_dtype(%t: !torch.list<!torch.list<!tor
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.zeros(
|
||||
// CHECK-SAME: %[[DIM0:.*]]: !torch.int) -> !torch.tensor {
|
||||
|
@ -773,7 +773,7 @@ builtin.func @torch.aten.zeros(%dim0: !torch.int) -> !torch.tensor {
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.index_select(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
|
@ -789,7 +789,7 @@ builtin.func @torch.aten.index_select(%input: !torch.tensor<[2,3,4], f32>, %inde
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.index_select$unknown_indexes(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
|
@ -805,7 +805,7 @@ builtin.func @torch.aten.index_select$unknown_indexes(%input: !torch.tensor<[2,3
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.index_select$unknown_dim(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
|
@ -820,7 +820,7 @@ builtin.func @torch.aten.index_select$unknown_dim(%input: !torch.tensor<[2,3,4],
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.select.int(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
|
@ -836,7 +836,7 @@ builtin.func @torch.aten.select.int(%input: !torch.tensor<[2,3,4], f32>, %index:
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.type_as(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?],si64>,
|
||||
// CHECK-SAME: %[[OTHER:.*]]: !torch.tensor<[?,2],f32>) -> !torch.tensor {
|
||||
|
@ -849,7 +849,7 @@ builtin.func @torch.aten.type_as(%self: !torch.tensor<[?], si64>, %other: !torch
|
|||
return %ret: !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.gather(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
|
@ -866,7 +866,7 @@ builtin.func @torch.aten.gather(%input: !torch.tensor<[2,3,4], f32>, %dim: !torc
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.expand(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||
|
@ -888,7 +888,7 @@ builtin.func @torch.aten.expand(%input: !torch.tensor<[2,1,4], f32>) -> !torch.t
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.expand$higher_rank(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>) -> !torch.tensor {
|
||||
|
@ -913,7 +913,7 @@ builtin.func @torch.aten.expand$higher_rank(%input: !torch.tensor<[2,1,4], f32>)
|
|||
}
|
||||
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.expand$unknown_sizes(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
|
||||
// CHECK-SAME: %[[SIZEX:.*]]: !torch.int) -> !torch.tensor {
|
||||
|
@ -933,7 +933,7 @@ builtin.func @torch.aten.expand$unknown_sizes(%input: !torch.tensor<[2,1,4], f32
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.repeat(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[2,1,4],f32>,
|
||||
// CHECK-SAME: %[[REPEATX:.*]]: !torch.int) -> !torch.tensor {
|
||||
|
@ -952,7 +952,7 @@ builtin.func @torch.aten.repeat(%input: !torch.tensor<[2,1,4], f32>, %repeat: !t
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.cat(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
||||
|
@ -970,7 +970,7 @@ builtin.func @torch.aten.cat(%t0: !torch.tensor<[?,1,4], f32>, %t1: !torch.tenso
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.cat$unknown_dim(
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.tensor<[?,1,4],f32>,
|
||||
// CHECK-SAME: %[[T2:.*]]: !torch.tensor<[2,3,4],f32>,
|
||||
|
@ -986,7 +986,7 @@ builtin.func @torch.aten.cat$unknown_dim(%t0: !torch.tensor<[?,1,4], f32>, %t1:
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten._shape_as_tensor(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[?,1,4],f32>) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor<[?,1,4],f32> -> !torch.tensor<[3],si64>
|
||||
|
@ -997,7 +997,7 @@ builtin.func @torch.aten._shape_as_tensor(%input: !torch.tensor<[?,1,4], f32>) -
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten._shape_as_tensor$unknown_input_shape(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor) -> !torch.tensor {
|
||||
// CHECK: %[[RET:.*]] = torch.aten._shape_as_tensor %[[INPUT]] : !torch.tensor -> !torch.tensor<[?],si64>
|
||||
|
@ -1008,7 +1008,7 @@ builtin.func @torch.aten._shape_as_tensor$unknown_input_shape(%input: !torch.ten
|
|||
return %ret : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.embedding(
|
||||
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor<[104,512],f32>,
|
||||
// CHECK-SAME: %[[INDEXES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor {
|
||||
|
@ -1024,7 +1024,7 @@ builtin.func @torch.aten.embedding(%weight: !torch.tensor<[104,512],f32>, %indic
|
|||
return %ret: !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.softmax.int(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||
|
@ -1039,7 +1039,7 @@ func @torch.aten.softmax.int(%t: !torch.tensor<[2,3],f32>, %dim: !torch.int) ->
|
|||
}
|
||||
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.softmax.int$specified_dtype(
|
||||
// CHECK-SAME: %[[T:.*]]: !torch.tensor<[2,3],f32>,
|
||||
// CHECK-SAME: %[[DIM:.*]]: !torch.int) -> !torch.tensor {
|
||||
|
@ -1054,7 +1054,7 @@ func @torch.aten.softmax.int$specified_dtype(%t: !torch.tensor<[2,3],f32>, %dim:
|
|||
}
|
||||
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Matrix(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.tensor {
|
||||
|
@ -1067,7 +1067,7 @@ func @torch.aten.Matmul.Broadcast.Matrix(%arg0: !torch.vtensor<[?,?,?,?,?],f32>,
|
|||
}
|
||||
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.Matmul.Broadcast.Vector(
|
||||
// CHECK-SAME: %[[LHS:.*]]: !torch.vtensor<[?,?,?,?,?],f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !torch.vtensor<[?],f32>) -> !torch.tensor {
|
||||
|
@ -1096,7 +1096,7 @@ func @torch.aten.to.dtype(%arg0: !torch.tensor<[?,?],f32>) -> !torch.tensor{
|
|||
return %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.prim.NumToTensor.Scalar(
|
||||
// CHECK-SAME: %[[SELF:.*]]: !torch.int) -> !torch.tensor {
|
||||
// CHECK: %[[NTT:.*]] = torch.prim.NumToTensor.Scalar %[[SELF]] : !torch.int -> !torch.tensor<[],si64>
|
||||
|
@ -1107,7 +1107,7 @@ func @torch.prim.NumToTensor.Scalar(%arg0: !torch.int) -> !torch.tensor {
|
|||
return %0: !torch.tensor
|
||||
}
|
||||
|
||||
// ----
|
||||
// -----
|
||||
// CHECK-LABEL: func @torch.aten.BinaryBroadcasting(
|
||||
// CHECK-SAME: %[[T0:.*]]: !torch.vtensor<[5,4,3,3,1],f32>,
|
||||
// CHECK-SAME: %[[T1:.*]]: !torch.vtensor<[?,3,1,2],f32>,
|
||||
|
|
Loading…
Reference in New Issue