mirror of https://github.com/llvm/torch-mlir
[Torch] Eliminate getWithLeastStaticInformation in DecomposeAtenLinspaceOp and DecomposeAtenFakeQuantizePerTensorAffineOp (#3539)
as titlepull/3543/head
parent
fe9db78120
commit
e5d1677894
|
@ -7592,7 +7592,6 @@ public:
|
|||
Location loc = op.getLoc();
|
||||
MLIRContext *context = getContext();
|
||||
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value zero =
|
||||
|
@ -7602,13 +7601,25 @@ public:
|
|||
|
||||
Value addStart;
|
||||
int64_t steps;
|
||||
auto si64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned*/ true);
|
||||
auto fp32Type = rewriter.getF32Type();
|
||||
auto arangeIntType =
|
||||
getTensorTypeFromShapeValues({op.getSteps()}, si64Type);
|
||||
auto arangeFp32Type =
|
||||
getTensorTypeFromShapeValues({op.getSteps()}, fp32Type);
|
||||
if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) {
|
||||
// specically handle steps == 1
|
||||
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
||||
op.getDevice(), op.getPinMemory());
|
||||
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, arange,
|
||||
op.getStart(), one);
|
||||
loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none,
|
||||
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
|
||||
isa<Torch::FloatType>(op.getStart().getType())) {
|
||||
addStart = rewriter.create<AtenAddScalarOp>(loc, arangeFp32Type, arange,
|
||||
op.getStart(), one);
|
||||
} else {
|
||||
addStart = rewriter.create<AtenAddScalarOp>(loc, arangeIntType, arange,
|
||||
op.getStart(), one);
|
||||
}
|
||||
} else {
|
||||
// handle steps != 1 or dynamic steps
|
||||
Value neOrNot = rewriter.create<AtenNeIntOp>(loc, op.getSteps(), one);
|
||||
|
@ -7617,8 +7628,8 @@ public:
|
|||
rewriter.getStringAttr("linspace's dynamic steps must not be 1"));
|
||||
// create arange: [0, ..., steps - 1]
|
||||
Value arange = rewriter.create<AtenArangeStartOp>(
|
||||
loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(),
|
||||
op.getDevice(), op.getPinMemory());
|
||||
loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none,
|
||||
op.getLayout(), op.getDevice(), op.getPinMemory());
|
||||
// calculate (end - start) / (steps - 1)
|
||||
Value sub;
|
||||
if (isa<Torch::FloatType>(op.getEnd().getType()) ||
|
||||
|
@ -7632,15 +7643,16 @@ public:
|
|||
loc, sub, rewriter.create<AtenSubIntOp>(loc, op.getSteps(), one));
|
||||
// calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start
|
||||
Value mulScalar =
|
||||
rewriter.create<AtenMulScalarOp>(loc, baseType, arange, div);
|
||||
addStart = rewriter.create<AtenAddScalarOp>(loc, baseType, mulScalar,
|
||||
op.getStart(), one);
|
||||
rewriter.create<AtenMulScalarOp>(loc, arangeFp32Type, arange, div);
|
||||
addStart = rewriter.create<AtenAddScalarOp>(
|
||||
loc, arangeFp32Type, mulScalar, op.getStart(), one);
|
||||
}
|
||||
// to dtype
|
||||
Value result;
|
||||
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||
result = rewriter.create<AtenToDtypeOp>(
|
||||
loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal,
|
||||
loc, op.getType(), addStart, op.getDtype(),
|
||||
/*non_blocking=*/falseVal,
|
||||
/*copy=*/falseVal, /*memory_format=*/none);
|
||||
} else {
|
||||
Value f32Type = rewriter.create<ConstantIntOp>(
|
||||
|
@ -8557,7 +8569,6 @@ public:
|
|||
Value falseVal = rewriter.create<ConstantBoolOp>(loc, false);
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto baseType = ValueTensorType::getWithLeastStaticInformation(context);
|
||||
|
||||
// input/scale
|
||||
Value divScale = rewriter.create<AtenDivScalarOp>(
|
||||
|
@ -8568,16 +8579,19 @@ public:
|
|||
Value addZeroPoint = rewriter.create<AtenAddScalarOp>(
|
||||
loc, op.getType(), round, op.getZeroPoint(), one);
|
||||
// max(quant_min, std::nearby_int(input/scale) + zero_point)
|
||||
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
|
||||
auto tensorIntType =
|
||||
ValueTensorType::get(context, ArrayRef<int64_t>{1}, si64Type);
|
||||
Value max = rewriter.create<AtenMaximumOp>(
|
||||
loc, op.getType(), addZeroPoint,
|
||||
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMin(),
|
||||
rewriter.create<AtenTensorIntOp>(loc, tensorIntType, op.getQuantMin(),
|
||||
/*dtype=*/none,
|
||||
/*device=*/none,
|
||||
/*requires_grad=*/falseVal));
|
||||
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
|
||||
Value min = rewriter.create<AtenMinimumOp>(
|
||||
loc, op.getType(), max,
|
||||
rewriter.create<AtenTensorIntOp>(loc, baseType, op.getQuantMax(),
|
||||
rewriter.create<AtenTensorIntOp>(loc, tensorIntType, op.getQuantMax(),
|
||||
/*dtype=*/none, /*device=*/none,
|
||||
/*requires_grad=*/falseVal));
|
||||
// min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point))
|
||||
|
|
|
@ -402,10 +402,6 @@ FX_IMPORTER_XFAIL_SET = {
|
|||
"ElementwiseRreluTrainStaticModule_basic",
|
||||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"EqIntModule_basic",
|
||||
"FakeQuantizePerTensorAffineCachemaskModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"FloatImplicitModule_basic",
|
||||
"GeFloatIntModule_basic",
|
||||
"GeFloatModule_basic",
|
||||
|
@ -597,9 +593,6 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
|
|||
"ElementwiseToDtypeI64ToUI8Module_basic",
|
||||
"EmptyModule_uint8",
|
||||
"EqIntModule_basic",
|
||||
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
|
||||
"FakeQuantizePerTensorAffineModule_basic",
|
||||
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
|
||||
"Fill_TensorFloat32WithFloat32_basic",
|
||||
"Fill_TensorFloat32WithFloat64_basic",
|
||||
"Fill_TensorFloat32WithInt64_basic",
|
||||
|
|
Loading…
Reference in New Issue