Add broadcast

pull/366/head
George Petterson 2021-10-19 04:25:08 -04:00 committed by Yi Zhang
parent a459e09ab7
commit 8853dfbc74
6 changed files with 154 additions and 2 deletions

View File

@ -377,7 +377,23 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module):
def forward(self, tensor):
return self.softmax.forward(tensor)
@register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module())
def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils):
module.forward(torch.randn(3, 2, 4).double())
class BroadcastToModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, 1], torch.float32, True),
])
def forward(self, x):
return torch.broadcast_to(x, [1, -1, -1, 4])
@register_test_case(module_factory=lambda: BroadcastToModule())
def BroadcastToModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 1, 1))

View File

@ -1475,6 +1475,20 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [
let assemblyFormat = "$self `,` $size `,` $implicit attr-dict `:` type($self) `,` type($size) `,` type($implicit) `->` type($result)";
}
def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
TorchIntListType:$size
);
let results = (outs
AnyTorchTensorType:$result
);
let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)";
}
def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
AllowsTypeRefinement,
HasValueSemantics

View File

@ -2134,6 +2134,105 @@ public:
};
} // namespace
namespace {
class ConvertAtenBroadcastToOp : public OpConversionPattern<AtenBroadcastToOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBroadcastToOp op, llvm::ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
AtenBroadcastToOp::Adaptor adaptor(operands);
Value self = adaptor.self();
auto selfType = self.getType().cast<RankedTensorType>();
ArrayRef<int64_t> selfShape = selfType.getShape();
Type elementType = selfType.getElementType();
Location loc = op.getLoc();
MLIRContext *context = op->getContext();
SmallVector<Value> inShape, outShape;
if (!getListConstructElements(adaptor.size(), inShape)) {
return rewriter.notifyMatchFailure(
op, "unimplemented: the size list is not from list construct");
}
SmallVector<Value> inShapeConverted =
getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape);
if (inShape.size() < selfShape.size())
return rewriter.notifyMatchFailure(
op, "invalid shape: must not be smaller than rank of tensor");
size_t diff = inShape.size() - selfShape.size();
// Create affine map and shapes for tensor initialization.
SmallVector<AffineExpr> outExpr;
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
for (size_t i = 0; i < inShape.size(); i++) {
Value shapeValue = inShapeConverted[i];
size_t j = i - diff;
if (i < diff) {
Value isValid = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, shapeValue, zero);
rewriter.create<AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"negative values not allowed in new dimensions"));
outShape.push_back(castIntToIndex(rewriter, loc, shapeValue));
continue;
}
if (selfShape[j] == 1) {
// Broadcast singleton dimension
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value select = rewriter.create<SelectOp>(
loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue));
outShape.push_back(select);
outExpr.push_back(mlir::getAffineConstantExpr(0, context));
continue;
}
// Non-broadcast case
Value dim = getDimOp(rewriter, loc, self, j);
Value isNegative = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, shapeValue, zero);
Value isEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim),
shapeValue);
Value isValid = rewriter.create<arith::OrIOp>(loc, isNegative, isEqual);
rewriter.create<AssertOp>(
loc, isValid,
rewriter.getStringAttr(
"only broadcasting singleton dimensions supported"));
outShape.push_back(dim);
outExpr.push_back(mlir::getAffineDimExpr(i, context));
}
Value outTensor =
rewriter.create<linalg::InitTensorOp>(loc, outShape, elementType);
SmallVector<AffineMap> indexingMaps = {
AffineMap::get(inShape.size(), 0, outExpr, context),
rewriter.getMultiDimIdentityMap(inShape.size())};
SmallVector<StringRef> iteratorTypes(inShape.size(), "parallel");
Value result = rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), self, outTensor,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -2195,6 +2294,8 @@ public:
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
target.addIllegalOp<AtenLayerNormOp>();
patterns.add<ConvertAtenLayerNormOp>(typeConverter, context);
target.addIllegalOp<AtenBroadcastToOp>();
patterns.add<ConvertAtenBroadcastToOp>(typeConverter, context);
target.addIllegalOp<AtenArgmaxOp>();
patterns.add<ConvertAtenArgmaxOp>(typeConverter, context);
target.addIllegalOp<AtenSizeIntOp>();

View File

@ -90,7 +90,8 @@ public:
if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(op)) {
copyToValueTensorOps.push_back(copyToValueTensor);
} else if (isa<AtenUnsqueezeOp, AtenFlattenUsingIntsOp,
AtenTransposeIntOp, TensorStaticInfoCastOp>(op)) {
AtenTransposeIntOp, TensorStaticInfoCastOp>(op),
AtenBroadcastToOp > (op)) {
viewLikeOps.push_back(op);
llvm::append_range(workList, op->getResult(0).getUsers());
} else {

View File

@ -347,6 +347,8 @@ public:
targetDim = size == -1 ? inputDim : size;
};
return visitExpandLikeOp(expand, expand.size(), operands, setDim);
} else if (auto broadcast = dyn_cast<AtenBroadcastToOp>(op)) {
return visitBroadcastToOp(broadcast, broadcast.size(), operands);
} else if (auto repeat = dyn_cast<AtenRepeatOp>(op)) {
// The repeats list specify the number of times to repeat along each dim
// of the original tensor.
@ -447,6 +449,9 @@ private:
ArrayRef<LatticeElement<ValueKnowledge> *> operands,
SetDimSizePerListItemFn setDim);
ChangeResult
visitBroadcastToOp(Operation *op, Value list,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
visitAtenCatOp(AtenCatOp op,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult
@ -997,6 +1002,20 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp(
return getLatticeElement(op->getResult(0)).join(knowledge);
}
ChangeResult TypeAnalyzer::visitBroadcastToOp(
Operation *op, Value list,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
if (!input.hasSizes)
return getLatticeElement(op->getResult(0)).join(knowledge);
fillInSizesGivenSizesList(knowledge, list);
return getLatticeElement(op->getResult(0)).join(knowledge);
}
// `torch.aten.cat` concatenates the given sequence of seq tensors in the given
// dimension. The output has the same sizes as the input for all dimensions
// except the given dimension.

View File

@ -520,6 +520,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry):
emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)")
emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)")
emit("aten::expand : (Tensor, int[], bool) -> (Tensor)")
emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)")
emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)")
emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)")
emit("aten::item : (Tensor) -> (Scalar)")