From 8853dfbc747c116ae65f56e3383ffd959f6858d4 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Tue, 19 Oct 2021 04:25:08 -0400 Subject: [PATCH] Add broadcast --- e2e_testing/torchscript/basic.py | 18 +++- .../Dialect/Torch/IR/GeneratedAtenOps.td | 14 +++ .../TorchToLinalg/TorchToLinalg.cpp | 101 ++++++++++++++++++ .../Transforms/MaximizeValueSemantics.cpp | 3 +- lib/Dialect/Torch/Transforms/RefineTypes.cpp | 19 ++++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + 6 files changed, 154 insertions(+), 2 deletions(-) diff --git a/e2e_testing/torchscript/basic.py b/e2e_testing/torchscript/basic.py index 21002e03b..d1c49acbe 100644 --- a/e2e_testing/torchscript/basic.py +++ b/e2e_testing/torchscript/basic.py @@ -377,7 +377,23 @@ class SoftmaxIntArgTypeF64Module(torch.nn.Module): def forward(self, tensor): return self.softmax.forward(tensor) - @register_test_case(module_factory=lambda: SoftmaxIntArgTypeF64Module()) def SoftmaxIntArgTypeF64Module_basic(module, tu: TestUtils): module.forward(torch.randn(3, 2, 4).double()) + +class BroadcastToModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, 1], torch.float32, True), + ]) + def forward(self, x): + return torch.broadcast_to(x, [1, -1, -1, 4]) + + +@register_test_case(module_factory=lambda: BroadcastToModule()) +def BroadcastToModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 1, 1)) \ No newline at end of file diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td index 0796a19ff..ed6053880 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedAtenOps.td @@ -1475,6 +1475,20 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ let assemblyFormat = "$self `,` $size `,` $implicit attr-dict `:` type($self) `,` type($size) `,` type($implicit) `->` type($result)"; } +def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + TorchIntListType:$size + ); + let results = (outs + AnyTorchTensorType:$result + ); + let assemblyFormat = "$self `,` $size attr-dict `:` type($self) `,` type($size) `->` type($result)"; +} + def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ AllowsTypeRefinement, HasValueSemantics diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index 7e4680dfb..f40c72e61 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -2134,6 +2134,105 @@ public: }; } // namespace +namespace { +class ConvertAtenBroadcastToOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenBroadcastToOp op, llvm::ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + AtenBroadcastToOp::Adaptor adaptor(operands); + Value self = adaptor.self(); + auto selfType = self.getType().cast(); + ArrayRef selfShape = selfType.getShape(); + Type elementType = selfType.getElementType(); + Location loc = op.getLoc(); + MLIRContext *context = op->getContext(); + + SmallVector inShape, outShape; + if (!getListConstructElements(adaptor.size(), inShape)) { + return rewriter.notifyMatchFailure( + op, "unimplemented: the size list is not from list construct"); + } + SmallVector inShapeConverted = + getTypeConvertedValues(rewriter, loc, getTypeConverter(), inShape); + if (inShape.size() < selfShape.size()) + return rewriter.notifyMatchFailure( + op, "invalid shape: must not be smaller than rank of tensor"); + size_t diff = inShape.size() - selfShape.size(); + + // Create affine map and shapes for tensor initialization. + SmallVector outExpr; + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + for (size_t i = 0; i < inShape.size(); i++) { + Value shapeValue = inShapeConverted[i]; + size_t j = i - diff; + if (i < diff) { + Value isValid = rewriter.create( + loc, arith::CmpIPredicate::sge, shapeValue, zero); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "negative values not allowed in new dimensions")); + outShape.push_back(castIntToIndex(rewriter, loc, shapeValue)); + continue; + } + if (selfShape[j] == 1) { + // Broadcast singleton dimension + Value one = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value isNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value select = rewriter.create( + loc, isNegative, one, castIntToIndex(rewriter, loc, shapeValue)); + outShape.push_back(select); + outExpr.push_back(mlir::getAffineConstantExpr(0, context)); + continue; + } + // Non-broadcast case + Value dim = getDimOp(rewriter, loc, self, j); + Value isNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value isEqual = rewriter.create( + loc, arith::CmpIPredicate::eq, castIndexToInt(rewriter, loc, dim), + shapeValue); + Value isValid = rewriter.create(loc, isNegative, isEqual); + rewriter.create( + loc, isValid, + rewriter.getStringAttr( + "only broadcasting singleton dimensions supported")); + outShape.push_back(dim); + outExpr.push_back(mlir::getAffineDimExpr(i, context)); + } + + Value outTensor = + rewriter.create(loc, outShape, elementType); + + SmallVector indexingMaps = { + AffineMap::get(inShape.size(), 0, outExpr, context), + rewriter.getMultiDimIdentityMap(inShape.size())}; + SmallVector iteratorTypes(inShape.size(), "parallel"); + Value result = rewriter + .create( + loc, outTensor.getType(), self, outTensor, + indexingMaps, iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); + + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, result); + + return success(); + } +}; +} // namespace + // ----------------------------------------------------------------------------- // The pass // ----------------------------------------------------------------------------- @@ -2195,6 +2294,8 @@ public: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 42a8ea33c..dc97fe62e 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -90,7 +90,8 @@ public: if (auto copyToValueTensor = dyn_cast(op)) { copyToValueTensorOps.push_back(copyToValueTensor); } else if (isa(op)) { + AtenTransposeIntOp, TensorStaticInfoCastOp>(op), + AtenBroadcastToOp > (op)) { viewLikeOps.push_back(op); llvm::append_range(workList, op->getResult(0).getUsers()); } else { diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 8db7cefbe..d9cdbab00 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -347,6 +347,8 @@ public: targetDim = size == -1 ? inputDim : size; }; return visitExpandLikeOp(expand, expand.size(), operands, setDim); + } else if (auto broadcast = dyn_cast(op)) { + return visitBroadcastToOp(broadcast, broadcast.size(), operands); } else if (auto repeat = dyn_cast(op)) { // The repeats list specify the number of times to repeat along each dim // of the original tensor. @@ -447,6 +449,9 @@ private: ArrayRef *> operands, SetDimSizePerListItemFn setDim); ChangeResult + visitBroadcastToOp(Operation *op, Value list, + ArrayRef *> operands); + ChangeResult visitAtenCatOp(AtenCatOp op, ArrayRef *> operands); ChangeResult @@ -997,6 +1002,20 @@ ChangeResult TypeAnalyzer::visitExpandLikeOp( return getLatticeElement(op->getResult(0)).join(knowledge); } +ChangeResult TypeAnalyzer::visitBroadcastToOp( + Operation *op, Value list, + ArrayRef *> operands) { + auto input = operands[0]->getValue(); + auto knowledge = + ValueKnowledge::getNotNonePessimisticValueState(op->getContext()); + knowledge.dtype = input.dtype; + if (!input.hasSizes) + return getLatticeElement(op->getResult(0)).join(knowledge); + + fillInSizesGivenSizesList(knowledge, list); + return getLatticeElement(op->getResult(0)).join(knowledge); +} + // `torch.aten.cat` concatenates the given sequence of seq tensors in the given // dimension. The output has the same sizes as the input for all dimensions // except the given dimension. diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 06692db52..fca55f443 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -520,6 +520,7 @@ def emit_aten_ops(torch_ir_dir: str, registry: Registry): emit("aten::embedding : (Tensor, Tensor, int, bool, bool) -> (Tensor)") emit("aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::expand : (Tensor, int[], bool) -> (Tensor)") + emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)") emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") emit("aten::item : (Tensor) -> (Scalar)")