mirror of https://github.com/llvm/torch-mlir
Add support for dtype argument in reduction ops
Many reduction ops take as an argument an optional output dtype that can change the type of the input tensor before the reduction is performed. This commit adds support for the optional dtype flag that had been previously ignored. Test: /tools/torchscript_e2e_test.sh -f 'ReduceSumDtype' /tools/torchscript_e2e_test.sh -f 'ReduceSumDImIntListDtype'pull/451/head
parent
73b27b32dc
commit
e6675a50d3
|
@ -30,6 +30,25 @@ def ReduceSumModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDtypeModule())
|
||||
def ReduceSumDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -49,6 +68,25 @@ def ReduceSumDimIntListModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListDtypeModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1, -1], torch.float64, True),
|
||||
])
|
||||
def forward(self, a):
|
||||
return torch.sum(a, (0, 1), dtype=torch.float32)
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ReduceSumDimIntListDtypeModule())
|
||||
def ReduceSumDimIntListDtypeModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4, 5).to(torch.float64))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ReduceSumDimIntListKeepDimModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -1705,10 +1705,14 @@ static Value createLinalgNeutralElementForReduceOp(OpBuilder &b, Location loc,
|
|||
|
||||
static Value createLinalgPayloadCalculationForReduceOp(
|
||||
OpBuilder &b, Location loc, ValueRange payloadArgs, Operation *op,
|
||||
ArrayRef<Value> operands, Type elementType) {
|
||||
ArrayRef<Value> operands, Type resultElementType) {
|
||||
if (isa<AtenSumOp, AtenSumDimIntListOp>(op) &&
|
||||
elementType.isa<mlir::FloatType>())
|
||||
return b.create<arith::AddFOp>(loc, payloadArgs);
|
||||
resultElementType.isa<mlir::FloatType>()) {
|
||||
Value self =
|
||||
convertScalarToDtype(b, loc, payloadArgs[0], resultElementType);
|
||||
Value result = payloadArgs[1];
|
||||
return b.create<arith::AddFOp>(loc, self, result);
|
||||
}
|
||||
op->emitError("unimplemented lowering in "
|
||||
"createLinalgPayloadCalculationForReduceOp");
|
||||
return nullptr;
|
||||
|
|
|
@ -67,6 +67,16 @@ static Type getTypeForDTypeInteger(MLIRContext *context, int64_t dtypeInt) {
|
|||
return getTypeForScalarType(context, (ScalarType)dtypeInt);
|
||||
}
|
||||
|
||||
static Type getDtypeOrDefault(MLIRContext *context, Value optionalDtype,
|
||||
Type defaultDtype) {
|
||||
int64_t dtypeInt;
|
||||
if (matchPattern(optionalDtype, m_TorchConstantInt(&dtypeInt)))
|
||||
return getTypeForDTypeInteger(context, dtypeInt);
|
||||
else if (optionalDtype.getType().isa<Torch::NoneType>())
|
||||
return defaultDtype;
|
||||
return Type();
|
||||
}
|
||||
|
||||
static Type joinElementTypes(Type lhs, Type rhs) {
|
||||
if (!lhs)
|
||||
return rhs;
|
||||
|
@ -309,14 +319,23 @@ public:
|
|||
} else if (auto arangeStart = dyn_cast<AtenArangeStartOp>(op)) {
|
||||
return visitAtenArangeStartOp(arangeStart);
|
||||
} else if (auto sum = dyn_cast<AtenSumOp>(op)) {
|
||||
Type dtype = operands[0]->getValue().dtype;
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
Type dtype =
|
||||
getDtypeOrDefault(sum.getContext(), sum.dtype(), defaultDtype);
|
||||
return visitReductionAlongAllDimsOp(sum, dtype, operands);
|
||||
} else if (auto sumDimIntList = dyn_cast<AtenSumDimIntListOp>(op)) {
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
Type dtype = getDtypeOrDefault(sumDimIntList.getContext(),
|
||||
sumDimIntList.dtype(), defaultDtype);
|
||||
return visitReductionAlongDimIntListOp(sumDimIntList, sumDimIntList.dim(),
|
||||
sumDimIntList.keepdim(), operands);
|
||||
sumDimIntList.keepdim(), dtype,
|
||||
operands);
|
||||
} else if (auto meanDim = dyn_cast<AtenMeanDimOp>(op)) {
|
||||
return visitReductionAlongDimIntListOp(meanDim, meanDim.dim(),
|
||||
meanDim.keepdim(), operands);
|
||||
Type defaultDtype = operands[0]->getValue().dtype;
|
||||
Type dtype = getDtypeOrDefault(meanDim.getContext(), meanDim.dtype(),
|
||||
defaultDtype);
|
||||
return visitReductionAlongDimIntListOp(
|
||||
meanDim, meanDim.dim(), meanDim.keepdim(), dtype, operands);
|
||||
} else if (auto argmax = dyn_cast<AtenArgmaxOp>(op)) {
|
||||
Value dim = argmax.dim();
|
||||
Type dtype = IntegerType::get(op->getContext(), 64, IntegerType::Signed);
|
||||
|
@ -483,7 +502,7 @@ private:
|
|||
Operation *op, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitReductionAlongDimIntListOp(
|
||||
Operation *op, Value dim, Value keepdim,
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands);
|
||||
ChangeResult visitReductionAlongDimIntOp(
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
|
@ -980,12 +999,12 @@ ChangeResult TypeAnalyzer::visitReductionAlongAllDimsOp(
|
|||
// These ops do caculation along the dims given by the integer list and reduce
|
||||
// each dim to size one. If \p keepdim is false, the dims are squeezed.
|
||||
ChangeResult TypeAnalyzer::visitReductionAlongDimIntListOp(
|
||||
Operation *op, Value dim, Value keepdim,
|
||||
Operation *op, Value dim, Value keepdim, Type dtype,
|
||||
ArrayRef<LatticeElement<ValueKnowledge> *> operands) {
|
||||
auto input = operands[0]->getValue();
|
||||
auto knowledge =
|
||||
ValueKnowledge::getNotNonePessimisticValueState(op->getContext());
|
||||
knowledge.dtype = input.dtype;
|
||||
knowledge.dtype = dtype;
|
||||
llvm::SmallVector<int64_t> dimList;
|
||||
bool keepDim;
|
||||
if (matchPattern(keepdim, m_TorchConstantBool(&keepDim))) {
|
||||
|
|
Loading…
Reference in New Issue