Add support for dtype argument in reduction ops

Many reduction ops take as an argument an optional output dtype that
can change the type of the input tensor before the reduction is
performed. This commit adds support for the optional dtype flag that
had been previously ignored.

Test:
/tools/torchscript_e2e_test.sh -f 'ReduceSumDtype'
/tools/torchscript_e2e_test.sh -f 'ReduceSumDImIntListDtype'
pull/451/head
Ramiro Leal-Cavazos 2021-11-30 14:57:36 +00:00 committed by Yi Zhang
parent 73b27b32dc
commit e6675a50d3
3 changed files with 71 additions and 10 deletions

View File

@ -30,6 +30,25 @@ def ReduceSumModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceSumDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.sum(a, dtype=torch.float32)
@register_test_case(module_factory=lambda: ReduceSumDtypeModule())
def ReduceSumDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceSumDimIntListModule(torch.nn.Module):
def __init__(self):
super().__init__()
@ -49,6 +68,25 @@ def ReduceSumDimIntListModule_basic(module, tu: TestUtils):
# ==============================================================================
class ReduceSumDimIntListDtypeModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float64, True),
])
def forward(self, a):
return torch.sum(a, (0, 1), dtype=torch.float32)
@register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeModule())
def ReduceSumDimIntListDtypeModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5).to(torch.float64))
# ==============================================================================
class ReduceSumDimIntListKeepDimModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -1705,10 +1705,14 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
static Value createLinalgPayloadCalculationForReduceOp(
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
ArrayRef<Value> operands, Type elementType) {
ArrayRef<Value> operands, Type resultElementType) {
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
elementType.isa<mlir::FloatType>())
return b.create<arith::AddFOp>(loc, payloadArgs);
resultElementType.isa<mlir::FloatType>()) {
Value self =
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
Value result = payloadArgs[1];
return b.create<arith::AddFOp>(loc, self, result);
}
op->emitError("unimplemented lowering in "
"createLinalgPayloadCalculationForReduceOp");
return nullptr;

View File

@ -67,6 +67,16 @@ static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
return getTypeForScalarType(context, (ScalarType)dtypeInt);
}
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
Type defaultDtype) {
int64_t dtypeInt;
if (matchPattern(optionalDtype, m_TorchConstantInt(&dtypeInt)))
return getTypeForDTypeInteger(context, dtypeInt);
else if (optionalDtype.getType().isa<Torch::NoneType>())
return defaultDtype;
return Type();
}
static Type joinElementTypes(Type lhs, Type rhs) {
if (!lhs)
return rhs;
@ -309,14 +319,23 @@ public:
} else if (auto arangeStart = dyn_cast<AtenArangeStartOp>(op)) {
return visitAtenArangeStartOp(arangeStart);
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
Type dtype = operands[0]->getValue().dtype;
Type defaultDtype = operands[0]->getValue().dtype;
Type dtype =
getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
return visitReductionAlongAllDimsOp(sum, dtype, operands);
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
Type defaultDtype = operands[0]->getValue().dtype;
Type dtype = getDtypeOrDefault(sumDimIntList.getContext(),
sumDimIntList.dtype(), defaultDtype);
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
sumDimIntList.keepdim(), operands);
sumDimIntList.keepdim(), dtype,
operands);
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
meanDim.keepdim(), operands);
Type defaultDtype = operands[0]->getValue().dtype;
Type dtype = getDtypeOrDefault(meanDim.getContext(), meanDim.dtype(),
defaultDtype);
return visitReductionAlongDimIntListOp(
meanDim, meanDim.dim(), meanDim.keepdim(), dtype, operands);
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
Value dim = argmax.dim();
Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
@ -483,7 +502,7 @@ private:
Operation *op, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitReductionAlongDimIntListOp(
Operation *op, Value dim, Value keepdim,
Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
ChangeResult visitReductionAlongDimIntOp(
Operation *op, Value dim, Value keepdim, Type dtype,
@ -980,12 +999,12 @@ ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
// These ops do caculation along the dims given by the integer list and reduce
// each dim to size one. If \p keepdim is false, the dims are squeezed.
ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
Operation *op, Value dim, Value keepdim,
Operation *op, Value dim, Value keepdim, Type dtype,
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
auto input = operands[0]->getValue();
auto knowledge =
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
knowledge.dtype = input.dtype;
knowledge.dtype = dtype;
llvm::SmallVector<int64_t> dimList;
bool keepDim;
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {