mirror of https://github.com/llvm/torch-mlir
Fix dynamic shapes type verifications (#1409)
* Fix dynamic shapes type verificationspull/1371/head
parent
72e422b589
commit
16dd7e2e5f
|
@ -878,10 +878,19 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
Type outputTy = getTypeConverter()->convertType(op.getType());
|
Type outputTy = getTypeConverter()->convertType(op.getType());
|
||||||
rewriter.replaceOpWithNewOp<mhlo::BatchNormInferenceOp>(
|
SmallVector<int64_t, 4> castShape{inputTy.getShape().begin(),
|
||||||
op, outputTy, input, weight, bias, runningMean, runningVar,
|
inputTy.getShape().end()};
|
||||||
rewriter.getFloatAttr(inputTy.getElementType(), eps),
|
castShape[1] = weightTy.getShape()[0];
|
||||||
rewriter.getI64IntegerAttr(1));
|
auto castTy = RankedTensorType::get(castShape, inputTy.getElementType());
|
||||||
|
// Feature counts must match among operands of mhlo::BatchNormInferenceOp.
|
||||||
|
Value inputCasted =
|
||||||
|
rewriter.create<tensor::CastOp>(op.getLoc(), castTy, input);
|
||||||
|
Value output = rewriter.create<mhlo::BatchNormInferenceOp>(
|
||||||
|
op.getLoc(), inputCasted.getType(), inputCasted, weight, bias,
|
||||||
|
runningMean, runningVar,
|
||||||
|
// 'epsilon' must satisfy constraint: 32-bit float attribute.
|
||||||
|
rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1));
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, output);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,6 +71,63 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||||
return result.getResult();
|
return result.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
||||||
|
Value &lhs, Value &rhs,
|
||||||
|
int64_t lhsResultDim, int64_t rhsResultDim,
|
||||||
|
int64_t lhsContractingDim,
|
||||||
|
int64_t rhsContractingDim) {
|
||||||
|
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
||||||
|
|
||||||
|
auto oldLhsShape = lhsTy.getShape();
|
||||||
|
auto oldRhsShape = rhsTy.getShape();
|
||||||
|
SmallVector<int64_t> lhsShape;
|
||||||
|
SmallVector<int64_t> rhsShape;
|
||||||
|
lhsShape.append(oldLhsShape.begin(), oldLhsShape.end());
|
||||||
|
rhsShape.append(oldRhsShape.begin(), oldRhsShape.end());
|
||||||
|
auto lhsContractingDimSize = lhsShape[lhsContractingDim];
|
||||||
|
auto rhsContractingDimSize = rhsShape[rhsContractingDim];
|
||||||
|
if (lhsContractingDimSize != rhsContractingDimSize) {
|
||||||
|
if (lhsContractingDimSize == ShapedType::kDynamicSize &&
|
||||||
|
rhsContractingDimSize >= 0) {
|
||||||
|
lhsShape[lhsContractingDim] = rhsContractingDimSize;
|
||||||
|
auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType());
|
||||||
|
lhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, lhs);
|
||||||
|
} else if (rhsContractingDimSize == ShapedType::kDynamicSize &&
|
||||||
|
lhsContractingDimSize >= 0) {
|
||||||
|
rhsShape[rhsContractingDim] = lhsContractingDimSize;
|
||||||
|
auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType());
|
||||||
|
rhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SmallVector<int64_t> outShape;
|
||||||
|
// set batch dims, will skip invalid dimensions
|
||||||
|
for (size_t k = 0; k < lhsShape.size(); ++k) {
|
||||||
|
if (k == lhsResultDim || k == lhsContractingDim)
|
||||||
|
continue;
|
||||||
|
outShape.push_back(lhsShape[k]);
|
||||||
|
}
|
||||||
|
for (size_t k = 0, b = 0; k < rhsShape.size(); ++k) {
|
||||||
|
if (b >= outShape.size())
|
||||||
|
break;
|
||||||
|
if (k == rhsResultDim || k == rhsContractingDim)
|
||||||
|
continue;
|
||||||
|
if (outShape[b] == ShapedType::kDynamicSize && rhsShape[k] >= 0) {
|
||||||
|
outShape[b] = rhsShape[k];
|
||||||
|
}
|
||||||
|
b++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set result dimensions
|
||||||
|
if (lhsResultDim < lhsShape.size() && lhsResultDim >= 0) {
|
||||||
|
outShape.push_back(lhsShape[lhsResultDim]);
|
||||||
|
}
|
||||||
|
if (rhsResultDim < rhsShape.size() && rhsResultDim >= 0) {
|
||||||
|
outShape.push_back(rhsShape[rhsResultDim]);
|
||||||
|
}
|
||||||
|
return RankedTensorType::get(outShape, lhsTy.getElementType());
|
||||||
|
}
|
||||||
|
|
||||||
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
Value &inpRhs, int64_t leadingRank,
|
Value &inpRhs, int64_t leadingRank,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
|
@ -183,10 +240,15 @@ public:
|
||||||
options.dimSizeIndexBits);
|
options.dimSizeIndexBits);
|
||||||
}
|
}
|
||||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||||
|
|
||||||
|
auto lhsResultDim = nBatchDims;
|
||||||
|
auto rhsResultDim = nBatchDims + 1;
|
||||||
auto lhsContractingDim = nBatchDims + 1;
|
auto lhsContractingDim = nBatchDims + 1;
|
||||||
auto rhsContractingDim = nBatchDims;
|
auto rhsContractingDim = nBatchDims;
|
||||||
if (lhsRank == 1)
|
if (lhsRank == 1) {
|
||||||
|
lhsResultDim = nBatchDims + 1;
|
||||||
lhsContractingDim = nBatchDims;
|
lhsContractingDim = nBatchDims;
|
||||||
|
}
|
||||||
|
|
||||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||||
mhlo::DotDimensionNumbersAttr::get(
|
mhlo::DotDimensionNumbersAttr::get(
|
||||||
|
@ -195,15 +257,13 @@ public:
|
||||||
/*rhsBatchingDimensions=*/batchDims,
|
/*rhsBatchingDimensions=*/batchDims,
|
||||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||||
auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
auto outTy =
|
||||||
->convertType(op.getType())
|
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||||
.template cast<RankedTensorType>();
|
lhsContractingDim, rhsContractingDim);
|
||||||
|
|
||||||
output = rewriter
|
output = rewriter
|
||||||
.create<mhlo::DotGeneralOp>(op->getLoc(), resultTy, lhs, rhs,
|
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
|
||||||
dotDimensionNumbers, nullptr)
|
dotDimensionNumbers, nullptr)
|
||||||
.getResult();
|
.getResult();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -221,7 +281,7 @@ public:
|
||||||
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
|
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
|
||||||
return op.emitError("failed to perform matmul operation");
|
return op.emitError("failed to perform matmul operation");
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||||
op,
|
op,
|
||||||
ConvertAtenOp<AtenOpT>::getTypeConverter()
|
ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
|
@ -355,9 +415,15 @@ public:
|
||||||
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
||||||
auto nBatchDims = resultRank - 2;
|
auto nBatchDims = resultRank - 2;
|
||||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
||||||
|
|
||||||
|
auto lhsResultDim = nBatchDims;
|
||||||
|
auto rhsResultDim = nBatchDims + 1;
|
||||||
auto lhsContractingDim = nBatchDims + 1;
|
auto lhsContractingDim = nBatchDims + 1;
|
||||||
auto rhsContractingDim = nBatchDims;
|
auto rhsContractingDim = nBatchDims;
|
||||||
|
|
||||||
|
auto outTy =
|
||||||
|
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
||||||
|
lhsContractingDim, rhsContractingDim);
|
||||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
||||||
mhlo::DotDimensionNumbersAttr::get(
|
mhlo::DotDimensionNumbersAttr::get(
|
||||||
rewriter.getContext(),
|
rewriter.getContext(),
|
||||||
|
@ -365,24 +431,21 @@ public:
|
||||||
/*rhsBatchingDimensions=*/batchDims,
|
/*rhsBatchingDimensions=*/batchDims,
|
||||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
/*lhsContractingDimensions=*/{lhsContractingDim},
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
/*rhsContractingDimensions=*/{rhsContractingDim});
|
||||||
|
|
||||||
auto resultTy =
|
|
||||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
|
||||||
|
|
||||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
||||||
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||||
|
|
||||||
Value matmulPlusBias = matmulOutput;
|
Value matmulPlusBias = matmulOutput;
|
||||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||||
// Bias addition broadcasts to the matmul output shape.
|
// Bias addition broadcasts to the matmul output shape.
|
||||||
matmulPlusBias =
|
matmulPlusBias = rewriter
|
||||||
rewriter
|
.create<chlo::BroadcastAddOp>(
|
||||||
.create<chlo::BroadcastAddOp>(op->getLoc(), resultTy,
|
op->getLoc(), outTy, matmulOutput, bias, nullptr)
|
||||||
matmulOutput, bias, nullptr)
|
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, resultTy, matmulPlusBias);
|
auto resultTy =
|
||||||
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -159,22 +159,24 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>)
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func.func @torch.aten.batch_norm$training(
|
// CHECK-LABEL: func.func @torch.aten.batch_norm$inference(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32>
|
||||||
// CHECK: %true = torch.constant.bool true
|
// CHECK: %[[TRUE:.*]] = torch.constant.bool true
|
||||||
// CHECK: %false = torch.constant.bool false
|
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
|
||||||
// CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01
|
// CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01
|
||||||
// CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05
|
// CHECK: %[[FLOAT1:.*]].000000e-05 = torch.constant.float 1.000000e-05
|
||||||
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
|
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||||
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x3x?x?xf32>
|
// CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex>
|
// CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex>
|
||||||
// CHECK: %[[VAL_7:.*]] = "mhlo.batch_norm_inference"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]], %[[VAL_2]], %[[VAL_3]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
// CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||||
// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<?x3x?x?xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<?x3x?x?xf32>
|
||||||
// CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32>
|
// CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor<?x3x?x?xf32> to tensor<?x3x?x?xf32>
|
||||||
func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
// CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor<?x3x?x?xf32> -> !torch.vtensor<[?,3,?,?],f32>
|
||||||
|
// CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32>
|
||||||
|
func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> {
|
||||||
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
%0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
%1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
%1 = torch.vtensor.literal(dense<1.000000e+00> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||||
%true = torch.constant.bool true
|
%true = torch.constant.bool true
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<2x3xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32>
|
||||||
func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> {
|
||||||
|
@ -20,7 +20,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor<?x3xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x3xf32>, tensor<3x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?x?xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -46,7 +46,7 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1:
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<10x3x5xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32>
|
||||||
func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> {
|
func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg1: !torch.vtensor<[10,4,5],f32>) -> !torch.vtensor<[10,3,5],f32> {
|
||||||
|
@ -72,7 +72,7 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<?x4x?xf32>, tensor<3xi64>) -> tensor<?x4x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x4xf32>, tensor<?x4x?xf32>) -> tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<?x?x?xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x?xf32> to tensor<?x?x?xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x?xf32> -> !torch.vtensor<[?,?,?],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32>
|
||||||
func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg1: !torch.vtensor<[?,4,?],f32>) -> !torch.vtensor<[?,?,?],f32> {
|
||||||
|
@ -98,7 +98,7 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x256x256xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32>
|
||||||
func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> {
|
func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, %arg1: !torch.vtensor<[4,120,256],f32>) -> !torch.vtensor<[4,256,256],f32> {
|
||||||
|
@ -124,7 +124,7 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>,
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<4x?x?xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32>
|
||||||
func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> {
|
func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[4,?,?],f32> {
|
||||||
|
@ -147,7 +147,7 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>,
|
||||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32>
|
||||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<1x?xf32>
|
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32>
|
||||||
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
|
// CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32>
|
||||||
func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> {
|
func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[1,?],f32> {
|
||||||
|
@ -170,7 +170,7 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1:
|
||||||
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
// CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64>
|
||||||
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1], rhs_contracting_dimensions = [1]>} : (tensor<?x256xf32>, tensor<?x256x?xf32>) -> tensor<?x?xf32>
|
||||||
// CHECK: %[[T9:.*]] = mhlo.convert %[[T8]] : tensor<?x?xf32>
|
// CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||||
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
// CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||||
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
// CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32>
|
||||||
func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[?,256,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||||
|
@ -185,7 +185,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256xf32>) -> tensor<?xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||||
func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
@ -200,7 +200,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor<?xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?xf32> to tensor<?xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?xf32> -> !torch.vtensor<[?],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?],f32>
|
||||||
func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> {
|
||||||
|
@ -215,7 +215,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<f32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<f32> to tensor<f32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[],f32>
|
||||||
func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> {
|
||||||
|
@ -241,7 +241,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor
|
||||||
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
// CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64>
|
||||||
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor<?x256x256xf32>
|
||||||
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [2], rhs_contracting_dimensions = [1]>} : (tensor<?x?x256xf32>, tensor<?x256x256xf32>) -> tensor<?x?x256xf32>
|
||||||
// CHECK: %[[T11:.*]] = mhlo.convert %[[T10]] : tensor<?x?x256xf32>
|
// CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<?x?x256xf32> to tensor<?x?x256xf32>
|
||||||
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
// CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<?x?x256xf32> -> !torch.vtensor<[?,?,256],f32>
|
||||||
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
|
// CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32>
|
||||||
func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> {
|
||||||
|
@ -257,7 +257,7 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc
|
||||||
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32>
|
||||||
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<?x256xf32>, tensor<256x256xf32>) -> tensor<?x256xf32>
|
||||||
// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : tensor<?x256xf32>
|
// CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<?x256xf32> to tensor<?x256xf32>
|
||||||
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
|
// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<?x256xf32> -> !torch.vtensor<[?,256],f32>
|
||||||
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
|
// CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32>
|
||||||
func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> {
|
||||||
|
|
Loading…
Reference in New Issue