diff --git a/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py b/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py index b0b523b0d..9c0fa9369 100644 --- a/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py +++ b/frontends/pytorch/python/torch_mlir_torchscript/e2e_test/reporting.py @@ -22,9 +22,10 @@ class TensorSummary: self.min = torch.min(tensor) self.max = torch.max(tensor) self.mean = torch.mean(tensor) + self.shape = list(tensor.shape) def __str__(self): - return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}' + return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}' class ErrorContext: diff --git a/frontends/pytorch/test/torchscript_e2e_test/error_reports.py b/frontends/pytorch/test/torchscript_e2e_test/error_reports.py index a87ac8002..00e4f5295 100644 --- a/frontends/pytorch/test/torchscript_e2e_test/error_reports.py +++ b/frontends/pytorch/test/torchscript_e2e_test/error_reports.py @@ -116,7 +116,7 @@ class ErroneousModule(torch.nn.Module): # CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch" # CHECK-NEXT: @ output of call to "test_tensor_value_mismatch" - # CHECK-NEXT: ERROR: value (Tensor with min=+1.0, max=+3.0, mean=+2.0000) is not close to golden value (Tensor with min=+1.5, max=+3.5, mean=+2.5000) + # CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5) @torch.jit.export def test_tensor_value_mismatch(self): if torch.jit.is_scripting(): diff --git a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp index cfde7595b..c5c5eb71e 100644 --- a/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp +++ b/lib/Conversion/TorchToLinalg/TorchToLinalg.cpp @@ -755,6 +755,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } +static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc, + Operation *op, + Type elementType) { + if (isa(op) && + elementType.isa()) + return b.create(loc, b.getFloatAttr(elementType, 0.0)); + + op->emitError("unimplemented lowering in " + "createLinalgNeutralElementForReduceOp"); + return nullptr; +} + static Value createLinalgPayloadCalculationForReduceOp( OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op, ArrayRef operands, Type elementType) { @@ -981,11 +993,16 @@ struct ConvertReductionOp : ConversionPattern { auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); Value initTensor = rewriter.create( loc, resultShape, resultType.getElementType()); + Value initValue = createLinalgNeutralElementForReduceOp( + rewriter, loc, op, resultType.getElementType()); + Value accumulator = + rewriter.create(loc, initValue, initTensor) + .getResult(0); bool hadErrorCreatingPayload = false; auto generic = rewriter.create( - loc, /*resultTensorTypes=*/initTensor.getType(), + loc, /*resultTensorTypes=*/accumulator.getType(), /*inputs=*/tensorOperand, - /*outputs=*/initTensor, + /*outputs=*/accumulator, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {