[TOSA] Fix broadcast_to input and output different shape support (#1855)

pull/1869/head snapshot-20230210.745
Chi_Liu 2023-02-09 09:15:14 -08:00 committed by GitHub
parent 83534370c3
commit cc819e73dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 16 deletions

View File

@ -679,6 +679,10 @@ TOSA_PASS_SET = {
"HardsigmoidRandomModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"FullLikeModuleInt2DStatic_basic",
"FullModuleInt3D_basic",
"FullModuleFloat2D_basic",
"RepeatModule_basic"
}
LTC_XFAIL_SET = {

View File

@ -3144,31 +3144,70 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
op, "Only floating-point or integer datatype legalization supported");
}
SmallVector<int64_t> outShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outShape)))
SmallVector<int64_t> resultShape;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(resultShape)))
return rewriter.notifyMatchFailure(op,
"size must consist of Scalar constants");
// Get the result type
auto resultType = getTypeConverter()->convertType(op.getType());
SmallVector<int64_t> inputShape(
makeShapeTorchCompatible(selfType.getShape()));
if (inputShape.size() == outShape.size() || inputShape.size() == 0) {
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand
// irrespective of the users of the op result.
if (!llvm::equal(inputShape, outShape)) {
for (auto user : op->getResult(0).getUsers()) {
// This case is only supported if the result of the `broadcast_to` op is
// not used by an op which is a view like.
if (isViewLikeOp(user)) {
// Check for identity case i.e, for ex: [a, b, c] -> [a, b, c]. If this is
// true then we can replace the op result with the input operand directly.
if (llvm::equal(inputShape, resultShape)) {
// If we reach here, then it means that the broadcasting is not required
// since the input and result are of same shape.
op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op);
return success();
} else if (selfType.hasRank() &&
(selfType.getRank() == (int64_t)resultShape.size() ||
selfType.getRank() == 0)) {
// Right now to support limited cases where input and result shape are not
// equal, we can put a constraint that either the input should be of rank
// 0 or the rank of input tensor and result should be equal. And then we
// can check for broadcasting compatibility for the latter case. For
// broadcasting compatibility, either the shape of input and result should
// be equal at each dimenion or one of them should be 1.
if (selfType.getRank() != 0) {
for (unsigned i = 0; i < inputShape.size(); i++) {
if (inputShape[i] != resultShape[i] && inputShape[i] != 1 &&
resultShape[i] != 1) {
return rewriter.notifyMatchFailure(
op, "unimplemented: broadcast not supported for this case");
op, "unimplemented: either the shape of input and result should "
"be equal at each dimenion or one of them should be 1.");
}
}
}
// If we reach here, then it means the given case is handled by implicit
// broadcasting done by tosa.
op.replaceAllUsesWith(op.getSelf());
rewriter.eraseOp(op);
// If the above condition hold true then we can directly create a const
// zero tensor of shape same as the result shape.
SmallVector<int64_t> zeroTensorShape{resultShape};
// create the 0 constant tensor
int64_t totalNumElements = 1;
for (auto dimSize : zeroTensorShape) {
totalNumElements = dimSize * totalNumElements;
}
// There is some danger here. For edge cases in floating point, x + 0 != x.
// The cases are denormalized values, which may get flushed, and -0 + 0 =
// +0. (sign bit flips). These are probably acceptable in the short term,
// but we should put a comment acknowledging the danger, as there isn't an
// op that avoids the denorm flushing.
SmallVector<int64_t> intValues(totalNumElements, 0);
SmallVector<float> floatValues(totalNumElements, 0.0);
Value zeroTensor = selfType.getElementType().isa<mlir::FloatType>()
? tosa::getConstTensor<float>(
rewriter, op, floatValues, zeroTensorShape)
.value()
: tosa::getConstTensor<int64_t>(
rewriter, op, intValues, zeroTensorShape)
.value();
// Use add broadcast
rewriter.replaceOpWithNewOp<tosa::AddOp>(op, resultType, adaptor.getSelf(),
zeroTensor);
return success();
}
return rewriter.notifyMatchFailure(