[Torch] Eliminate getWithLeastStaticInformation in DecomposeAtenLinspaceOp and DecomposeAtenFakeQuantizePerTensorAffineOp (#3539)

as title
pull/3543/head
Xinyu Yang 2024-07-15 10:02:36 +08:00 committed by GitHub
parent fe9db78120
commit e5d1677894
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 21 deletions

View File

@ -7592,7 +7592,6 @@ public:
Location loc = op.getLoc();
MLIRContext *context = getContext();
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
Value none = rewriter.create<ConstantNoneOp>(loc);
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value zero =
@ -7602,13 +7601,25 @@ public:
Value addStart;
int64_t steps;
auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true);
auto fp32Type = rewriter.getF32Type();
auto arangeIntType =
getTensorTypeFromShapeValues({op.getSteps()}, si64Type);
auto arangeFp32Type =
getTensorTypeFromShapeValues({op.getSteps()}, fp32Type);
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
// specically handle steps == 1
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
op.getStart(), one);
loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none,
op.getLayout(), op.getDevice(), op.getPinMemory());
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
isa<Torch::FloatType>(op.getStart().getType())) {
addStart = rewriter.create<AtenAddScalarOp>(loc, arangeFp32Type, arange,
op.getStart(), one);
} else {
addStart = rewriter.create<AtenAddScalarOp>(loc, arangeIntType, arange,
op.getStart(), one);
}
} else {
// handle steps != 1 or dynamic steps
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
@ -7617,8 +7628,8 @@ public:
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
// create arange: [0, ..., steps - 1]
Value arange = rewriter.create<AtenArangeStartOp>(
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
op.getDevice(), op.getPinMemory());
loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none,
op.getLayout(), op.getDevice(), op.getPinMemory());
// calculate (end - start) / (steps - 1)
Value sub;
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
@ -7632,15 +7643,16 @@ public:
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
Value mulScalar =
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
op.getStart(), one);
rewriter.create<AtenMulScalarOp>(loc, arangeFp32Type, arange, div);
addStart = rewriter.create<AtenAddScalarOp>(
loc, arangeFp32Type, mulScalar, op.getStart(), one);
}
// to dtype
Value result;
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
result = rewriter.create<AtenToDtypeOp>(
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
loc, op.getType(), addStart, op.getDtype(),
/*non_blocking=*/falseVal,
/*copy=*/falseVal, /*memory_format=*/none);
} else {
Value f32Type = rewriter.create<ConstantIntOp>(
@ -8557,7 +8569,6 @@ public:
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
// input/scale
Value divScale = rewriter.create<AtenDivScalarOp>(
@ -8568,16 +8579,19 @@ public:
Value addZeroPoint = rewriter.create<AtenAddScalarOp>(
loc, op.getType(), round, op.getZeroPoint(), one);
// max(quant_min, std::nearby_int(input/scale) + zero_point)
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto tensorIntType =
ValueTensorType::get(context, ArrayRef<int64_t>{1}, si64Type);
Value max = rewriter.create<AtenMaximumOp>(
loc, op.getType(), addZeroPoint,
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMin(),
rewriter.create<AtenTensorIntOp>(loc, tensorIntType, op.getQuantMin(),
/*dtype=*/none,
/*device=*/none,
/*requires_grad=*/falseVal));
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
Value min = rewriter.create<AtenMinimumOp>(
loc, op.getType(), max,
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMax(),
rewriter.create<AtenTensorIntOp>(loc, tensorIntType, op.getQuantMax(),
/*dtype=*/none, /*device=*/none,
/*requires_grad=*/falseVal));
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))

View File

@ -402,10 +402,6 @@ FX_IMPORTER_XFAIL_SET = {
"ElementwiseRreluTrainStaticModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"FloatImplicitModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
@ -597,9 +593,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseToDtypeI64ToUI8Module_basic",
"EmptyModule_uint8",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat32WithFloat32_basic",
"Fill_TensorFloat32WithFloat64_basic",
"Fill_TensorFloat32WithInt64_basic",