mirror of https://github.com/llvm/torch-mlir
reimplement linear lowering torchToMhlo (#1524)
parent
15b249777b
commit
50b524546f
|
@ -373,10 +373,10 @@ public:
|
||||||
auto lhsRank = lhsTy.getRank();
|
auto lhsRank = lhsTy.getRank();
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
|
||||||
if (lhsRank != 2 && lhsRank != 3)
|
if (lhsRank < 1)
|
||||||
return op.emitError("aten.Linear called but input rank not 2 or 3");
|
return op.emitError("aten.Linear called but input rank 0");
|
||||||
if (rhsRank != 2 && rhsRank != 3)
|
if (rhsRank != 2)
|
||||||
return op.emitError("aten.Linear called but weight rank not 2 or 3");
|
return op.emitError("aten.Linear called but weight rank not 2");
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -406,33 +406,59 @@ public:
|
||||||
|
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
||||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
auto lhsRank = lhsTy.getRank();
|
||||||
rhsTy.getRank() - lhsTy.getRank());
|
|
||||||
|
|
||||||
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
auto loc = op->getLoc();
|
||||||
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
|
Value dotLhs;
|
||||||
options.dimSizeIndexBits);
|
SmallVector<Value> resultDims;
|
||||||
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
|
// vector * matrix or matrix * matrix can directly use mhlo.dot_general
|
||||||
auto nBatchDims = resultRank - 2;
|
if (lhsTy.getRank() <= 2) {
|
||||||
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));
|
dotLhs = lhs;
|
||||||
|
} else {
|
||||||
|
// Broadcast weight and then use bmm would lead to too much data copy,
|
||||||
|
// and more compute, then decreace the performance
|
||||||
|
// Instead, reshape input to 2-D tensor, then use dot to perform
|
||||||
|
// matrix-matrix multiply, and finnaly reshape to the output shape,
|
||||||
|
// would get better performance
|
||||||
|
|
||||||
auto lhsResultDim = nBatchDims;
|
// [x_1, x_2, ..., x_n, in_features] * [in_features, out_features]
|
||||||
auto rhsResultDim = nBatchDims + 1;
|
// -> [x_1 * x_2 * ... * x_n , in_features] * [in_features, out_features]
|
||||||
auto lhsContractingDim = nBatchDims + 1;
|
auto dotLhsTy = RankedTensorType::get(
|
||||||
auto rhsContractingDim = nBatchDims;
|
{ShapedType::kDynamicSize, lhsTy.getShape()[lhsRank - 1]},
|
||||||
|
lhsTy.getElementType());
|
||||||
|
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
|
||||||
|
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||||
|
Value numel = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
|
||||||
|
for (int i = 0; i < lhsRank - 1; ++i) {
|
||||||
|
Value dimValue = rewriter.create<tensor::DimOp>(loc, lhs, i);
|
||||||
|
resultDims.push_back(dimValue);
|
||||||
|
numel = rewriter.create<arith::MulIOp>(
|
||||||
|
loc, numel,
|
||||||
|
rewriter.create<arith::IndexCastOp>(loc, intType, dimValue));
|
||||||
|
}
|
||||||
|
Value lhsLastRankDim = rewriter.create<arith::IndexCastOp>(
|
||||||
|
loc, intType, rewriter.create<tensor::DimOp>(loc, lhs, lhsRank - 1));
|
||||||
|
resultDims.push_back(rewriter.create<tensor::DimOp>(loc, rhs, 1));
|
||||||
|
Value reshapeDim =
|
||||||
|
rewriter
|
||||||
|
.create<mlir::tensor::FromElementsOp>(
|
||||||
|
op->getLoc(), ValueRange{numel, lhsLastRankDim})
|
||||||
|
.getResult();
|
||||||
|
dotLhs = rewriter.create<mhlo::DynamicReshapeOp>(loc, dotLhsTy, lhs,
|
||||||
|
reshapeDim);
|
||||||
|
}
|
||||||
|
Value matmulOutput =
|
||||||
|
rewriter.create<mhlo::DotOp>(loc, dotLhs, rhs, nullptr);
|
||||||
auto outTy =
|
auto outTy =
|
||||||
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
|
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
||||||
lhsContractingDim, rhsContractingDim);
|
// reshape to [x_1, x_2, ..., x_n, out_features]
|
||||||
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
|
if (dotLhs != lhs) {
|
||||||
mhlo::DotDimensionNumbersAttr::get(
|
matmulOutput = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||||
rewriter.getContext(),
|
loc, outTy, matmulOutput,
|
||||||
/*lhsBatchingDimensions=*/batchDims,
|
rewriter.create<mlir::tensor::FromElementsOp>(loc, resultDims));
|
||||||
/*rhsBatchingDimensions=*/batchDims,
|
}
|
||||||
/*lhsContractingDimensions=*/{lhsContractingDim},
|
|
||||||
/*rhsContractingDimensions=*/{rhsContractingDim});
|
|
||||||
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
|
|
||||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
|
||||||
|
|
||||||
Value matmulPlusBias = matmulOutput;
|
Value matmulPlusBias = matmulOutput;
|
||||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
if (!biasTy.template isa<Torch::NoneType>()) {
|
||||||
|
@ -443,9 +469,7 @@ public:
|
||||||
.getResult();
|
.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultTy =
|
rewriter.replaceOp(op, matmulPlusBias);
|
||||||
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -499,3 +499,60 @@ func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7
|
||||||
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
|
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
|
||||||
return %3 : !torch.vtensor<[1,4,15,15],f32>
|
return %3 : !torch.vtensor<[1,4,15,15],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.linear(
|
||||||
|
// CHECK-NOT: mhlo.dynamic_reshape
|
||||||
|
// CHECK: mhlo.transpose
|
||||||
|
// CHECK: mhlo.dot
|
||||||
|
// CHECK: chlo.broadcast_add
|
||||||
|
func.func @torch.aten.linear(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[4,5],f32> {
|
||||||
|
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[4,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.linear$nobias(
|
||||||
|
// CHECK-NOT: mhlo.dynamic_reshape
|
||||||
|
// CHECK: mhlo.transpose
|
||||||
|
// CHECK: mhlo.dot
|
||||||
|
// CHECK-NOT: chlo.broadcast_add
|
||||||
|
func.func @torch.aten.linear$nobias(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[4,5],f32> {
|
||||||
|
%none = torch.constant.none
|
||||||
|
%1 = torch.aten.linear %arg0, %arg1, %none : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.none -> !torch.vtensor<[4,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[4,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.linear$dynamic(
|
||||||
|
// CHECK: mhlo.transpose
|
||||||
|
// CHECK: arith.muli
|
||||||
|
// CHECK: arith.muli
|
||||||
|
// CHECK: tensor.from_elements
|
||||||
|
// CHECK: mhlo.dynamic_reshape
|
||||||
|
// CHECK: mhlo.dot
|
||||||
|
// CHECK: mhlo.dynamic_reshape
|
||||||
|
// CHECK: chlo.broadcast_add
|
||||||
|
func.func @torch.aten.linear$dynamic(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,5],f32> {
|
||||||
|
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[?,?,5],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @torch.aten.linear$dynamic4D(
|
||||||
|
// CHECK: mhlo.transpose
|
||||||
|
// CHECK: arith.muli
|
||||||
|
// CHECK: arith.muli
|
||||||
|
// CHECK: tensor.from_elements
|
||||||
|
// CHECK: mhlo.dynamic_reshape
|
||||||
|
// CHECK: mhlo.dot
|
||||||
|
// CHECK: mhlo.dynamic_reshape
|
||||||
|
// CHECK: chlo.broadcast_add
|
||||||
|
func.func @torch.aten.linear$dynamic4D(%arg0: !torch.vtensor<[?,?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,?,5],f32> {
|
||||||
|
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,?,5],f32>
|
||||||
|
return %1 : !torch.vtensor<[?,?,?,5],f32>
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue