Fix lowering of reduce ops

We were not filling the `outs` with the neutral element of the
reduction, which resulted in reading uninitialized values (we were
getting lucky that sometimes the uninitialized buffers were all zero's).

Also,
- Slight tweak to error messages in the e2e framework.
pull/302/head
Sean Silva 2021-09-08 21:58:15 +00:00
parent 6724de7692
commit 5f3eb637c4
3 changed files with 22 additions and 4 deletions

View File

@ -22,9 +22,10 @@ class TensorSummary:
self.min = torch.min(tensor)
self.max = torch.max(tensor)
self.mean = torch.mean(tensor)
self.shape = list(tensor.shape)
def __str__(self):
return f'Tensor with min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4f}'
return f'Tensor with shape={self.shape} min={self.min:+0.4}, max={self.max:+0.4}, mean={self.mean:+0.4}'
class ErrorContext:

View File

@ -116,7 +116,7 @@ class ErroneousModule(torch.nn.Module):
# CHECK-NEXT: @ trace item #8 - call to "test_tensor_value_mismatch"
# CHECK-NEXT: @ output of call to "test_tensor_value_mismatch"
# CHECK-NEXT: ERROR: value (Tensor with min=+1.0, max=+3.0, mean=+2.0000) is not close to golden value (Tensor with min=+1.5, max=+3.5, mean=+2.5000)
# CHECK-NEXT: ERROR: value (Tensor with shape=[3] min=+1.0, max=+3.0, mean=+2.0) is not close to golden value (Tensor with shape=[3] min=+1.5, max=+3.5, mean=+2.5)
@torch.jit.export
def test_tensor_value_mismatch(self):
if torch.jit.is_scripting():

View File

@ -755,6 +755,18 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return nullptr;
}
static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
Operation *op,
Type elementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>())
return b.create<mlir::ConstantOp>(loc, b.getFloatAttr(elementType, 0.0));
op->emitError("unimplemented lowering in "
"createLinalgNeutralElementForReduceOp");
return nullptr;
}
static Value createLinalgPayloadCalculationForReduceOp(
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
ArrayRef<Value> operands, Type elementType) {
@ -981,11 +993,16 @@ struct ConvertReductionOp : ConversionPattern {
auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs});
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, resultShape, resultType.getElementType());
Value initValue = createLinalgNeutralElementForReduceOp(
rewriter, loc, op, resultType.getElementType());
Value accumulator =
rewriter.create<linalg::FillOp>(loc, initValue, initTensor)
.getResult(0);
bool hadErrorCreatingPayload = false;
auto generic = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/initTensor.getType(),
loc, /*resultTensorTypes=*/accumulator.getType(),
/*inputs=*/tensorOperand,
/*outputs=*/initTensor,
/*outputs=*/accumulator,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {