mirror of https://github.com/llvm/torch-mlir
Fix deprecated uses of cast/dyn_cast/dyn_cast_or_null/isa (#3243)
Like #3130, gradually replace the deprecated code https://github.com/llvm/mlir-www/blob/main/website/content/deprecation/_index.md#deprecatedpull/3244/head
parent
466618e45e
commit
6679728c56
|
@ -23,7 +23,7 @@ static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter,
|
||||||
int64_t dimA, int64_t dimB,
|
int64_t dimA, int64_t dimB,
|
||||||
Value &transposed) {
|
Value &transposed) {
|
||||||
Type transposedType;
|
Type transposedType;
|
||||||
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
|
||||||
dimA, dimB, transposedType)))
|
dimA, dimB, transposedType)))
|
||||||
return failure();
|
return failure();
|
||||||
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -554,7 +554,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
// conversions which are not supported in Torch-MLIR right now.
|
// conversions which are not supported in Torch-MLIR right now.
|
||||||
|
|
||||||
Torch::ValueTensorType targetTy =
|
Torch::ValueTensorType targetTy =
|
||||||
target.getType().cast<Torch::ValueTensorType>();
|
cast<Torch::ValueTensorType>(target.getType());
|
||||||
if (!targetTy.hasDtype()) {
|
if (!targetTy.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"target tensor must have a dtype");
|
"target tensor must have a dtype");
|
||||||
|
@ -753,9 +753,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
Type listElemType =
|
Type listElemType =
|
||||||
tensors[0]
|
cast<Torch::BaseTensorType>(tensors[0].getType())
|
||||||
.getType()
|
|
||||||
.cast<Torch::BaseTensorType>()
|
|
||||||
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||||
/*optionalDtype=*/nullptr);
|
/*optionalDtype=*/nullptr);
|
||||||
Type listType = Torch::ListType::get(listElemType);
|
Type listType = Torch::ListType::get(listElemType);
|
||||||
|
@ -869,7 +867,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
|
auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
|
||||||
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected weight type having sizes");
|
binder.op, "Expected weight type having sizes");
|
||||||
|
@ -1188,7 +1186,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto weightTensorType = weight.getType().cast<Torch::ValueTensorType>();
|
auto weightTensorType = cast<Torch::ValueTensorType>(weight.getType());
|
||||||
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
if (!weightTensorType || !weightTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected weight type having sizes");
|
binder.op, "Expected weight type having sizes");
|
||||||
|
@ -1427,7 +1425,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
|
binder.customOpNameStringAttr(mode, "mode", "DCR") ||
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
auto inputTy = input.getType().dyn_cast<Torch::BaseTensorType>();
|
auto inputTy = dyn_cast<Torch::BaseTensorType>(input.getType());
|
||||||
if (!inputTy || !inputTy.hasSizes()) {
|
if (!inputTy || !inputTy.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected input type having sizes");
|
binder.op, "Expected input type having sizes");
|
||||||
|
@ -1536,9 +1534,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
Value scale = operands[1];
|
Value scale = operands[1];
|
||||||
Value zeropoint = operands[2];
|
Value zeropoint = operands[2];
|
||||||
|
|
||||||
auto operandTy = operand.getType().cast<Torch::ValueTensorType>();
|
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
|
|
||||||
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
|
||||||
if (!scaleTy || !scaleTy.hasSizes())
|
if (!scaleTy || !scaleTy.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
||||||
if (!resultType.hasDtype())
|
if (!resultType.hasDtype())
|
||||||
|
@ -1611,7 +1609,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
ratio = rewriter.create<Torch::AtenFloatImplicitOp>(loc, operands[1]);
|
||||||
Value trainVal = operands[2];
|
Value trainVal = operands[2];
|
||||||
auto trainTensorType =
|
auto trainTensorType =
|
||||||
trainVal.getType().dyn_cast<Torch::BaseTensorType>();
|
dyn_cast<Torch::BaseTensorType>(trainVal.getType());
|
||||||
if (!trainTensorType)
|
if (!trainTensorType)
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"train tensor must have a type");
|
"train tensor must have a type");
|
||||||
|
@ -1629,8 +1627,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp =
|
if (auto valueTensorLiteralOp =
|
||||||
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
|
trainVal.getDefiningOp<Torch::ValueTensorLiteralOp>()) {
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
auto val = cast<DenseElementsAttr>(valueTensorLiteralOp.getValue())
|
||||||
.cast<DenseElementsAttr>()
|
|
||||||
.getSplatValue<bool>();
|
.getSplatValue<bool>();
|
||||||
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
|
trainingMode = rewriter.create<Torch::ConstantBoolOp>(loc, val);
|
||||||
} else {
|
} else {
|
||||||
|
@ -2072,7 +2069,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
|
||||||
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
dyn_cast<Torch::ValueTensorType>(shape.getType()).getSizes();
|
||||||
SmallVector<Value> dimList;
|
SmallVector<Value> dimList;
|
||||||
Torch::BaseTensorType shapeType =
|
Torch::BaseTensorType shapeType =
|
||||||
shape.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(shape.getType());
|
||||||
Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
|
Type selectResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
|
ArrayRef<int64_t>({}), shapeType.getOptionalDtype());
|
||||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
|
|
@ -104,10 +104,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "operand grid_sampler bind failure");
|
binder.op, "operand grid_sampler bind failure");
|
||||||
|
|
||||||
auto inputTensorType = input.getType().cast<Torch::ValueTensorType>();
|
auto inputTensorType = cast<Torch::ValueTensorType>(input.getType());
|
||||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||||
uint32_t inputRank = inputShape.size();
|
uint32_t inputRank = inputShape.size();
|
||||||
auto gridTensorType = grid.getType().cast<Torch::ValueTensorType>();
|
auto gridTensorType = cast<Torch::ValueTensorType>(grid.getType());
|
||||||
ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
|
ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
|
||||||
uint32_t gridRank = gridShape.size();
|
uint32_t gridRank = gridShape.size();
|
||||||
|
|
||||||
|
@ -233,7 +233,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
axis = rank + axis;
|
axis = rank + axis;
|
||||||
}
|
}
|
||||||
// need input type and sizes to flatten/unflatten later.
|
// need input type and sizes to flatten/unflatten later.
|
||||||
auto inputTy = input.getType().cast<Torch::ValueTensorType>();
|
auto inputTy = cast<Torch::ValueTensorType>(input.getType());
|
||||||
if (!inputTy || !inputTy.hasSizes())
|
if (!inputTy || !inputTy.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "failed to get input type or sizes");
|
binder.op, "failed to get input type or sizes");
|
||||||
|
@ -1065,7 +1065,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1));
|
||||||
|
|
||||||
auto transpose = [&](Value m) -> Value {
|
auto transpose = [&](Value m) -> Value {
|
||||||
auto tty = m.getType().cast<Torch::ValueTensorType>();
|
auto tty = cast<Torch::ValueTensorType>(m.getType());
|
||||||
auto shape = tty.getOptionalSizes();
|
auto shape = tty.getOptionalSizes();
|
||||||
if (shape.has_value()) {
|
if (shape.has_value()) {
|
||||||
llvm::SmallVector<int64_t> newShape(shape.value());
|
llvm::SmallVector<int64_t> newShape(shape.value());
|
||||||
|
@ -1134,7 +1134,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto inputTensorType = operand.getType().cast<Torch::ValueTensorType>();
|
auto inputTensorType = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected input type having sizes");
|
binder.op, "Expected input type having sizes");
|
||||||
|
@ -1228,7 +1228,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
rank = *maybeRank;
|
rank = *maybeRank;
|
||||||
SmallVector<Value> normalized;
|
SmallVector<Value> normalized;
|
||||||
axis = Torch::toPositiveDim(axis, rank);
|
axis = Torch::toPositiveDim(axis, rank);
|
||||||
auto xType = x.getType().cast<Torch::ValueTensorType>();
|
auto xType = cast<Torch::ValueTensorType>(x.getType());
|
||||||
if (!xType.hasSizes()) {
|
if (!xType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected input (X) to have sizes");
|
binder.op, "Expected input (X) to have sizes");
|
||||||
|
@ -1307,7 +1307,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
|
|
||||||
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
// Get pads shape and rank. The pads tensor is expected to be 1-D
|
||||||
// tensor.
|
// tensor.
|
||||||
auto padsTensorType = pads.getType().cast<Torch::ValueTensorType>();
|
auto padsTensorType = cast<Torch::ValueTensorType>(pads.getType());
|
||||||
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
if (!padsTensorType || !padsTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
"Expect non empty pad tensor");
|
"Expect non empty pad tensor");
|
||||||
|
@ -1323,7 +1323,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
// As per onnx.Pad documentation, padSize = 2*num_data_axes
|
||||||
// (if axes param not passed). Need to be updated when adding
|
// (if axes param not passed). Need to be updated when adding
|
||||||
// support for `axes` param.
|
// support for `axes` param.
|
||||||
auto dataOpTy = data.getType().cast<Torch::ValueTensorType>();
|
auto dataOpTy = cast<Torch::ValueTensorType>(data.getType());
|
||||||
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
TensorType dataTensor = dataOpTy.toBuiltinTensor();
|
||||||
if (!dataTensor || !dataTensor.hasRank())
|
if (!dataTensor || !dataTensor.hasRank())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1350,7 +1350,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!constantValue) {
|
if (!constantValue) {
|
||||||
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
|
auto dataTensorType = cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (dataTensorType.getDtype().isa<IntegerType>())
|
if (dataTensorType.getDtype().isa<IntegerType>())
|
||||||
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
constantValue = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(0));
|
loc, rewriter.getI64IntegerAttr(0));
|
||||||
|
|
|
@ -54,7 +54,7 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
||||||
SmallVector<Value> axesList;
|
SmallVector<Value> axesList;
|
||||||
Value axesVal;
|
Value axesVal;
|
||||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "unimplemented: expected input and result to have shapes");
|
binder.op, "unimplemented: expected input and result to have shapes");
|
||||||
|
@ -97,7 +97,7 @@ LogicalResult reducedSumImpl(OpBinder binder,
|
||||||
}
|
}
|
||||||
if (axesList.empty()) {
|
if (axesList.empty()) {
|
||||||
Torch::BaseTensorType axesType =
|
Torch::BaseTensorType axesType =
|
||||||
axesVal.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(axesVal.getType());
|
||||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||||
auto axesShape = axesTy.getSizes();
|
auto axesShape = axesTy.getSizes();
|
||||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||||
|
@ -177,7 +177,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value scale = operands[1];
|
Value scale = operands[1];
|
||||||
Value zeropoint = operands[2];
|
Value zeropoint = operands[2];
|
||||||
|
|
||||||
auto scaleTy = scale.getType().dyn_cast<Torch::ValueTensorType>();
|
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
|
||||||
if (!scaleTy || !scaleTy.hasSizes())
|
if (!scaleTy || !scaleTy.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
|
||||||
if (!resultType.hasDtype())
|
if (!resultType.hasDtype())
|
||||||
|
@ -241,7 +241,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value c = operands.size() == 9 ? operands[8] : nullptr;
|
Value c = operands.size() == 9 ? operands[8] : nullptr;
|
||||||
|
|
||||||
auto check = [](Value v) {
|
auto check = [](Value v) {
|
||||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||||
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
|
return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; });
|
||||||
};
|
};
|
||||||
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) ||
|
||||||
|
@ -250,7 +250,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, "not supported for non per-tensor quantization");
|
binder.op, "not supported for non per-tensor quantization");
|
||||||
|
|
||||||
auto extract = [&rewriter, &binder](Value v) {
|
auto extract = [&rewriter, &binder](Value v) {
|
||||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||||
Type extractTy = rewriter.getType<Torch::FloatType>();
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||||
if (isa<IntegerType>(vTy.getDtype()))
|
if (isa<IntegerType>(vTy.getDtype()))
|
||||||
extractTy = rewriter.getType<Torch::IntType>();
|
extractTy = rewriter.getType<Torch::IntType>();
|
||||||
|
@ -268,7 +268,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
auto make = [&rewriter, &binder](Value v, Value scale,
|
auto make = [&rewriter, &binder](Value v, Value scale,
|
||||||
Value zp) -> Value {
|
Value zp) -> Value {
|
||||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
||||||
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
||||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||||
binder.getLoc(), newTy, v, scale, zp);
|
binder.getLoc(), newTy, v, scale, zp);
|
||||||
|
@ -351,7 +351,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value cZp = operands[7];
|
Value cZp = operands[7];
|
||||||
|
|
||||||
auto check = [](Value v) {
|
auto check = [](Value v) {
|
||||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||||
for (auto dim : vTy.getSizes())
|
for (auto dim : vTy.getSizes())
|
||||||
if (dim != 1)
|
if (dim != 1)
|
||||||
return false;
|
return false;
|
||||||
|
@ -368,7 +368,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.getType<Torch::IntType>()),
|
rewriter.getType<Torch::IntType>()),
|
||||||
ValueRange{});
|
ValueRange{});
|
||||||
auto extract = [&rewriter, &binder, &emptyList](Value v) {
|
auto extract = [&rewriter, &binder, &emptyList](Value v) {
|
||||||
auto vTy = v.getType().cast<Torch::ValueTensorType>();
|
auto vTy = cast<Torch::ValueTensorType>(v.getType());
|
||||||
if (!vTy.getSizes().empty()) {
|
if (!vTy.getSizes().empty()) {
|
||||||
vTy = rewriter.getType<Torch::ValueTensorType>(
|
vTy = rewriter.getType<Torch::ValueTensorType>(
|
||||||
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
|
ArrayRef<int64_t>({}), vTy.getOptionalDtype());
|
||||||
|
@ -393,7 +393,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
auto make = [&rewriter, &binder](Value v, Value scale,
|
auto make = [&rewriter, &binder](Value v, Value scale,
|
||||||
Value zp) -> Value {
|
Value zp) -> Value {
|
||||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
||||||
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
auto newTy = getQTorchTypeFromTorchIntType(ty);
|
||||||
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
return rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
|
||||||
binder.getLoc(), newTy, v, scale, zp);
|
binder.getLoc(), newTy, v, scale, zp);
|
||||||
|
@ -667,7 +667,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value data = inputOperands[0];
|
Value data = inputOperands[0];
|
||||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (!inputType.hasSizes() || !resultType.hasSizes())
|
if (!inputType.hasSizes() || !resultType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op,
|
binder.op,
|
||||||
|
@ -718,7 +718,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
if (dimList.empty()) {
|
if (dimList.empty()) {
|
||||||
Value axes = inputOperands[1];
|
Value axes = inputOperands[1];
|
||||||
Torch::BaseTensorType axesType =
|
Torch::BaseTensorType axesType =
|
||||||
axes.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(axes.getType());
|
||||||
SmallVector<int64_t> selectSizes{1};
|
SmallVector<int64_t> selectSizes{1};
|
||||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
selectSizes, axesType.getOptionalDtype());
|
selectSizes, axesType.getOptionalDtype());
|
||||||
|
@ -760,7 +760,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
if (binder.tensorOperands(data, axes) ||
|
if (binder.tensorOperands(data, axes) ||
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (!inputType.hasSizes() || !resultType.hasSizes())
|
if (!inputType.hasSizes() || !resultType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op,
|
binder.op,
|
||||||
|
@ -925,8 +925,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
// Perform an AtenToDtype op on the squared sum of the operand, stored
|
// Perform an AtenToDtype op on the squared sum of the operand, stored
|
||||||
// now in operand itself.
|
// now in operand itself.
|
||||||
auto size = operand.getType()
|
auto size = dyn_cast<Torch::ValueTensorType>(operand.getType())
|
||||||
.dyn_cast<Torch::ValueTensorType>()
|
|
||||||
.getOptionalSizes();
|
.getOptionalSizes();
|
||||||
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
auto f32ResultType = rewriter.getType<Torch::ValueTensorType>(
|
||||||
size, rewriter.getF32Type());
|
size, rewriter.getF32Type());
|
||||||
|
@ -1005,7 +1004,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
Value axesVal;
|
Value axesVal;
|
||||||
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
if (!binder.tensorOperandAtIndex(axesVal, 1)) {
|
||||||
auto inputType = data.getType().dyn_cast<Torch::ValueTensorType>();
|
auto inputType = dyn_cast<Torch::ValueTensorType>(data.getType());
|
||||||
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
if (!inputType.hasSizes() || !resultType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op,
|
binder.op,
|
||||||
|
@ -1053,7 +1052,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
if (axesList.empty()) {
|
if (axesList.empty()) {
|
||||||
Torch::BaseTensorType axesType =
|
Torch::BaseTensorType axesType =
|
||||||
axesVal.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(axesVal.getType());
|
||||||
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
auto axesTy = dyn_cast<Torch::ValueTensorType>(axesVal.getType());
|
||||||
auto axesShape = axesTy.getSizes();
|
auto axesShape = axesTy.getSizes();
|
||||||
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize)
|
||||||
|
@ -1191,7 +1190,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
// Extract the axes values from the axes operand:
|
// Extract the axes values from the axes operand:
|
||||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||||
Torch::BaseTensorType axesType =
|
Torch::BaseTensorType axesType =
|
||||||
axes.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(axes.getType());
|
||||||
SmallVector<int64_t> selectSizes{1};
|
SmallVector<int64_t> selectSizes{1};
|
||||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
selectSizes, axesType.getOptionalDtype());
|
selectSizes, axesType.getOptionalDtype());
|
||||||
|
@ -1344,7 +1343,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
// Extract the axes values from the axes operand:
|
// Extract the axes values from the axes operand:
|
||||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||||
Torch::BaseTensorType axesType =
|
Torch::BaseTensorType axesType =
|
||||||
axes.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(axes.getType());
|
||||||
SmallVector<int64_t> selectSizes{1};
|
SmallVector<int64_t> selectSizes{1};
|
||||||
Type selectResultType = axesType.getWithSizesAndDtype(
|
Type selectResultType = axesType.getWithSizesAndDtype(
|
||||||
selectSizes, axesType.getOptionalDtype());
|
selectSizes, axesType.getOptionalDtype());
|
||||||
|
@ -1467,12 +1466,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
|
|
||||||
auto loc = binder.getLoc();
|
auto loc = binder.getLoc();
|
||||||
auto result0Ty =
|
auto result0Ty =
|
||||||
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
|
||||||
auto resultNTy = binder.op->getResults()
|
auto resultNTy = cast<Torch::ValueTensorType>(
|
||||||
.back()
|
binder.op->getResults().back().getType());
|
||||||
.getType()
|
auto selfTy = cast<Torch::ValueTensorType>(self.getType());
|
||||||
.cast<Torch::ValueTensorType>();
|
|
||||||
auto selfTy = self.getType().cast<Torch::ValueTensorType>();
|
|
||||||
|
|
||||||
int64_t dim = axis;
|
int64_t dim = axis;
|
||||||
if (dim < 0)
|
if (dim < 0)
|
||||||
|
@ -1555,7 +1552,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.op, "Failed to get num_outputs attribute");
|
binder.op, "Failed to get num_outputs attribute");
|
||||||
|
|
||||||
auto result0Ty =
|
auto result0Ty =
|
||||||
binder.op->getResult(0).getType().cast<Torch::ValueTensorType>();
|
cast<Torch::ValueTensorType>(binder.op->getResult(0).getType());
|
||||||
auto selfTy =
|
auto selfTy =
|
||||||
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
|
cast<Torch::ValueTensorType>(binder.op->getOperand(0).getType());
|
||||||
|
|
||||||
|
@ -1617,7 +1614,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
if (binder.tensorOperand(operand) ||
|
if (binder.tensorOperand(operand) ||
|
||||||
binder.tensorResultType(resultType))
|
binder.tensorResultType(resultType))
|
||||||
return failure();
|
return failure();
|
||||||
auto operandType = operand.getType().cast<Torch::ValueTensorType>();
|
auto operandType = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
TensorType tensorType = operandType.toBuiltinTensor();
|
TensorType tensorType = operandType.toBuiltinTensor();
|
||||||
if (!tensorType || !tensorType.hasRank())
|
if (!tensorType || !tensorType.hasRank())
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -1705,26 +1702,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto context = rewriter.getContext();
|
auto context = rewriter.getContext();
|
||||||
auto operandTorchTy = operand.getType().cast<Torch::ValueTensorType>();
|
auto operandTorchTy = cast<Torch::ValueTensorType>(operand.getType());
|
||||||
auto operandTy =
|
auto operandTy =
|
||||||
operandTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(operandTorchTy.toBuiltinTensor());
|
||||||
|
|
||||||
if (!operandTy)
|
if (!operandTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op,
|
binder.op,
|
||||||
"Expected tensor operator argument to be a ranked tensor type");
|
"Expected tensor operator argument to be a ranked tensor type");
|
||||||
|
|
||||||
auto startsTorchTy = starts.getType().cast<Torch::ValueTensorType>();
|
auto startsTorchTy = cast<Torch::ValueTensorType>(starts.getType());
|
||||||
auto startsTy =
|
auto startsTy =
|
||||||
startsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(startsTorchTy.toBuiltinTensor());
|
||||||
int startSize = startsTy.getDimSize(0);
|
int startSize = startsTy.getDimSize(0);
|
||||||
|
|
||||||
auto endsTorchTy = ends.getType().cast<Torch::ValueTensorType>();
|
auto endsTorchTy = cast<Torch::ValueTensorType>(ends.getType());
|
||||||
auto endsTy =
|
auto endsTy = dyn_cast<RankedTensorType>(endsTorchTy.toBuiltinTensor());
|
||||||
endsTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
|
||||||
int endSize = endsTy.getDimSize(0);
|
int endSize = endsTy.getDimSize(0);
|
||||||
auto resultTy =
|
auto resultTy =
|
||||||
resultTorchType.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(resultTorchType.toBuiltinTensor());
|
||||||
if (!resultTy)
|
if (!resultTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
binder.op, "Expected result type to be a ranked tensor type");
|
binder.op, "Expected result type to be a ranked tensor type");
|
||||||
|
@ -1768,9 +1764,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
"and their dimensions to match");
|
"and their dimensions to match");
|
||||||
|
|
||||||
if (axes) {
|
if (axes) {
|
||||||
auto axesTorchTy = axes.getType().cast<Torch::ValueTensorType>();
|
auto axesTorchTy = cast<Torch::ValueTensorType>(axes.getType());
|
||||||
auto axesTy =
|
auto axesTy =
|
||||||
axesTorchTy.toBuiltinTensor().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(axesTorchTy.toBuiltinTensor());
|
||||||
int64_t numAxes = axesTy.getDimSize(0);
|
int64_t numAxes = axesTy.getDimSize(0);
|
||||||
|
|
||||||
if (!(axesTy && numAxes == endSize))
|
if (!(axesTy && numAxes == endSize))
|
||||||
|
@ -1792,7 +1788,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
auto select = [&](Value v, Value k) -> Value {
|
auto select = [&](Value v, Value k) -> Value {
|
||||||
auto ty = v.getType().cast<Torch::ValueTensorType>();
|
auto ty = cast<Torch::ValueTensorType>(v.getType());
|
||||||
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
auto sel = rewriter.create<Torch::AtenIndexSelectOp>(
|
||||||
loc,
|
loc,
|
||||||
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
|
Torch::ValueTensorType::get(ty.getContext(), ArrayRef<int64_t>{1},
|
||||||
|
@ -1872,7 +1868,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
}
|
}
|
||||||
|
|
||||||
Torch::BaseTensorType shapeType =
|
Torch::BaseTensorType shapeType =
|
||||||
shape.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(shape.getType());
|
||||||
SmallVector<Value> dimList;
|
SmallVector<Value> dimList;
|
||||||
SmallVector<int64_t> selectSizes;
|
SmallVector<int64_t> selectSizes;
|
||||||
selectSizes.push_back(1);
|
selectSizes.push_back(1);
|
||||||
|
@ -2007,7 +2003,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
// instead of using the dynamic axes at operand[1].
|
// instead of using the dynamic axes at operand[1].
|
||||||
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
if (!binder.tensorOperandAtIndex(axes, 1)) {
|
||||||
Torch::BaseTensorType axesType =
|
Torch::BaseTensorType axesType =
|
||||||
axes.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(axes.getType());
|
||||||
auto sizes = axesType.getSizes();
|
auto sizes = axesType.getSizes();
|
||||||
for (int i = 0; i < sizes[0]; i++) {
|
for (int i = 0; i < sizes[0]; i++) {
|
||||||
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -2136,7 +2132,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
// int32, int64 Assuming start, limit and delta to be same type (could
|
// int32, int64 Assuming start, limit and delta to be same type (could
|
||||||
// they be different?)
|
// they be different?)
|
||||||
Torch::BaseTensorType startTensorType =
|
Torch::BaseTensorType startTensorType =
|
||||||
start.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(start.getType());
|
||||||
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
bool isFloatDType = startTensorType.getDtype().isF64() ||
|
||||||
startTensorType.getDtype().isF32();
|
startTensorType.getDtype().isF32();
|
||||||
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
bool isIntDType = startTensorType.getDtype().isInteger(16) ||
|
||||||
|
@ -2222,7 +2218,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
SmallVector<int64_t> selectSizes;
|
SmallVector<int64_t> selectSizes;
|
||||||
selectSizes.push_back(1);
|
selectSizes.push_back(1);
|
||||||
Torch::BaseTensorType shapeType =
|
Torch::BaseTensorType shapeType =
|
||||||
repeatDims.getType().cast<Torch::BaseTensorType>();
|
cast<Torch::BaseTensorType>(repeatDims.getType());
|
||||||
Type selectResultType = shapeType.getWithSizesAndDtype(
|
Type selectResultType = shapeType.getWithSizesAndDtype(
|
||||||
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype());
|
||||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
|
|
@ -95,7 +95,7 @@ public:
|
||||||
Value input = adaptor.getA();
|
Value input = adaptor.getA();
|
||||||
Type resultType =
|
Type resultType =
|
||||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
if (!input.getType().isa<mlir::FloatType>())
|
if (!isa<mlir::FloatType>(input.getType()))
|
||||||
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
|
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
|
||||||
Value result = rewriter.create<UnaryOp>(loc, input);
|
Value result = rewriter.create<UnaryOp>(loc, input);
|
||||||
rewriter.replaceOp(op,
|
rewriter.replaceOp(op,
|
||||||
|
@ -172,8 +172,8 @@ public:
|
||||||
matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor,
|
matchAndRewrite(ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
MLIRContext *context = op->getContext();
|
MLIRContext *context = op->getContext();
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
||||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = dyn_cast<RankedTensorType>(elements.getType())) {
|
||||||
Type elemTy = op.getValueAttr().getElementType();
|
Type elemTy = op.getValueAttr().getElementType();
|
||||||
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
|
||||||
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
|
||||||
|
@ -187,9 +187,9 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto elements =
|
if (auto elements =
|
||||||
op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
|
dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
|
||||||
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = dyn_cast<RankedTensorType>(elements.getType())) {
|
||||||
if (auto intType = type.getElementType().dyn_cast<IntegerType>()) {
|
if (auto intType = dyn_cast<IntegerType>(type.getElementType())) {
|
||||||
Type builtinTensorElemTy =
|
Type builtinTensorElemTy =
|
||||||
IntegerType::get(context, intType.getIntOrFloatBitWidth());
|
IntegerType::get(context, intType.getIntOrFloatBitWidth());
|
||||||
auto shapedType =
|
auto shapedType =
|
||||||
|
|
|
@ -49,8 +49,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
SmallVector<Value> &strides) {
|
SmallVector<Value> &strides) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
RankedTensorType inputType =
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
input.getType().template cast<RankedTensorType>();
|
|
||||||
|
|
||||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
@ -73,8 +72,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
Value builtinTypeStart = adaptor.getStart();
|
Value builtinTypeStart = adaptor.getStart();
|
||||||
Value builtinTypeEnd = adaptor.getEnd();
|
Value builtinTypeEnd = adaptor.getEnd();
|
||||||
|
|
||||||
if (torchTypeStart.getType().isa<OptionalType>() ||
|
if (isa<OptionalType>(torchTypeStart.getType()) ||
|
||||||
torchTypeEnd.getType().isa<OptionalType>())
|
isa<OptionalType>(torchTypeEnd.getType()))
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");
|
||||||
|
|
||||||
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
|
Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep());
|
||||||
|
@ -84,7 +83,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
// We cannot use to positive valid dim as for negative strides we need to
|
// We cannot use to positive valid dim as for negative strides we need to
|
||||||
// clamp to `-1` so that the full tensor bounds are available:
|
// clamp to `-1` so that the full tensor bounds are available:
|
||||||
Value end = builtinTypeEnd;
|
Value end = builtinTypeEnd;
|
||||||
if (torchTypeEnd.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(torchTypeEnd.getType())) {
|
||||||
end = dimSize;
|
end = dimSize;
|
||||||
} else {
|
} else {
|
||||||
end = castIntToIndex(rewriter, loc, end);
|
end = castIntToIndex(rewriter, loc, end);
|
||||||
|
@ -594,7 +593,7 @@ public:
|
||||||
int64_t endDim;
|
int64_t endDim;
|
||||||
if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim)))
|
if (!matchPattern(op.getEndDim(), m_TorchConstantInt(&endDim)))
|
||||||
return rewriter.notifyMatchFailure(op, "end_dim must be constant");
|
return rewriter.notifyMatchFailure(op, "end_dim must be constant");
|
||||||
auto type = adaptor.getSelf().getType().cast<RankedTensorType>();
|
auto type = cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
auto inputRank = type.getRank();
|
auto inputRank = type.getRank();
|
||||||
if (inputRank == 1) {
|
if (inputRank == 1) {
|
||||||
// If input rank is equal to 1, then there's no scope for flattening the
|
// If input rank is equal to 1, then there's no scope for flattening the
|
||||||
|
@ -604,7 +603,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultType =
|
auto resultType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
if (startDim < 0)
|
if (startDim < 0)
|
||||||
startDim += inputRank;
|
startDim += inputRank;
|
||||||
if (endDim < 0)
|
if (endDim < 0)
|
||||||
|
@ -652,7 +651,7 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
|
BaseTensorType outputTensorType = cast<BaseTensorType>(op.getType());
|
||||||
if (!outputTensorType.hasSizes())
|
if (!outputTensorType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: output must have known sizes");
|
op, "unimplemented: output must have known sizes");
|
||||||
|
@ -660,7 +659,7 @@ public:
|
||||||
std::optional<unsigned> maybeRank = getTensorRank(self);
|
std::optional<unsigned> maybeRank = getTensorRank(self);
|
||||||
if (!maybeRank)
|
if (!maybeRank)
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
||||||
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
auto inputTensorType = cast<Torch::ValueTensorType>(self.getType());
|
||||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Expected input type having sizes");
|
"Expected input type having sizes");
|
||||||
|
@ -901,7 +900,7 @@ public:
|
||||||
getInputAndOutputShape(Value inputTorchTensor,
|
getInputAndOutputShape(Value inputTorchTensor,
|
||||||
SmallVector<Value> outputSizeTorchInt) {
|
SmallVector<Value> outputSizeTorchInt) {
|
||||||
SmallVector<int64_t> inputShape(
|
SmallVector<int64_t> inputShape(
|
||||||
inputTorchTensor.getType().cast<BaseTensorType>().getSizes());
|
cast<BaseTensorType>(inputTorchTensor.getType()).getSizes());
|
||||||
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
SmallVector<int64_t> outputShape(outputSizeTorchInt.size(), kUnknownSize);
|
||||||
for (auto [outputDim, outputDimSize] :
|
for (auto [outputDim, outputDimSize] :
|
||||||
llvm::enumerate(outputSizeTorchInt)) {
|
llvm::enumerate(outputSizeTorchInt)) {
|
||||||
|
@ -945,11 +944,11 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
auto resultType =
|
auto resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
int64_t resultRank = resultType.getRank();
|
int64_t resultRank = resultType.getRank();
|
||||||
if (resultRank == 0) {
|
if (resultRank == 0) {
|
||||||
rewriter
|
rewriter
|
||||||
|
@ -1349,7 +1348,7 @@ public:
|
||||||
auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes);
|
auto outputDims = b.create<tensor::FromElementsOp>(ty, sizes);
|
||||||
|
|
||||||
auto resultType =
|
auto resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self,
|
rewriter.replaceOpWithNewOp<tensor::ReshapeOp>(op, resultType, self,
|
||||||
outputDims);
|
outputDims);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1367,13 +1366,13 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
|
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
auto resultType =
|
auto resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
auto resultShape = resultType.getShape();
|
auto resultShape = resultType.getShape();
|
||||||
int64_t resultRank = resultType.getRank();
|
int64_t resultRank = resultType.getRank();
|
||||||
|
|
||||||
|
@ -1437,7 +1436,7 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
|
|
||||||
if (inputRank == 0) {
|
if (inputRank == 0) {
|
||||||
|
@ -1460,7 +1459,7 @@ public:
|
||||||
|
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
auto resultType =
|
auto resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
int64_t resultRank = resultType.getRank();
|
int64_t resultRank = resultType.getRank();
|
||||||
|
|
||||||
// If the dim(th) dimension of operand tensor type is not statically unit,
|
// If the dim(th) dimension of operand tensor type is not statically unit,
|
||||||
|
@ -1510,7 +1509,7 @@ public:
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
return rewriter.notifyMatchFailure(op, "dim must be constant");
|
||||||
auto inputRank =
|
auto inputRank =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
dim = toPositiveDim(dim, inputRank + 1);
|
dim = toPositiveDim(dim, inputRank + 1);
|
||||||
if (!isValidDim(dim, inputRank + 1))
|
if (!isValidDim(dim, inputRank + 1))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
@ -1535,9 +1534,8 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto resultType = getTypeConverter()
|
auto resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
|
||||||
op, resultType, adaptor.getSelf(), reassociationMap);
|
op, resultType, adaptor.getSelf(), reassociationMap);
|
||||||
return success();
|
return success();
|
||||||
|
@ -1564,11 +1562,10 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||||
|
|
||||||
auto inVector = adaptor.getSelf();
|
auto inVector = adaptor.getSelf();
|
||||||
auto inType = inVector.getType().cast<RankedTensorType>();
|
auto inType = cast<RankedTensorType>(inVector.getType());
|
||||||
auto inputRank = inType.getRank();
|
auto inputRank = inType.getRank();
|
||||||
auto outType = getTypeConverter()
|
auto outType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
auto elementType = inType.getElementType();
|
auto elementType = inType.getElementType();
|
||||||
|
|
||||||
dim0 = toPositiveDim(dim0, inputRank);
|
dim0 = toPositiveDim(dim0, inputRank);
|
||||||
|
@ -1634,11 +1631,10 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");
|
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");
|
||||||
|
|
||||||
Value inVector = adaptor.getSelf();
|
Value inVector = adaptor.getSelf();
|
||||||
auto inType = inVector.getType().cast<RankedTensorType>();
|
auto inType = cast<RankedTensorType>(inVector.getType());
|
||||||
int64_t inputRank = inType.getRank();
|
int64_t inputRank = inType.getRank();
|
||||||
auto outType = getTypeConverter()
|
auto outType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type elementType = inType.getElementType();
|
Type elementType = inType.getElementType();
|
||||||
|
|
||||||
// Check if the dimensions are a valid constants.
|
// Check if the dimensions are a valid constants.
|
||||||
|
@ -1747,7 +1743,7 @@ public:
|
||||||
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
getTypeConvertedValues(rewriter, loc, typeConverter, tensorsTorchType);
|
||||||
|
|
||||||
RankedTensorType newResultType =
|
RankedTensorType newResultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
int rank = newResultType.getRank();
|
int rank = newResultType.getRank();
|
||||||
Value dimValue = op.getDim();
|
Value dimValue = op.getDim();
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
|
@ -1802,7 +1798,7 @@ public:
|
||||||
// which in this case is `inShapeConverted` because this shape will yield
|
// which in this case is `inShapeConverted` because this shape will yield
|
||||||
// us the dimension size of the output.
|
// us the dimension size of the output.
|
||||||
SmallVector<bool> useBroadcastToShape;
|
SmallVector<bool> useBroadcastToShape;
|
||||||
int64_t inputRank = self.getType().cast<RankedTensorType>().getRank();
|
int64_t inputRank = cast<RankedTensorType>(self.getType()).getRank();
|
||||||
for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e;
|
for (size_t i = inShape.size() - inputRank, e = inShape.size(); i < e;
|
||||||
++i) {
|
++i) {
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
|
@ -1821,7 +1817,7 @@ public:
|
||||||
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
SmallVector<Value> inShapeConverted = getTypeConvertedValues(
|
||||||
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
rewriter, op.getLoc(), getTypeConverter(), inShape);
|
||||||
auto newResultType =
|
auto newResultType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
Value result;
|
Value result;
|
||||||
if (failed(torch_to_linalg::broadcastToGivenShape(
|
if (failed(torch_to_linalg::broadcastToGivenShape(
|
||||||
op, rewriter, self, inShapeConverted, newResultType, result,
|
op, rewriter, self, inShapeConverted, newResultType, result,
|
||||||
|
@ -1869,7 +1865,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
Value src = adaptor.getSrc();
|
Value src = adaptor.getSrc();
|
||||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
// The non_blocking should be a constant `False`.
|
// The non_blocking should be a constant `False`.
|
||||||
bool nonBlocking;
|
bool nonBlocking;
|
||||||
|
@ -1954,7 +1950,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Value src = adaptor.getSrc();
|
Value src = adaptor.getSrc();
|
||||||
auto srcType = src.getType().cast<RankedTensorType>();
|
auto srcType = cast<RankedTensorType>(src.getType());
|
||||||
int64_t srcRank = srcType.getRank();
|
int64_t srcRank = srcType.getRank();
|
||||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||||
// TODO: audit possibility of sparsity on these tensor
|
// TODO: audit possibility of sparsity on these tensor
|
||||||
|
@ -1992,7 +1988,7 @@ public:
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
|
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
auto elementType = resultType.getElementType();
|
auto elementType = resultType.getElementType();
|
||||||
SmallVector<Value> resultShape;
|
SmallVector<Value> resultShape;
|
||||||
|
@ -2070,9 +2066,9 @@ public:
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
|
|
||||||
RankedTensorType resultType =
|
RankedTensorType resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElementType = getElementTypeOrSelf(input.getType());
|
auto inputElementType = getElementTypeOrSelf(input.getType());
|
||||||
if (!isa<ComplexType>(inputElementType)) {
|
if (!isa<ComplexType>(inputElementType)) {
|
||||||
return op.emitError("only ComplexType is allowed as input type");
|
return op.emitError("only ComplexType is allowed as input type");
|
||||||
|
@ -2157,7 +2153,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "dim2 must be constant");
|
return rewriter.notifyMatchFailure(op, "dim2 must be constant");
|
||||||
|
|
||||||
Value inputMatrix = adaptor.getSelf();
|
Value inputMatrix = adaptor.getSelf();
|
||||||
RankedTensorType inputType = inputMatrix.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(inputMatrix.getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
|
|
||||||
if (inputRank < 2)
|
if (inputRank < 2)
|
||||||
|
@ -2277,7 +2273,7 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern<AtenDiagEmbedOp> {
|
||||||
static SmallVector<Value>
|
static SmallVector<Value>
|
||||||
getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor,
|
getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor,
|
||||||
int64_t offset, int64_t dim1, int64_t dim2) {
|
int64_t offset, int64_t dim1, int64_t dim2) {
|
||||||
auto inputType = tensor.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(tensor.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
// output tensor always has 1 extra dimension
|
// output tensor always has 1 extra dimension
|
||||||
|
@ -2314,7 +2310,7 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
auto resultRank = inputRank + 1;
|
auto resultRank = inputRank + 1;
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ public:
|
||||||
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
|
||||||
return op.emitError("unimplemented: dim is not constant");
|
return op.emitError("unimplemented: dim is not constant");
|
||||||
int64_t inputRank =
|
int64_t inputRank =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
dim = toPositiveDim(dim, inputRank);
|
dim = toPositiveDim(dim, inputRank);
|
||||||
if (!isValidDim(dim, inputRank))
|
if (!isValidDim(dim, inputRank))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
@ -88,7 +88,7 @@ public:
|
||||||
Value indices = adaptor.getIndex();
|
Value indices = adaptor.getIndex();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
RankedTensorType newResultTy =
|
RankedTensorType newResultTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
int64_t rank = newResultTy.getRank();
|
int64_t rank = newResultTy.getRank();
|
||||||
|
|
||||||
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
|
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
|
||||||
|
@ -128,9 +128,9 @@ public:
|
||||||
Value weight = adaptor.getWeight();
|
Value weight = adaptor.getWeight();
|
||||||
Value indices = adaptor.getIndices();
|
Value indices = adaptor.getIndices();
|
||||||
RankedTensorType newResultType =
|
RankedTensorType newResultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
if (weightTy.getRank() != 2)
|
if (weightTy.getRank() != 2)
|
||||||
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
||||||
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
|
Value embeddingDim = getDimOp(rewriter, loc, weight, 1);
|
||||||
|
@ -140,7 +140,7 @@ public:
|
||||||
sizes.push_back(embeddingDim);
|
sizes.push_back(embeddingDim);
|
||||||
int64_t resultRank = sizes.size();
|
int64_t resultRank = sizes.size();
|
||||||
|
|
||||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||||
int64_t indicesRank = indicesTy.getRank();
|
int64_t indicesRank = indicesTy.getRank();
|
||||||
SmallVector<AffineExpr> indicesExprs;
|
SmallVector<AffineExpr> indicesExprs;
|
||||||
for (int i = 0; i < indicesRank; i++)
|
for (int i = 0; i < indicesRank; i++)
|
||||||
|
@ -274,15 +274,15 @@ public:
|
||||||
"include_last_offset is expected to be a constant boolean value.");
|
"include_last_offset is expected to be a constant boolean value.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
if (weightTy.getRank() != 2)
|
if (weightTy.getRank() != 2)
|
||||||
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
return rewriter.notifyMatchFailure(op, "weight must be rank 2");
|
||||||
|
|
||||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||||
if (indicesTy.getRank() != 1)
|
if (indicesTy.getRank() != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "indices must be a vector");
|
return rewriter.notifyMatchFailure(op, "indices must be a vector");
|
||||||
|
|
||||||
auto offsetsTy = offsets.getType().cast<RankedTensorType>();
|
auto offsetsTy = cast<RankedTensorType>(offsets.getType());
|
||||||
if (offsetsTy.getRank() != 1)
|
if (offsetsTy.getRank() != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "offsets much be a vector");
|
return rewriter.notifyMatchFailure(op, "offsets much be a vector");
|
||||||
|
|
||||||
|
@ -471,10 +471,9 @@ public:
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
Value indices = adaptor.getIndex();
|
Value indices = adaptor.getIndex();
|
||||||
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
unsigned inputRank = inputType.getRank();
|
unsigned inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
@ -604,10 +603,9 @@ public:
|
||||||
op, "aten.index.Tensor: index tensor must not be None");
|
op, "aten.index.Tensor: index tensor must not be None");
|
||||||
}
|
}
|
||||||
|
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
int inputRank = inputType.getRank();
|
int inputRank = inputType.getRank();
|
||||||
int resultRank = resultType.getRank();
|
int resultRank = resultType.getRank();
|
||||||
|
@ -625,7 +623,7 @@ public:
|
||||||
int maxRank = -1;
|
int maxRank = -1;
|
||||||
for (auto indexTensor : indexTensors) {
|
for (auto indexTensor : indexTensors) {
|
||||||
RankedTensorType indexTensorType =
|
RankedTensorType indexTensorType =
|
||||||
indexTensor.getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(indexTensor.getType());
|
||||||
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
maxRank = std::max(maxRank, (int)indexTensorType.getRank());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -639,7 +637,7 @@ public:
|
||||||
int64_t staticDimSize = -1;
|
int64_t staticDimSize = -1;
|
||||||
for (auto indexTensor : indexTensors) {
|
for (auto indexTensor : indexTensors) {
|
||||||
RankedTensorType indexTensorType =
|
RankedTensorType indexTensorType =
|
||||||
indexTensor.getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(indexTensor.getType());
|
||||||
int64_t indexTensorRank = indexTensorType.getRank();
|
int64_t indexTensorRank = indexTensorType.getRank();
|
||||||
if ((maxRank - indexTensorRank) > (i - startIndex))
|
if ((maxRank - indexTensorRank) > (i - startIndex))
|
||||||
continue;
|
continue;
|
||||||
|
@ -714,7 +712,7 @@ public:
|
||||||
|
|
||||||
for (auto indexTensor : indexTensors) {
|
for (auto indexTensor : indexTensors) {
|
||||||
RankedTensorType indexTensorType =
|
RankedTensorType indexTensorType =
|
||||||
indexTensor.getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(indexTensor.getType());
|
||||||
auto indexTensorShape =
|
auto indexTensorShape =
|
||||||
makeShapeTorchCompatible(indexTensorType.getShape());
|
makeShapeTorchCompatible(indexTensorType.getShape());
|
||||||
int rank = indexTensorShape.size();
|
int rank = indexTensorShape.size();
|
||||||
|
@ -828,7 +826,7 @@ public:
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
|
|
||||||
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
Type elementType = inputType.getElementType();
|
Type elementType = inputType.getElementType();
|
||||||
|
|
||||||
|
@ -989,7 +987,7 @@ public:
|
||||||
Value gradOutput = adaptor.getGradOutput();
|
Value gradOutput = adaptor.getGradOutput();
|
||||||
|
|
||||||
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
Type resultType = getTypeConverter()->convertType(op.getResult().getType());
|
||||||
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
|
||||||
auto gradOutputRank = gradOutputType.getRank();
|
auto gradOutputRank = gradOutputType.getRank();
|
||||||
Type elementType = gradOutputType.getElementType();
|
Type elementType = gradOutputType.getElementType();
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
||||||
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
|
||||||
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
arg = torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
rewriter, loc, ValueRange{arg},
|
rewriter, loc, ValueRange{arg},
|
||||||
arg.getType().cast<TensorType>().getElementType(),
|
cast<TensorType>(arg.getType()).getElementType(),
|
||||||
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
[&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
|
||||||
Value result =
|
Value result =
|
||||||
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
|
rewriter.create<arith::AddIOp>(loc, payloadArgs[0], minSIValue);
|
||||||
|
@ -58,7 +58,7 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
|
||||||
|
|
||||||
static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
|
static Value transposeValue(Location loc, Value value, ArrayRef<int64_t> perms,
|
||||||
PatternRewriter &rewriter) {
|
PatternRewriter &rewriter) {
|
||||||
auto valueTy = value.getType().cast<RankedTensorType>();
|
auto valueTy = cast<RankedTensorType>(value.getType());
|
||||||
auto inShape = valueTy.getShape();
|
auto inShape = valueTy.getShape();
|
||||||
llvm::SmallVector<int64_t> outShape;
|
llvm::SmallVector<int64_t> outShape;
|
||||||
llvm::SmallVector<Value> dynDims;
|
llvm::SmallVector<Value> dynDims;
|
||||||
|
@ -100,8 +100,8 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
|
||||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
if (lhsType.getRank() != 2 || rhsType.getRank() != 2) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -109,9 +109,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
ValueTensorType lhsTorchType =
|
ValueTensorType lhsTorchType =
|
||||||
op.getSelf().getType().cast<ValueTensorType>();
|
cast<ValueTensorType>(op.getSelf().getType());
|
||||||
ValueTensorType rhsTorchType =
|
ValueTensorType rhsTorchType =
|
||||||
op.getMat2().getType().cast<ValueTensorType>();
|
cast<ValueTensorType>(op.getMat2().getType());
|
||||||
|
|
||||||
Value lhsZeroPoint, rhsZeroPoint;
|
Value lhsZeroPoint, rhsZeroPoint;
|
||||||
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
getZeroPoint(op.getSelf(), lhsZeroPoint);
|
||||||
|
@ -148,7 +148,7 @@ public:
|
||||||
"mismatching contracting dimension for torch.aten.mm"));
|
"mismatching contracting dimension for torch.aten.mm"));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto resultTy = op.getType().cast<ValueTensorType>();
|
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type elementType = cast<TensorType>(newResultType).getElementType();
|
Type elementType = cast<TensorType>(newResultType).getElementType();
|
||||||
|
@ -176,9 +176,9 @@ public:
|
||||||
|
|
||||||
// change uint8 quantization -> int8 quantization
|
// change uint8 quantization -> int8 quantization
|
||||||
int64_t numBits =
|
int64_t numBits =
|
||||||
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
|
||||||
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
||||||
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
|
||||||
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
||||||
|
|
||||||
matmul =
|
matmul =
|
||||||
|
@ -229,9 +229,9 @@ public:
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfRank =
|
auto selfRank =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
Type elementType =
|
Type elementType =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getElementType();
|
||||||
Value c1 =
|
Value c1 =
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||||
|
|
||||||
|
@ -299,8 +299,8 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
auto lhsType = cast<RankedTensorType>(lhs.getType());
|
||||||
auto rhsType = rhs.getType().cast<RankedTensorType>();
|
auto rhsType = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
|
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
|
||||||
auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType());
|
auto rhsTorchType = cast<ValueTensorType>(op.getOther().getType());
|
||||||
|
@ -348,9 +348,9 @@ public:
|
||||||
|
|
||||||
// change uint8 quantization -> int8 quantization
|
// change uint8 quantization -> int8 quantization
|
||||||
int64_t numBits =
|
int64_t numBits =
|
||||||
lhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
|
||||||
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
|
||||||
numBits = rhsType.getElementType().cast<mlir::IntegerType>().getWidth();
|
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
|
||||||
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
|
||||||
|
|
||||||
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to
|
// for quantized vec-vec, vec-mat, and mat-vec cases, lower to
|
||||||
|
@ -726,8 +726,8 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
Value rhs = adaptor.getMat2();
|
Value rhs = adaptor.getMat2();
|
||||||
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
|
RankedTensorType lhsType = cast<RankedTensorType>(lhs.getType());
|
||||||
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
|
RankedTensorType rhsType = cast<RankedTensorType>(rhs.getType());
|
||||||
Type newResultType = getTypeConverter()->convertType(op.getType());
|
Type newResultType = getTypeConverter()->convertType(op.getType());
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
cast<RankedTensorType>(newResultType).getElementType();
|
cast<RankedTensorType>(newResultType).getElementType();
|
||||||
|
@ -794,7 +794,7 @@ public:
|
||||||
Value input = adaptor.getInput(); /* in form of N*C*H*W */
|
Value input = adaptor.getInput(); /* in form of N*C*H*W */
|
||||||
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */
|
Value weight = adaptor.getWeight(); /* in form of F*C*H*W */
|
||||||
Value bias = adaptor.getBias();
|
Value bias = adaptor.getBias();
|
||||||
auto resultTy = op.getType().cast<ValueTensorType>();
|
auto resultTy = cast<ValueTensorType>(op.getType());
|
||||||
|
|
||||||
Value inputZp, weightZp;
|
Value inputZp, weightZp;
|
||||||
if (auto make = op.getInput()
|
if (auto make = op.getInput()
|
||||||
|
@ -826,7 +826,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
|
if (inputZp && weightZp && !isa<Torch::NoneType>(bias.getType())) {
|
||||||
auto biasDTy = bias.getType().cast<RankedTensorType>().getElementType();
|
auto biasDTy = cast<RankedTensorType>(bias.getType()).getElementType();
|
||||||
if (!biasDTy.isInteger(32)) {
|
if (!biasDTy.isInteger(32)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "quantized result ty should be i32 accumulator");
|
op, "quantized result ty should be i32 accumulator");
|
||||||
|
@ -838,15 +838,15 @@ public:
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only constant transposed supported");
|
op, "unimplemented: only constant transposed supported");
|
||||||
|
|
||||||
auto inputDTy = input.getType().cast<RankedTensorType>().getElementType();
|
auto inputDTy = cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
auto weightDTy = weight.getType().cast<RankedTensorType>().getElementType();
|
auto weightDTy = cast<RankedTensorType>(weight.getType()).getElementType();
|
||||||
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
auto resultDTy = resultTy.toBuiltinTensor().getElementType();
|
||||||
|
|
||||||
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
|
if (!isa<mlir::FloatType, mlir::IntegerType>(inputDTy) ||
|
||||||
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
|
!isa<mlir::FloatType, mlir::IntegerType>(weightDTy) ||
|
||||||
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
|
!isa<mlir::FloatType, mlir::IntegerType>(resultDTy))
|
||||||
return op.emitError("unimplemented: non-fp not-int type");
|
return op.emitError("unimplemented: non-fp not-int type");
|
||||||
size_t inRank = input.getType().cast<RankedTensorType>().getRank();
|
size_t inRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
size_t numSpatialDims = inRank - 2;
|
size_t numSpatialDims = inRank - 2;
|
||||||
if (numSpatialDims < 1 || numSpatialDims > 3)
|
if (numSpatialDims < 1 || numSpatialDims > 3)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1067,11 +1067,11 @@ public:
|
||||||
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
rewriter.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||||
if (biasType.getRank() != 1)
|
if (biasType.getRank() != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
return rewriter.notifyMatchFailure(op, "expect bias to be rank 1");
|
||||||
|
|
||||||
auto resultRank = initTensor.getType().cast<RankedTensorType>().getRank();
|
auto resultRank = cast<RankedTensorType>(initTensor.getType()).getRank();
|
||||||
SmallVector<AffineMap> indexingMaps = {
|
SmallVector<AffineMap> indexingMaps = {
|
||||||
// bias is used to initialize the channels - dimension 1 of output
|
// bias is used to initialize the channels - dimension 1 of output
|
||||||
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
AffineMap::get(/*dimCount=*/resultRank, /*symbolCount=*/0,
|
||||||
|
@ -1228,9 +1228,9 @@ public:
|
||||||
|
|
||||||
// Special depthwise case
|
// Special depthwise case
|
||||||
auto inShape = makeShapeTorchCompatible(
|
auto inShape = makeShapeTorchCompatible(
|
||||||
input.getType().cast<RankedTensorType>().getShape());
|
cast<RankedTensorType>(input.getType()).getShape());
|
||||||
auto weightShape = makeShapeTorchCompatible(
|
auto weightShape = makeShapeTorchCompatible(
|
||||||
weight.getType().cast<RankedTensorType>().getShape());
|
cast<RankedTensorType>(weight.getType()).getShape());
|
||||||
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
if (weightShape[0] != kUnknownSize && inShape[1] == groupSize &&
|
||||||
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
|
weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) {
|
||||||
// Collapse weight shape
|
// Collapse weight shape
|
||||||
|
@ -1264,7 +1264,7 @@ public:
|
||||||
|
|
||||||
// Grouped case, use the grouped conv linalg op
|
// Grouped case, use the grouped conv linalg op
|
||||||
auto expandGroups = [&](Value tensor, size_t dim) {
|
auto expandGroups = [&](Value tensor, size_t dim) {
|
||||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||||
|
|
||||||
SmallVector<int64_t> outShape;
|
SmallVector<int64_t> outShape;
|
||||||
|
@ -1297,7 +1297,7 @@ public:
|
||||||
|
|
||||||
// expand F,C,H,W -> G,F/G,C,H,W
|
// expand F,C,H,W -> G,F/G,C,H,W
|
||||||
auto expandWeight = [&](Value tensor) {
|
auto expandWeight = [&](Value tensor) {
|
||||||
auto inType = tensor.getType().cast<RankedTensorType>();
|
auto inType = cast<RankedTensorType>(tensor.getType());
|
||||||
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
auto inShape = makeShapeTorchCompatible(inType.getShape());
|
||||||
|
|
||||||
SmallVector<int64_t> outShape{
|
SmallVector<int64_t> outShape{
|
||||||
|
|
|
@ -80,7 +80,7 @@ computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
|
||||||
SmallVectorImpl<int64_t> &dilationInts,
|
SmallVectorImpl<int64_t> &dilationInts,
|
||||||
SmallVectorImpl<Value> &kernelSizeIntValues,
|
SmallVectorImpl<Value> &kernelSizeIntValues,
|
||||||
SmallVectorImpl<Value> &outTensorShape, Value initValue) {
|
SmallVectorImpl<Value> &outTensorShape, Value initValue) {
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
Value N = getDimOp(rewriter, loc, self, 0);
|
Value N = getDimOp(rewriter, loc, self, 0);
|
||||||
|
@ -116,7 +116,7 @@ static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter,
|
||||||
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
SmallVector<int64_t> lowPaddingIncludingNC = {0, 0};
|
||||||
SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
|
SmallVector<int64_t> highPaddingIncludingNC = {0, 0};
|
||||||
|
|
||||||
unsigned selfRank = self.getType().cast<RankedTensorType>().getRank();
|
unsigned selfRank = cast<RankedTensorType>(self.getType()).getRank();
|
||||||
unsigned paddingIntsSize = paddingInts.size();
|
unsigned paddingIntsSize = paddingInts.size();
|
||||||
|
|
||||||
if (paddingIntsSize == 2 * (selfRank - 2)) {
|
if (paddingIntsSize == 2 * (selfRank - 2)) {
|
||||||
|
@ -153,7 +153,7 @@ static LogicalResult createPoolingOp(
|
||||||
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
SmallVectorImpl<int64_t> &dilationInts, Attribute initValueAttr,
|
||||||
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
SmallVectorImpl<Value> &outTensorShape, Value &paddedInput, Value &result) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
|
if (!isa<mlir::FloatType>(elementType) && !supportNonFPInput)
|
||||||
return op->emitError("unimplemented: non-floating point type");
|
return op->emitError("unimplemented: non-floating point type");
|
||||||
|
|
||||||
|
@ -214,7 +214,7 @@ private:
|
||||||
bool ceilMode) const {
|
bool ceilMode) const {
|
||||||
SmallVector<Value, 5> outTensorShape;
|
SmallVector<Value, 5> outTensorShape;
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
TypedAttr smallestFPValueAttr = rewriter.getFloatAttr(
|
||||||
elementType,
|
elementType,
|
||||||
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
APFloat::getInf(cast<mlir::FloatType>(elementType).getFloatSemantics(),
|
||||||
|
@ -307,7 +307,7 @@ public:
|
||||||
|
|
||||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
int64_t selfRank = self.getType().cast<RankedTensorType>().getRank();
|
int64_t selfRank = cast<RankedTensorType>(self.getType()).getRank();
|
||||||
|
|
||||||
if (selfRank != Dim + 2)
|
if (selfRank != Dim + 2)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -326,7 +326,7 @@ public:
|
||||||
strideInts, paddingInts)))
|
strideInts, paddingInts)))
|
||||||
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
return rewriter.notifyMatchFailure(op, "invalid pooling parameters");
|
||||||
|
|
||||||
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
|
|
||||||
if constexpr (Dim == 2) {
|
if constexpr (Dim == 2) {
|
||||||
SmallVector<Value, 4> outTensorShape;
|
SmallVector<Value, 4> outTensorShape;
|
||||||
|
@ -389,7 +389,7 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||||
Type elementType = selfType.getElementType();
|
Type elementType = selfType.getElementType();
|
||||||
RankedTensorType indicesRankedTensorType =
|
RankedTensorType indicesRankedTensorType =
|
||||||
getTypeConverter()
|
getTypeConverter()
|
||||||
|
@ -552,7 +552,7 @@ public:
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
|
|
||||||
Type inputElementType =
|
Type inputElementType =
|
||||||
self.getType().cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(self.getType()).getElementType();
|
||||||
Type resultType = typeConverter->convertType(op.getType());
|
Type resultType = typeConverter->convertType(op.getType());
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
cast<RankedTensorType>(resultType).getElementType();
|
cast<RankedTensorType>(resultType).getElementType();
|
||||||
|
@ -592,10 +592,9 @@ public:
|
||||||
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
if constexpr (std::is_same<OpTy, AtenAvgPool2dOp>()) {
|
||||||
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
Value kHtimeskW = rewriter.create<arith::MulIOp>(
|
||||||
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
loc, kernelSizeIntValues[0], kernelSizeIntValues[1]);
|
||||||
divisor =
|
divisor = isa<Torch::NoneType>(op.getDivisorOverride().getType())
|
||||||
op.getDivisorOverride().getType().template isa<Torch::NoneType>()
|
? kHtimeskW
|
||||||
? kHtimeskW
|
: adaptor.getDivisorOverride();
|
||||||
: adaptor.getDivisorOverride();
|
|
||||||
} else {
|
} else {
|
||||||
divisor = kernelSizeIntValues[0];
|
divisor = kernelSizeIntValues[0];
|
||||||
}
|
}
|
||||||
|
@ -901,7 +900,7 @@ public:
|
||||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
|
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
const Type elementType = inputType.getElementType();
|
const Type elementType = inputType.getElementType();
|
||||||
|
|
||||||
// get rank of input (same as rank of output)
|
// get rank of input (same as rank of output)
|
||||||
|
|
|
@ -127,7 +127,7 @@ public:
|
||||||
Value from = adaptor.getFrom();
|
Value from = adaptor.getFrom();
|
||||||
Value to = adaptor.getTo();
|
Value to = adaptor.getTo();
|
||||||
Value generator = adaptor.getGenerator();
|
Value generator = adaptor.getGenerator();
|
||||||
RankedTensorType resultType = self.getType().cast<RankedTensorType>();
|
RankedTensorType resultType = cast<RankedTensorType>(self.getType());
|
||||||
Type elemTy = resultType.getElementType();
|
Type elemTy = resultType.getElementType();
|
||||||
Type f64Ty = rewriter.getF64Type();
|
Type f64Ty = rewriter.getF64Type();
|
||||||
|
|
||||||
|
|
|
@ -66,8 +66,7 @@ public:
|
||||||
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
cast<RankedTensorType>(typec->convertType(op.getResult(0).getType()));
|
||||||
auto idxResultType =
|
auto idxResultType =
|
||||||
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
cast<RankedTensorType>(typec->convertType(op.getResult(1).getType()));
|
||||||
RankedTensorType inputType =
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
input.getType().template cast<RankedTensorType>();
|
|
||||||
Type idxElementType =
|
Type idxElementType =
|
||||||
getElementTypeOrSelf(typec->convertType(idxResultType));
|
getElementTypeOrSelf(typec->convertType(idxResultType));
|
||||||
if (!isa<IntegerType>(idxElementType))
|
if (!isa<IntegerType>(idxElementType))
|
||||||
|
@ -472,7 +471,7 @@ private:
|
||||||
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}};
|
||||||
typename T::Adaptor adaptor(operands);
|
typename T::Adaptor adaptor(operands);
|
||||||
opInfo.tensorOperand = adaptor.getSelf();
|
opInfo.tensorOperand = adaptor.getSelf();
|
||||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
||||||
|
|
||||||
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim)))
|
if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&opInfo.keepDim)))
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -480,8 +479,7 @@ private:
|
||||||
|
|
||||||
SmallVector<int64_t> dimList;
|
SmallVector<int64_t> dimList;
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
bool isNoneOrEmptyDimList =
|
bool isNoneOrEmptyDimList = isa<Torch::NoneType>(op.getDim().getType());
|
||||||
op.getDim().getType().template isa<Torch::NoneType>();
|
|
||||||
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
if (matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) {
|
||||||
// Fix negative dimensions, if any, before adding to the list.
|
// Fix negative dimensions, if any, before adding to the list.
|
||||||
for (int64_t dim : dimList) {
|
for (int64_t dim : dimList) {
|
||||||
|
@ -522,7 +520,7 @@ private:
|
||||||
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
|
if (isa<AtenAnyOp, AtenAllOp, AtenMaxOp, AtenMinOp, AtenSumOp, AtenProdOp,
|
||||||
AtenNormScalarOp>(op)) {
|
AtenNormScalarOp>(op)) {
|
||||||
opInfo.tensorOperand = operands[0];
|
opInfo.tensorOperand = operands[0];
|
||||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
||||||
|
|
||||||
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
|
// `AtenAny`, `AtenAll`, `AtenSumOp`, `AtenProdOp`, `AtenMaxOp`, and
|
||||||
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
|
// `AtenMinOp` each reduce along all the dimensions of the input tensor.
|
||||||
|
|
|
@ -42,7 +42,7 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto type = self.getType().cast<RankedTensorType>();
|
auto type = cast<RankedTensorType>(self.getType());
|
||||||
int64_t rank = type.getRank();
|
int64_t rank = type.getRank();
|
||||||
|
|
||||||
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
|
auto primList = op.getPad().getDefiningOp<Torch::PrimListConstructOp>();
|
||||||
|
@ -105,7 +105,7 @@ public:
|
||||||
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType);
|
||||||
|
|
||||||
Type padType = tensor::PadOp::inferResultType(
|
Type padType = tensor::PadOp::inferResultType(
|
||||||
self.getType().cast<RankedTensorType>(), staticLow, staticHigh);
|
cast<RankedTensorType>(self.getType()), staticLow, staticHigh);
|
||||||
Value paddedInput = rewriter.create<tensor::PadOp>(
|
Value paddedInput = rewriter.create<tensor::PadOp>(
|
||||||
loc, padType, self, lowPad, highPad, castedValue);
|
loc, padType, self, lowPad, highPad, castedValue);
|
||||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, paddedInput);
|
||||||
|
@ -354,7 +354,7 @@ public:
|
||||||
|
|
||||||
// The pin_memory should be either `False` or `none`.
|
// The pin_memory should be either `False` or `none`.
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
pinMemory)) {
|
pinMemory)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -376,7 +376,7 @@ public:
|
||||||
auto resultType = typeConverter->convertType(op.getType())
|
auto resultType = typeConverter->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
Type resultElementType;
|
Type resultElementType;
|
||||||
if (op.getDtype().getType().template isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||||
resultElementType = resultType.getElementType();
|
resultElementType = resultType.getElementType();
|
||||||
} else {
|
} else {
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
|
@ -423,7 +423,7 @@ public:
|
||||||
|
|
||||||
// The pin_memory should be either `False` or `none`.
|
// The pin_memory should be either `False` or `none`.
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
pinMemory))
|
pinMemory))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -480,7 +480,7 @@ public:
|
||||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||||
|
|
||||||
auto resultType =
|
auto resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
Type resultElementType;
|
Type resultElementType;
|
||||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
||||||
resultElementType = getDefaultDtypeForTorchScalar(
|
resultElementType = getDefaultDtypeForTorchScalar(
|
||||||
|
|
|
@ -38,7 +38,7 @@ public:
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
Value dim = adaptor.getDim();
|
Value dim = adaptor.getDim();
|
||||||
auto type = self.getType().cast<RankedTensorType>();
|
auto type = cast<RankedTensorType>(self.getType());
|
||||||
Value inputRank = rewriter.create<arith::ConstantOp>(
|
Value inputRank = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(type.getRank()));
|
loc, rewriter.getI64IntegerAttr(type.getRank()));
|
||||||
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
|
Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank);
|
||||||
|
@ -86,8 +86,7 @@ public:
|
||||||
Value input = adaptor.getA();
|
Value input = adaptor.getA();
|
||||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
int64_t inputRank = inputSizes.size();
|
int64_t inputRank = inputSizes.size();
|
||||||
Type inputDtype =
|
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
||||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
|
||||||
|
|
||||||
// The `input` tensor must contain exactly one element, i.e., either the
|
// The `input` tensor must contain exactly one element, i.e., either the
|
||||||
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
|
// `input` is a zero rank tensor or all the dimensions of the `input` tensor
|
||||||
|
|
|
@ -34,7 +34,7 @@ using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// Check if a ranked-tensor has the specified element type.
|
// Check if a ranked-tensor has the specified element type.
|
||||||
template <typename elementType> static bool hasElementType(Value tensor) {
|
template <typename elementType> static bool hasElementType(Value tensor) {
|
||||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||||
Type tensorElementType = tensorType.getElementType();
|
Type tensorElementType = tensorType.getElementType();
|
||||||
return isa<elementType>(tensorElementType);
|
return isa<elementType>(tensorElementType);
|
||||||
}
|
}
|
||||||
|
@ -173,8 +173,7 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type elementalType =
|
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
|
||||||
op.getSelf().getType().template cast<BaseTensorType>().getDtype();
|
|
||||||
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
if constexpr (std::is_same<OpTy, AtenLtTensorOp>()) {
|
||||||
return createLessThan(b, loc, elementalType, lhs, rhs);
|
return createLessThan(b, loc, elementalType, lhs, rhs);
|
||||||
}
|
}
|
||||||
|
@ -200,7 +199,7 @@ template <arith::CmpIPredicate predicate>
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
|
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
|
||||||
Operation *op, ArrayRef<Value> operands, Value &result) {
|
Operation *op, ArrayRef<Value> operands, Value &result) {
|
||||||
auto inputType = operands[0].getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(operands[0].getType());
|
||||||
uint64_t inputRank = inputType.getRank();
|
uint64_t inputRank = inputType.getRank();
|
||||||
|
|
||||||
// Use the indices of the two innermost dimensions.
|
// Use the indices of the two innermost dimensions.
|
||||||
|
@ -405,7 +404,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
bitwiseAndScalar.getType().cast<BaseTensorType>().getDtype();
|
cast<BaseTensorType>(bitwiseAndScalar.getType()).getDtype();
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/resultElementType);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
|
@ -537,7 +536,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
if (auto relu = dyn_cast<AtenReluOp>(op)) {
|
||||||
Value zeroPoint = getZeroPoint(relu.getSelf());
|
Value zeroPoint = getZeroPoint(relu.getSelf());
|
||||||
Value arg = payloadArgs[0];
|
Value arg = payloadArgs[0];
|
||||||
auto intType = arg.getType().dyn_cast<mlir::IntegerType>();
|
auto intType = dyn_cast<mlir::IntegerType>(arg.getType());
|
||||||
if (zeroPoint && !intType) {
|
if (zeroPoint && !intType) {
|
||||||
relu.emitError("unimplemented: non-integer quantized Relu.");
|
relu.emitError("unimplemented: non-integer quantized Relu.");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -739,9 +738,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
if (auto add = dyn_cast<AtenAddTensorOp>(op)) {
|
||||||
AtenAddTensorOp::Adaptor adaptor(operands);
|
AtenAddTensorOp::Adaptor adaptor(operands);
|
||||||
Type resultElementType = add.getType().cast<BaseTensorType>().getDtype();
|
Type resultElementType = cast<BaseTensorType>(add.getType()).getDtype();
|
||||||
Type dtype = converter->convertType(add.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(add.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
|
@ -762,10 +760,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
|
if (auto sub = dyn_cast<AtenSubTensorOp>(op)) {
|
||||||
AtenSubTensorOp::Adaptor adaptor(operands);
|
AtenSubTensorOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(sub.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(sub.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Type resultElementType = sub.getType().cast<BaseTensorType>().getDtype();
|
Type resultElementType = cast<BaseTensorType>(sub.getType()).getDtype();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/resultElementType);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
|
@ -785,9 +782,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
|
if (auto subScalar = dyn_cast<AtenSubScalarOp>(op)) {
|
||||||
Type dtype = converter->convertType(subScalar.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(subScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
Value alpha = convertScalarToDtype(
|
Value alpha = convertScalarToDtype(
|
||||||
|
@ -805,11 +802,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
|
if (auto addScalar = dyn_cast<AtenAddScalarOp>(op)) {
|
||||||
Type dtype = converter->convertType(addScalar.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(addScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Type resultElementType =
|
Type resultElementType =
|
||||||
addScalar.getType().cast<BaseTensorType>().getDtype();
|
cast<BaseTensorType>(addScalar.getType()).getDtype();
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype,
|
||||||
/*srcOriginalDtype=*/std::nullopt,
|
/*srcOriginalDtype=*/std::nullopt,
|
||||||
/*dstOriginalDtype=*/resultElementType);
|
/*dstOriginalDtype=*/resultElementType);
|
||||||
|
@ -832,8 +829,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
if (auto mul = dyn_cast<AtenMulTensorOp>(op)) {
|
||||||
AtenMulTensorOp::Adaptor adaptor(operands);
|
AtenMulTensorOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(mul.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(mul.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
@ -846,8 +842,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
|
if (auto atan2 = dyn_cast<AtenAtan2Op>(op)) {
|
||||||
Type dtype = converter->convertType(atan2.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(atan2.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
atan2.emitError("Atan2 requires floating point result type");
|
atan2.emitError("Atan2 requires floating point result type");
|
||||||
|
@ -883,8 +878,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
if (auto div = dyn_cast<AtenDivTensorOp>(op)) {
|
||||||
AtenDivTensorOp::Adaptor adaptor(operands);
|
AtenDivTensorOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(div.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(div.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
|
@ -907,7 +901,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
operands);
|
operands);
|
||||||
}
|
}
|
||||||
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
if (auto pow = dyn_cast<AtenPowScalarOp>(op)) {
|
||||||
Type dtype = pow.getType().cast<ValueTensorType>().getDtype();
|
Type dtype = cast<ValueTensorType>(pow.getType()).getDtype();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
pow.emitError("unimplemented: non-floating point dtype");
|
pow.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -925,14 +919,13 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
pow.emitError("unimplemented: non-floating point dtype");
|
pow.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type dtype = pow.getSelf().getType().cast<ValueTensorType>().getDtype();
|
Type dtype = cast<ValueTensorType>(pow.getSelf().getType()).getDtype();
|
||||||
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
|
Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
|
return b.create<math::PowFOp>(loc, payloadArgs[0], expPromoted);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto pow = dyn_cast<AtenPowTensorTensorOp>(op)) {
|
if (auto pow = dyn_cast<AtenPowTensorTensorOp>(op)) {
|
||||||
Type dtype = converter->convertType(pow.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(pow.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
pow.emitError("unimplemented: non-floating point dtype");
|
pow.emitError("unimplemented: non-floating point dtype");
|
||||||
|
@ -944,8 +937,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto imag = dyn_cast<AtenImagOp>(op)) {
|
if (auto imag = dyn_cast<AtenImagOp>(op)) {
|
||||||
Type dtype = converter->convertType(imag.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(imag.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
imag.emitError("unimplemented: non-floating point dtype");
|
imag.emitError("unimplemented: non-floating point dtype");
|
||||||
|
@ -956,8 +948,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto real = dyn_cast<AtenRealOp>(op)) {
|
if (auto real = dyn_cast<AtenRealOp>(op)) {
|
||||||
Type dtype = converter->convertType(real.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(real.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
real.emitError("unimplemented: non-floating point dtype");
|
real.emitError("unimplemented: non-floating point dtype");
|
||||||
|
@ -968,7 +959,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
|
if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
|
||||||
Type dtype = gtScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(gtScalar.getSelf().getType()).getDtype();
|
||||||
|
|
||||||
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
|
// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
|
||||||
// one static function.
|
// one static function.
|
||||||
|
@ -998,7 +989,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
|
if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
|
||||||
Type dtype = geScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(geScalar.getSelf().getType()).getDtype();
|
||||||
|
|
||||||
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
|
// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
|
||||||
// can be refactored.
|
// can be refactored.
|
||||||
|
@ -1028,7 +1019,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
|
if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
|
||||||
Type dtype = eqScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(eqScalar.getSelf().getType()).getDtype();
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
|
@ -1044,7 +1035,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
|
||||||
Type dtype = neScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(neScalar.getSelf().getType()).getDtype();
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
|
@ -1060,7 +1051,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
|
||||||
Type dtype = ltScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
|
@ -1088,7 +1079,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
|
if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
|
||||||
Type dtype = leScalar.getSelf().getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(leScalar.getSelf().getType()).getDtype();
|
||||||
Value otherPromoted =
|
Value otherPromoted =
|
||||||
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());
|
||||||
|
|
||||||
|
@ -1116,9 +1107,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
|
if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
|
||||||
Type dtype = converter->convertType(whereSelf.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(whereSelf.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype);
|
||||||
return b.create<arith::SelectOp>(loc, payloadArgs[0], lhs, rhs);
|
return b.create<arith::SelectOp>(loc, payloadArgs[0], lhs, rhs);
|
||||||
|
@ -1141,7 +1132,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::AddFOp>(loc, start, weightedDelta);
|
return b.create<arith::AddFOp>(loc, start, weightedDelta);
|
||||||
}
|
}
|
||||||
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
if (auto minimum = dyn_cast<AtenMinimumOp>(op)) {
|
||||||
Type dtype = minimum.getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(minimum.getType()).getDtype();
|
||||||
Type elemTy = converter->convertType(minimum.getType())
|
Type elemTy = converter->convertType(minimum.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
@ -1151,7 +1142,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
|
return b.create<arith::SelectOp>(loc, pred, lhs, rhs);
|
||||||
}
|
}
|
||||||
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
if (auto maximum = dyn_cast<AtenMaximumOp>(op)) {
|
||||||
Type dtype = maximum.getType().cast<BaseTensorType>().getDtype();
|
Type dtype = cast<BaseTensorType>(maximum.getType()).getDtype();
|
||||||
Type elemTy = converter->convertType(maximum.getType())
|
Type elemTy = converter->convertType(maximum.getType())
|
||||||
.cast<RankedTensorType>()
|
.cast<RankedTensorType>()
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
@ -1170,15 +1161,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type dtype = converter->convertType(clamp.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(clamp.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
|
if (!isa<mlir::FloatType, mlir::IntegerType>(dtype)) {
|
||||||
clamp.emitError("unimplement type for clamp");
|
clamp.emitError("unimplement type for clamp");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Type dstOriginalDtype = clamp.getType().cast<BaseTensorType>().getDtype();
|
Type dstOriginalDtype = cast<BaseTensorType>(clamp.getType()).getDtype();
|
||||||
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
|
bool isUnsigned = isa<QUInt8Type>(dstOriginalDtype);
|
||||||
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
|
if (auto intTy = dyn_cast<IntegerType>(dstOriginalDtype)) {
|
||||||
isUnsigned = intTy.isUnsigned();
|
isUnsigned = intTy.isUnsigned();
|
||||||
|
@ -1219,9 +1209,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
clampTensor.emitError("unimplemented: runtime optional type");
|
clampTensor.emitError("unimplemented: runtime optional type");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
Type dtype = converter->convertType(clampTensor.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(clampTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
bool isMinNone = true;
|
bool isMinNone = true;
|
||||||
auto result = payloadArgs[0];
|
auto result = payloadArgs[0];
|
||||||
if (!min.getType().isa<Torch::NoneType>()) {
|
if (!min.getType().isa<Torch::NoneType>()) {
|
||||||
|
@ -1263,8 +1253,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
if (auto rsub = dyn_cast<AtenRsubScalarOp>(op)) {
|
||||||
Type dtype = converter->convertType(rsub.getType())
|
Type dtype = cast<RankedTensorType>(converter->convertType(rsub.getType()))
|
||||||
.cast<RankedTensorType>()
|
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
Value other = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
|
@ -1283,9 +1272,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
|
if (auto mulScalar = dyn_cast<AtenMulScalarOp>(op)) {
|
||||||
Type dtype = converter->convertType(mulScalar.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(mulScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
|
Value rhs = convertScalarToDtype(b, loc, operands[1], dtype);
|
||||||
if (isa<mlir::FloatType>(dtype))
|
if (isa<mlir::FloatType>(dtype))
|
||||||
|
@ -1297,9 +1286,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
if (auto atenToDtype = dyn_cast<AtenToDtypeOp>(op)) {
|
||||||
Value input = payloadArgs[0];
|
Value input = payloadArgs[0];
|
||||||
Type dtype = converter->convertType(atenToDtype.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(atenToDtype.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Type resultElementType;
|
Type resultElementType;
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
if (!matchPattern(atenToDtype.getDtype(), m_TorchConstantInt(&dtypeInt))) {
|
||||||
|
@ -1320,9 +1309,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
if (auto divScalar = dyn_cast<AtenDivScalarOp>(op)) {
|
||||||
Type dtype = converter->convertType(divScalar.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(divScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
if (!isa<mlir::FloatType>(dtype)) {
|
if (!isa<mlir::FloatType>(dtype)) {
|
||||||
divScalar.emitError("unimplemented: non-floating point dtype");
|
divScalar.emitError("unimplemented: non-floating point dtype");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1395,9 +1384,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
|
||||||
Type dtype = converter->convertType(reciprocal.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(reciprocal.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value arg = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
Type elementType = arg.getType();
|
Type elementType = arg.getType();
|
||||||
// assert(element != 0)
|
// assert(element != 0)
|
||||||
|
@ -1416,9 +1405,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
// The approach used here is as follows:
|
// The approach used here is as follows:
|
||||||
// result = self <= threshold ? value : self
|
// result = self <= threshold ? value : self
|
||||||
AtenThresholdOp::Adaptor adaptor(operands);
|
AtenThresholdOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(thresholdOp.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(thresholdOp.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
Value self = payloadArgs[0];
|
Value self = payloadArgs[0];
|
||||||
Value threshold =
|
Value threshold =
|
||||||
|
@ -1438,8 +1427,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
// The approach used here is as follows:
|
// The approach used here is as follows:
|
||||||
// result = self <= threshold ? 0 : grad
|
// result = self <= threshold ? 0 : grad
|
||||||
AtenThresholdBackwardOp::Adaptor adaptor(operands);
|
AtenThresholdBackwardOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(thresholdBackward.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(thresholdBackward.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
Value grad = convertScalarToDtype(b, loc, payloadArgs[0], dtype);
|
||||||
|
@ -1459,15 +1448,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto fillScalar = dyn_cast<AtenFillScalarOp>(op)) {
|
if (auto fillScalar = dyn_cast<AtenFillScalarOp>(op)) {
|
||||||
AtenFillScalarOp::Adaptor adaptor(operands);
|
AtenFillScalarOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(fillScalar.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(fillScalar.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
return convertScalarToDtype(b, loc, adaptor.getValue(), dtype);
|
||||||
}
|
}
|
||||||
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
|
if (auto maskedFillTensor = dyn_cast<AtenMaskedFillTensorOp>(op)) {
|
||||||
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
AtenMaskedFillScalarOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(maskedFillTensor.getType())
|
Type dtype = cast<RankedTensorType>(
|
||||||
.cast<RankedTensorType>()
|
converter->convertType(maskedFillTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
|
|
||||||
Value input = payloadArgs[0];
|
Value input = payloadArgs[0];
|
||||||
|
@ -1477,9 +1466,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
}
|
}
|
||||||
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
|
if (auto fillTensor = dyn_cast<AtenFillTensorOp>(op)) {
|
||||||
AtenFillTensorOp::Adaptor adaptor(operands);
|
AtenFillTensorOp::Adaptor adaptor(operands);
|
||||||
Type dtype = converter->convertType(fillTensor.getType())
|
Type dtype =
|
||||||
.cast<RankedTensorType>()
|
cast<RankedTensorType>(converter->convertType(fillTensor.getType()))
|
||||||
.getElementType();
|
.getElementType();
|
||||||
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
return convertScalarToDtype(b, loc, payloadArgs[1], dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1519,7 +1508,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
|
||||||
auto value = payloadArgs[0];
|
auto value = payloadArgs[0];
|
||||||
auto valueTy = value.getType();
|
auto valueTy = value.getType();
|
||||||
auto qtensor = op->getOperand(0);
|
auto qtensor = op->getOperand(0);
|
||||||
auto qtensorTy = qtensor.getType().cast<ValueTensorType>().getDtype();
|
auto qtensorTy = cast<ValueTensorType>(qtensor.getType()).getDtype();
|
||||||
|
|
||||||
Value zp, scale;
|
Value zp, scale;
|
||||||
if (auto makeQTensor =
|
if (auto makeQTensor =
|
||||||
|
@ -1744,8 +1733,8 @@ public:
|
||||||
Value ignoreIndex = adaptor.getIgnoreIndex();
|
Value ignoreIndex = adaptor.getIgnoreIndex();
|
||||||
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
Value ignoreIndexVal = castIntToIndex(rewriter, loc, ignoreIndex);
|
||||||
|
|
||||||
unsigned inputRank = input.getType().cast<RankedTensorType>().getRank();
|
unsigned inputRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
unsigned targetRank = target.getType().cast<RankedTensorType>().getRank();
|
unsigned targetRank = cast<RankedTensorType>(target.getType()).getRank();
|
||||||
|
|
||||||
// TODO: Add support for k-dim loss.
|
// TODO: Add support for k-dim loss.
|
||||||
if (inputRank > 2) {
|
if (inputRank > 2) {
|
||||||
|
@ -1931,11 +1920,11 @@ public:
|
||||||
failed(checkNotNone(rewriter, op, runningVar)))
|
failed(checkNotNone(rewriter, op, runningVar)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
auto weightType = cast<RankedTensorType>(weight.getType());
|
||||||
auto biasType = bias.getType().cast<RankedTensorType>();
|
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||||
auto runningMeanType = runningMean.getType().cast<RankedTensorType>();
|
auto runningMeanType = cast<RankedTensorType>(runningMean.getType());
|
||||||
auto runningVarType = runningVar.getType().cast<RankedTensorType>();
|
auto runningVarType = cast<RankedTensorType>(runningVar.getType());
|
||||||
|
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
if (inputRank < 2)
|
if (inputRank < 2)
|
||||||
|
@ -2032,9 +2021,9 @@ public:
|
||||||
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
|
Value ignoreIndex = castIntToIndex(rewriter, loc, adaptor.getIgnoreIndex());
|
||||||
Value totalWeight = adaptor.getTotalWeight();
|
Value totalWeight = adaptor.getTotalWeight();
|
||||||
|
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
int inputRank = inputType.getRank();
|
int inputRank = inputType.getRank();
|
||||||
auto gradOutputType = gradOutput.getType().cast<RankedTensorType>();
|
auto gradOutputType = cast<RankedTensorType>(gradOutput.getType());
|
||||||
Type resultElementType = gradOutputType.getElementType();
|
Type resultElementType = gradOutputType.getElementType();
|
||||||
|
|
||||||
int64_t reduction;
|
int64_t reduction;
|
||||||
|
@ -2059,7 +2048,7 @@ public:
|
||||||
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
|
createZeroInitTensor(rewriter, loc, outputSize, resultElementType);
|
||||||
|
|
||||||
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
|
auto getAffineMapForSingleElementTensor = [&](Value tensor) {
|
||||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||||
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
|
SmallVector<AffineExpr> affineExprs(tensorType.getRank(),
|
||||||
rewriter.getAffineConstantExpr(0));
|
rewriter.getAffineConstantExpr(0));
|
||||||
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
|
return AffineMap::get(inputRank, /*symbolCount=*/0, affineExprs,
|
||||||
|
@ -2188,12 +2177,12 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>();
|
auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
|
||||||
|
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
auto resultRankedTensorType =
|
auto resultRankedTensorType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
// The dimension being split must be statically known.
|
// The dimension being split must be statically known.
|
||||||
|
|
||||||
|
@ -2233,11 +2222,11 @@ public:
|
||||||
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto aRankedTensorType = adaptor.getA().getType().cast<RankedTensorType>();
|
auto aRankedTensorType = cast<RankedTensorType>(adaptor.getA().getType());
|
||||||
const TypeConverter *typeConverter = getTypeConverter();
|
const TypeConverter *typeConverter = getTypeConverter();
|
||||||
|
|
||||||
auto resultRankedTensorType =
|
auto resultRankedTensorType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
// Collapse range must be statically known.
|
// Collapse range must be statically known.
|
||||||
int64_t startInt;
|
int64_t startInt;
|
||||||
|
@ -2328,7 +2317,7 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElementType = inputType.getElementType();
|
auto inputElementType = inputType.getElementType();
|
||||||
|
|
||||||
if (!isa<mlir::FloatType>(inputElementType)) {
|
if (!isa<mlir::FloatType>(inputElementType)) {
|
||||||
|
@ -2433,8 +2422,8 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto operandDTy = operand.getType().cast<ValueTensorType>().getDtype();
|
auto operandDTy = cast<ValueTensorType>(operand.getType()).getDtype();
|
||||||
auto zeropointDTy = zeropoint.getType().cast<ValueTensorType>().getDtype();
|
auto zeropointDTy = cast<ValueTensorType>(zeropoint.getType()).getDtype();
|
||||||
operand = converter->materializeTargetConversion(
|
operand = converter->materializeTargetConversion(
|
||||||
rewriter, loc, converter->convertType(operand.getType()), operand);
|
rewriter, loc, converter->convertType(operand.getType()), operand);
|
||||||
scale = converter->materializeTargetConversion(
|
scale = converter->materializeTargetConversion(
|
||||||
|
@ -2537,7 +2526,7 @@ public:
|
||||||
Value twoFloat = rewriter.create<arith::ConstantOp>(
|
Value twoFloat = rewriter.create<arith::ConstantOp>(
|
||||||
loc, rewriter.getFloatAttr(floatType, 2.0));
|
loc, rewriter.getFloatAttr(floatType, 2.0));
|
||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
|
Value innerDim0a = rewriter.create<tensor::DimOp>(loc, input, 2);
|
||||||
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
|
Value innerDim1a = rewriter.create<tensor::DimOp>(loc, input, 3);
|
||||||
|
@ -2558,7 +2547,7 @@ public:
|
||||||
Value innerDim1e =
|
Value innerDim1e =
|
||||||
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
|
rewriter.create<arith::DivFOp>(loc, innerDim1d, twoFloat);
|
||||||
Value grid = adaptor.getGrid();
|
Value grid = adaptor.getGrid();
|
||||||
auto gridType = grid.getType().cast<RankedTensorType>();
|
auto gridType = cast<RankedTensorType>(grid.getType());
|
||||||
auto gridShape = gridType.getShape();
|
auto gridShape = gridType.getShape();
|
||||||
auto gridRank = gridType.getRank();
|
auto gridRank = gridType.getRank();
|
||||||
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
|
SmallVector<Value> extractGridOffsets0(gridRank, zeroIndex);
|
||||||
|
|
|
@ -37,9 +37,8 @@ Value torch_to_linalg::getPaddedTensor(
|
||||||
SmallVectorImpl<int64_t> &lowPaddingInts,
|
SmallVectorImpl<int64_t> &lowPaddingInts,
|
||||||
SmallVectorImpl<int64_t> &highPaddingInts, Value pad) {
|
SmallVectorImpl<int64_t> &highPaddingInts, Value pad) {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Type rankedTensorType =
|
Type rankedTensorType = tensor::PadOp::inferResultType(
|
||||||
tensor::PadOp::inferResultType(input.getType().cast<RankedTensorType>(),
|
cast<RankedTensorType>(input.getType()), lowPaddingInts, highPaddingInts);
|
||||||
lowPaddingInts, highPaddingInts);
|
|
||||||
SmallVector<OpFoldResult> lowPaddings =
|
SmallVector<OpFoldResult> lowPaddings =
|
||||||
getIndexIntsAsOpFoldResult(b, lowPaddingInts);
|
getIndexIntsAsOpFoldResult(b, lowPaddingInts);
|
||||||
SmallVector<OpFoldResult> highPaddings =
|
SmallVector<OpFoldResult> highPaddings =
|
||||||
|
@ -61,7 +60,7 @@ Value torch_to_linalg::getZeroPaddedTensor(
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value c0 = b.create<arith::ConstantOp>(
|
Value c0 = b.create<arith::ConstantOp>(
|
||||||
loc,
|
loc,
|
||||||
b.getZeroAttr(input.getType().cast<RankedTensorType>().getElementType()));
|
b.getZeroAttr(cast<RankedTensorType>(input.getType()).getElementType()));
|
||||||
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
|
return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +72,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||||
int unpaddedDims, Value pad) {
|
int unpaddedDims, Value pad) {
|
||||||
assert(input.getType().isa<RankedTensorType>() &&
|
assert(input.getType().isa<RankedTensorType>() &&
|
||||||
"input must be RankedTensorType");
|
"input must be RankedTensorType");
|
||||||
unsigned int inRank = input.getType().cast<RankedTensorType>().getRank();
|
unsigned int inRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
|
|
||||||
SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
|
SmallVector<Value> inputDims = getTensorSizes(b, loc, input);
|
||||||
|
@ -86,7 +85,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor(
|
||||||
pad < paddingIncludingUnchanged.end(); pad++)
|
pad < paddingIncludingUnchanged.end(); pad++)
|
||||||
*pad = castIntToIndex(b, loc, *pad);
|
*pad = castIntToIndex(b, loc, *pad);
|
||||||
|
|
||||||
Type elementType = input.getType().cast<RankedTensorType>().getElementType();
|
Type elementType = cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
// TODO: audit possibility of sparsity on this tensor
|
// TODO: audit possibility of sparsity on this tensor
|
||||||
Type inputType =
|
Type inputType =
|
||||||
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
|
RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef<int64_t>(
|
||||||
|
@ -158,7 +157,7 @@ Value torch_to_linalg::getOutputDimForConvTransposeOps(
|
||||||
Value torch_to_linalg::createReductionLinalgGeneric(
|
Value torch_to_linalg::createReductionLinalgGeneric(
|
||||||
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
|
OpBuilder &b, Location loc, const ReductionOpInfo &opInfo, Value initElem,
|
||||||
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
|
||||||
auto inputType = opInfo.tensorOperand.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(opInfo.tensorOperand.getType());
|
||||||
|
|
||||||
// Get the result shape by obtaining the size of each
|
// Get the result shape by obtaining the size of each
|
||||||
// dimension in the input tensor that is not getting reduced.
|
// dimension in the input tensor that is not getting reduced.
|
||||||
|
@ -237,7 +236,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
SmallVector<int64_t> operandRanks;
|
SmallVector<int64_t> operandRanks;
|
||||||
operandRanks.resize(tensorOperands.size());
|
operandRanks.resize(tensorOperands.size());
|
||||||
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
llvm::transform(tensorOperands, operandRanks.begin(), [](Value tensor) {
|
||||||
return tensor.getType().dyn_cast<RankedTensorType>().getRank();
|
return dyn_cast<RankedTensorType>(tensor.getType()).getRank();
|
||||||
});
|
});
|
||||||
|
|
||||||
auto resultRankIt =
|
auto resultRankIt =
|
||||||
|
@ -253,7 +252,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric(
|
||||||
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
|
bool elideDynamicBroadcastCheck = isAssumingStrictSymbolicShapes(b);
|
||||||
for (Value tensorOperand : tensorOperands) {
|
for (Value tensorOperand : tensorOperands) {
|
||||||
SmallVector<AffineExpr> exprs;
|
SmallVector<AffineExpr> exprs;
|
||||||
auto type = tensorOperand.getType().cast<RankedTensorType>();
|
auto type = cast<RankedTensorType>(tensorOperand.getType());
|
||||||
for (auto size :
|
for (auto size :
|
||||||
llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) {
|
llvm::enumerate(makeShapeTorchCompatible(type.getShape()))) {
|
||||||
// If the size is statically known to be 1, we don't want any
|
// If the size is statically known to be 1, we don't want any
|
||||||
|
@ -327,7 +326,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
|
||||||
Operation *op, PatternRewriter &rewriter, Value input,
|
Operation *op, PatternRewriter &rewriter, Value input,
|
||||||
SmallVector<Value> broadcastToShape, RankedTensorType broadcastType,
|
SmallVector<Value> broadcastToShape, RankedTensorType broadcastType,
|
||||||
Value &result, SmallVector<bool> useBroadcastToShape) {
|
Value &result, SmallVector<bool> useBroadcastToShape) {
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
int64_t outputRank = broadcastToShape.size();
|
int64_t outputRank = broadcastToShape.size();
|
||||||
ArrayRef<int64_t> outputShape = broadcastType.getShape();
|
ArrayRef<int64_t> outputShape = broadcastType.getShape();
|
||||||
|
@ -525,7 +524,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape(
|
||||||
|
|
||||||
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
auto tensorType = tensor.getType().cast<RankedTensorType>();
|
auto tensorType = cast<RankedTensorType>(tensor.getType());
|
||||||
auto rank = tensorType.getRank();
|
auto rank = tensorType.getRank();
|
||||||
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
|
SmallVector<int64_t> unknownSizes(rank, kUnknownSize);
|
||||||
return b.create<tensor::CastOp>(
|
return b.create<tensor::CastOp>(
|
||||||
|
|
|
@ -66,8 +66,8 @@ Value getConstantLike(OpBuilder &b, Location loc, const APFloat &constant,
|
||||||
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
|
||||||
mlir::Value &self, mlir::Value &other,
|
mlir::Value &self, mlir::Value &other,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
auto selfTy = self.getType().template dyn_cast<RankedTensorType>();
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||||
auto otherTy = other.getType().template dyn_cast<RankedTensorType>();
|
auto otherTy = dyn_cast<RankedTensorType>(other.getType());
|
||||||
auto selfRank = selfTy.getRank();
|
auto selfRank = selfTy.getRank();
|
||||||
auto otherRank = otherTy.getRank();
|
auto otherRank = otherTy.getRank();
|
||||||
if (selfRank == 0 || otherRank == 0)
|
if (selfRank == 0 || otherRank == 0)
|
||||||
|
@ -171,7 +171,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfType = self.getType().cast<TensorType>();
|
auto selfType = cast<TensorType>(self.getType());
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
}
|
}
|
||||||
|
@ -197,12 +197,12 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
|
||||||
if (selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(selfTy.getElementType())) {
|
||||||
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
rewriter.replaceOpWithNewOp<StablehloOpT>(
|
||||||
op,
|
op,
|
||||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||||
|
@ -229,14 +229,14 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
auto resultTy = OpConversionPattern<AtenOpT>::getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<TensorType>();
|
.template cast<TensorType>();
|
||||||
|
|
||||||
if (resultTy.getElementType().template isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(resultTy.getElementType())) {
|
||||||
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
Value src = hlo::promoteType(rewriter, op.getLoc(), self, resultTy);
|
||||||
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
rewriter.replaceOpWithNewOp<StablehloOpT>(op, resultTy, src);
|
||||||
return success();
|
return success();
|
||||||
|
@ -304,8 +304,7 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto inputType =
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getA().getType());
|
||||||
adaptor.getA().getType().template dyn_cast<RankedTensorType>();
|
|
||||||
if (!inputType)
|
if (!inputType)
|
||||||
|
|
||||||
op.emitError("only Tensor types supported in StableHLO");
|
op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
@ -313,8 +312,7 @@ public:
|
||||||
Value input = adaptor.getA();
|
Value input = adaptor.getA();
|
||||||
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
SmallVector<Value> inputSizes = getTensorSizes(rewriter, loc, input);
|
||||||
int64_t inputRank = inputSizes.size();
|
int64_t inputRank = inputSizes.size();
|
||||||
Type inputDtype =
|
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
||||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
|
||||||
|
|
||||||
Value constantOne =
|
Value constantOne =
|
||||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
|
@ -345,9 +343,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
auto lhsTy = cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
auto rhsTy = cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError("only Tensor types supported");
|
return op.emitError("only Tensor types supported");
|
||||||
|
@ -378,9 +376,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
RankedTensorType lhsType = lhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType lhsType = dyn_cast<RankedTensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
RankedTensorType rhsType = rhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType rhsType = dyn_cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
@ -433,9 +431,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
@ -527,8 +525,8 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
||||||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsTy)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
@ -616,8 +614,8 @@ public:
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
|
|
||||||
RankedTensorType lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
||||||
RankedTensorType rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
RankedTensorType rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsTy)
|
||||||
return op.emitError("lhs must be a ranked tensor type");
|
return op.emitError("lhs must be a ranked tensor type");
|
||||||
|
@ -659,11 +657,10 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
return rewriter.notifyMatchFailure(op, "dim1 must be constant");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto inType = self.getType().cast<RankedTensorType>();
|
auto inType = cast<RankedTensorType>(self.getType());
|
||||||
auto inputRank = inType.getRank();
|
auto inputRank = inType.getRank();
|
||||||
auto outType = getTypeConverter()
|
auto outType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
dim0 = toPositiveDim(dim0, inputRank);
|
dim0 = toPositiveDim(dim0, inputRank);
|
||||||
if (!isValidDim(dim0, inputRank)) {
|
if (!isValidDim(dim0, inputRank)) {
|
||||||
|
@ -691,7 +688,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
rewriter.replaceOpWithNewOp<stablehlo::ConvertOp>(op, outType, self);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -701,7 +698,7 @@ LogicalResult ConvertAtenOp<AtenSizeIntOp>::matchAndRewrite(
|
||||||
AtenSizeIntOp op, OpAdaptor adaptor,
|
AtenSizeIntOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return op.emitError("only tensor types are currently supported");
|
return op.emitError("only tensor types are currently supported");
|
||||||
|
|
||||||
|
@ -739,7 +736,7 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||||
Value other = adaptor.getOther();
|
Value other = adaptor.getOther();
|
||||||
|
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
// promote self and other types
|
// promote self and other types
|
||||||
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
self = hlo::promoteType(rewriter, op.getLoc(), self, outType);
|
||||||
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
other = hlo::promoteType(rewriter, op.getLoc(), other, outType);
|
||||||
|
@ -764,10 +761,9 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
AtenBroadcastToOp op, OpAdaptor adaptor,
|
AtenBroadcastToOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
auto outType = getTypeConverter()
|
auto outType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
if (options.enableStaticShape && selfTy.hasStaticShape()) {
|
||||||
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType);
|
||||||
|
@ -831,10 +827,9 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
// Not a ranked tensor type
|
// Not a ranked tensor type
|
||||||
auto inType = self.getType().dyn_cast<RankedTensorType>();
|
auto inType = dyn_cast<RankedTensorType>(self.getType());
|
||||||
auto outType = getTypeConverter()
|
auto outType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
if (!inType)
|
if (!inType)
|
||||||
return op.emitError("only ranked tensor types with static shapes are "
|
return op.emitError("only ranked tensor types with static shapes are "
|
||||||
"currently supported");
|
"currently supported");
|
||||||
|
@ -861,15 +856,14 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
ValueTensorLiteralOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
|
|
||||||
// Tensors with integer types need to be converted to signless integer
|
// Tensors with integer types need to be converted to signless integer
|
||||||
// element type. All tensors with element types other than integer can reuse
|
// element type. All tensors with element types other than integer can reuse
|
||||||
// existing elements attribute.
|
// existing elements attribute.
|
||||||
// TODO: what about unsigned integer?
|
// TODO: what about unsigned integer?
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
||||||
Type builtinTensorElemTy = resultType.getElementType();
|
Type builtinTensorElemTy = resultType.getElementType();
|
||||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||||
|
|
||||||
|
@ -892,9 +886,8 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenTensorIntOp>::matchAndRewrite(
|
||||||
AtenTensorIntOp op, OpAdaptor adaptor,
|
AtenTensorIntOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
RankedTensorType resultType = getTypeConverter()
|
RankedTensorType resultType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
Type outElementType = resultType.getElementType();
|
Type outElementType = resultType.getElementType();
|
||||||
Value innerValue = adaptor.getT();
|
Value innerValue = adaptor.getT();
|
||||||
Value stablehloTensor =
|
Value stablehloTensor =
|
||||||
|
@ -910,10 +903,10 @@ LogicalResult ConvertAtenOp<AtenReciprocalOp>::matchAndRewrite(
|
||||||
AtenReciprocalOp op, OpAdaptor adaptor,
|
AtenReciprocalOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto outTy =
|
auto outTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
if (!inputTy.getElementType().isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||||
return op.emitError("only floating-point datatype legalization supported "
|
return op.emitError("only floating-point datatype legalization supported "
|
||||||
"for AtenReciprocalOp");
|
"for AtenReciprocalOp");
|
||||||
}
|
}
|
||||||
|
@ -929,9 +922,9 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
AtenPowTensorScalarOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getExponent();
|
Value rhs = adaptor.getExponent();
|
||||||
TensorType rhsType = rhs.getType().dyn_cast<TensorType>();
|
TensorType rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return op.emitError("only Tensor types supported in StableHLO");
|
return op.emitError("only Tensor types supported in StableHLO");
|
||||||
|
@ -1002,9 +995,8 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<PrimNumToTensorScalarOp>::matchAndRewrite(
|
||||||
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
PrimNumToTensorScalarOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
RankedTensorType outputType = getTypeConverter()
|
RankedTensorType outputType = cast<RankedTensorType>(
|
||||||
->convertType(op->getResult(0).getType())
|
getTypeConverter()->convertType(op->getResult(0).getType()));
|
||||||
.cast<RankedTensorType>();
|
|
||||||
auto outputElemType = outputType.getElementType();
|
auto outputElemType = outputType.getElementType();
|
||||||
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
Value stablehloTensor = hlo::scalarToStablehloTensor(
|
||||||
rewriter, op, adaptor.getA(), outputElemType);
|
rewriter, op, adaptor.getA(), outputElemType);
|
||||||
|
@ -1018,8 +1010,7 @@ LogicalResult ConvertAtenOp<AtenScalarImplicitOp>::matchAndRewrite(
|
||||||
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
AtenScalarImplicitOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Type inputDtype =
|
Type inputDtype = cast<BaseTensorType>(op.getA().getType()).getDtype();
|
||||||
op.getA().getType().template cast<BaseTensorType>().getDtype();
|
|
||||||
Type resultType =
|
Type resultType =
|
||||||
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
this->getTypeConverter()->convertType(op->getResult(0).getType());
|
||||||
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
auto result = rewriter.create<tensor::ExtractOp>(loc, adaptor.getA());
|
||||||
|
@ -1037,7 +1028,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return op.emitError("only tensor types are currently supported");
|
return op.emitError("only tensor types are currently supported");
|
||||||
|
|
||||||
|
@ -1055,7 +1046,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
AtenReluOp op, OpAdaptor adaptor,
|
AtenReluOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
auto lhsElemTy = lhsTy.getElementType();
|
auto lhsElemTy = lhsTy.getElementType();
|
||||||
|
|
||||||
if (!isa<mlir::FloatType>(lhsElemTy)) {
|
if (!isa<mlir::FloatType>(lhsElemTy)) {
|
||||||
|
@ -1080,7 +1071,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return op.emitError("only ranked tensor type is supported.");
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
}
|
}
|
||||||
|
@ -1103,11 +1094,11 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
AtenLog2Op op, OpAdaptor adaptor,
|
AtenLog2Op op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return op.emitError("only ranked tensor type is supported.");
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
}
|
}
|
||||||
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||||
|
|
||||||
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
auto two = getConstantLike(rewriter, op.getLoc(), 2.0, input);
|
||||||
|
@ -1124,12 +1115,12 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
|
||||||
AtenLog10Op op, OpAdaptor adaptor,
|
AtenLog10Op op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return op.emitError("only ranked tensor type is supported.");
|
return op.emitError("only ranked tensor type is supported.");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outTy = getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
auto outTy = cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||||
|
|
||||||
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
auto ten = getConstantLike(rewriter, op.getLoc(), 10.0, input);
|
||||||
|
@ -1146,8 +1137,8 @@ LogicalResult ConvertAtenOp<AtenErfOp>::matchAndRewrite(
|
||||||
AtenErfOp op, OpAdaptor adaptor,
|
AtenErfOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().cast<TensorType>();
|
auto inputType = cast<TensorType>(input.getType());
|
||||||
if (!inputType.getElementType().isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(inputType.getElementType())) {
|
||||||
return rewriter.notifyMatchFailure(op, "only float tensor is supported");
|
return rewriter.notifyMatchFailure(op, "only float tensor is supported");
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<chlo::ErfOp>(
|
rewriter.replaceOpWithNewOp<chlo::ErfOp>(
|
||||||
|
@ -1161,7 +1152,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
AtenBatchNormOp op, OpAdaptor adaptor,
|
AtenBatchNormOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
Value weight = adaptor.getWeight();
|
Value weight = adaptor.getWeight();
|
||||||
Value bias = adaptor.getBias();
|
Value bias = adaptor.getBias();
|
||||||
Value runningMean = adaptor.getRunningMean();
|
Value runningMean = adaptor.getRunningMean();
|
||||||
|
@ -1174,10 +1165,10 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
// all of NC, NCL, NCHW, NCDHW's feature index is 1.
|
// all of NC, NCL, NCHW, NCDHW's feature index is 1.
|
||||||
int64_t feature_index = 1;
|
int64_t feature_index = 1;
|
||||||
|
|
||||||
if (!inputTy.getElementType().template isa<mlir::FloatType>()) {
|
if (!isa<mlir::FloatType>(inputTy.getElementType())) {
|
||||||
return op.emitError("only input tensor of float type is supported");
|
return op.emitError("only input tensor of float type is supported");
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType().cast<mlir::FloatType>();
|
auto inputElemTy = cast<mlir::FloatType>(inputTy.getElementType());
|
||||||
|
|
||||||
Value channelDim =
|
Value channelDim =
|
||||||
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
|
rewriter.create<tensor::DimOp>(op->getLoc(), input, feature_index);
|
||||||
|
@ -1220,20 +1211,20 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
inputTy.getElementType()));
|
inputTy.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
auto biasTy = cast<RankedTensorType>(bias.getType());
|
||||||
auto runningMeanTy = runningMean.getType().cast<RankedTensorType>();
|
auto runningMeanTy = cast<RankedTensorType>(runningMean.getType());
|
||||||
auto runningVarTy = runningVar.getType().cast<RankedTensorType>();
|
auto runningVarTy = cast<RankedTensorType>(runningVar.getType());
|
||||||
|
|
||||||
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
|
if (weightTy.getRank() != 1 || biasTy.getRank() != 1 ||
|
||||||
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
|
runningMeanTy.getRank() != 1 || runningVarTy.getRank() != 1) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
op, "expect weight, bias, running_mean and running_var to be rank 1");
|
||||||
}
|
}
|
||||||
if (!weightTy.getElementType().template isa<mlir::FloatType>() ||
|
if (!isa<mlir::FloatType>(weightTy.getElementType()) ||
|
||||||
!biasTy.getElementType().template isa<mlir::FloatType>() ||
|
!isa<mlir::FloatType>(biasTy.getElementType()) ||
|
||||||
!runningMeanTy.getElementType().template isa<mlir::FloatType>() ||
|
!isa<mlir::FloatType>(runningMeanTy.getElementType()) ||
|
||||||
!runningVarTy.getElementType().template isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(runningVarTy.getElementType())) {
|
||||||
return op.emitError("only float weight/bias/runningMean/runningVar tensor "
|
return op.emitError("only float weight/bias/runningMean/runningVar tensor "
|
||||||
"of float type is supported");
|
"of float type is supported");
|
||||||
}
|
}
|
||||||
|
@ -1261,8 +1252,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||||
RankedTensorType convertedType = inputTy;
|
RankedTensorType convertedType = inputTy;
|
||||||
if (weightTy.getElementType().cast<FloatType>().getWidth() >
|
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
||||||
inputTy.getElementType().cast<FloatType>().getWidth()) {
|
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
||||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||||
weightTy.getElementType());
|
weightTy.getElementType());
|
||||||
}
|
}
|
||||||
|
@ -1302,8 +1293,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
// supported mixed types, like input type is fp16 and weight type is fp32.
|
// supported mixed types, like input type is fp16 and weight type is fp32.
|
||||||
if (inputTy.getElementType() != weightTy.getElementType()) {
|
if (inputTy.getElementType() != weightTy.getElementType()) {
|
||||||
RankedTensorType convertedType = inputTy;
|
RankedTensorType convertedType = inputTy;
|
||||||
if (weightTy.getElementType().cast<FloatType>().getWidth() >
|
if (cast<FloatType>(weightTy.getElementType()).getWidth() >
|
||||||
inputTy.getElementType().cast<FloatType>().getWidth()) {
|
cast<FloatType>(inputTy.getElementType()).getWidth()) {
|
||||||
convertedType = RankedTensorType::get(inputTy.getShape(),
|
convertedType = RankedTensorType::get(inputTy.getShape(),
|
||||||
weightTy.getElementType());
|
weightTy.getElementType());
|
||||||
}
|
}
|
||||||
|
@ -1340,7 +1331,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
AtenNativeLayerNormOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getInput();
|
Value input = adaptor.getInput();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputShape = inputTy.getShape();
|
auto inputShape = inputTy.getShape();
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
Value weight = adaptor.getWeight();
|
Value weight = adaptor.getWeight();
|
||||||
|
@ -1365,12 +1356,12 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
failed(checkNotNone(rewriter, op, bias))) {
|
failed(checkNotNone(rewriter, op, bias))) {
|
||||||
return op->emitError("none weight or bias is unsupported");
|
return op->emitError("none weight or bias is unsupported");
|
||||||
}
|
}
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
auto biasTy = cast<RankedTensorType>(bias.getType());
|
||||||
|
|
||||||
if (!inputTy.getElementType().isa<mlir::FloatType>() ||
|
if (!isa<mlir::FloatType>(inputTy.getElementType()) ||
|
||||||
!biasTy.getElementType().isa<mlir::FloatType>() ||
|
!isa<mlir::FloatType>(biasTy.getElementType()) ||
|
||||||
!weightTy.getElementType().isa<mlir::FloatType>()) {
|
!isa<mlir::FloatType>(weightTy.getElementType())) {
|
||||||
return op->emitError("currently only float data type are supported");
|
return op->emitError("currently only float data type are supported");
|
||||||
}
|
}
|
||||||
int64_t normalizedShapeRank = normalizedShape.size();
|
int64_t normalizedShapeRank = normalizedShape.size();
|
||||||
|
@ -1423,7 +1414,7 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
SmallVector<APFloat> oneConstVec(
|
SmallVector<APFloat> oneConstVec(
|
||||||
numFeatureDimSize,
|
numFeatureDimSize,
|
||||||
APFloat(
|
APFloat(
|
||||||
inputTy.getElementType().cast<mlir::FloatType>().getFloatSemantics(),
|
cast<mlir::FloatType>(inputTy.getElementType()).getFloatSemantics(),
|
||||||
1));
|
1));
|
||||||
auto oneOrZeroConstType =
|
auto oneOrZeroConstType =
|
||||||
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType());
|
||||||
|
@ -1443,9 +1434,9 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Reshape back
|
// Reshape back
|
||||||
auto outputTy =
|
auto outputTy =
|
||||||
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
||||||
auto outputMeanOrVarTy =
|
auto outputMeanOrVarTy =
|
||||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||||
|
|
||||||
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
auto output = rewriter.create<stablehlo::DynamicReshapeOp>(
|
||||||
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
op->getLoc(), outputTy, batchNormTrainingResult.getResult(0),
|
||||||
|
@ -1482,7 +1473,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
AtenCatOp op, OpAdaptor adaptor,
|
AtenCatOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -1516,7 +1507,7 @@ LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
|
||||||
AtenNumelOp op, OpAdaptor adaptor,
|
AtenNumelOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
|
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
|
||||||
size_t rank = selfTy.getRank();
|
size_t rank = selfTy.getRank();
|
||||||
|
|
||||||
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
|
||||||
|
@ -1544,7 +1535,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
AtenClampOp op, OpAdaptor adaptor,
|
AtenClampOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElemType = inputType.getElementType();
|
auto inputElemType = inputType.getElementType();
|
||||||
Value minValue = adaptor.getMin();
|
Value minValue = adaptor.getMin();
|
||||||
Value maxValue = adaptor.getMax();
|
Value maxValue = adaptor.getMax();
|
||||||
|
@ -1716,7 +1707,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto outType =
|
auto outType =
|
||||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||||
if (!outType) {
|
if (!outType) {
|
||||||
return op.emitError("only tensor type is supported");
|
return op.emitError("only tensor type is supported");
|
||||||
}
|
}
|
||||||
|
@ -1764,15 +1755,15 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
|
||||||
AtenPowTensorTensorOp op, OpAdaptor adaptor,
|
AtenPowTensorTensorOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
auto lhsTy = cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getExponent();
|
Value rhs = adaptor.getExponent();
|
||||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
auto rhsTy = cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError("only Tensor types supported");
|
return op.emitError("only Tensor types supported");
|
||||||
|
|
||||||
auto outTy =
|
auto outTy =
|
||||||
this->getTypeConverter()->convertType(op.getType()).cast<TensorType>();
|
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outTy);
|
||||||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outTy);
|
||||||
|
@ -1790,12 +1781,12 @@ LogicalResult ConvertAtenOp<AtenUniformOp>::matchAndRewrite(
|
||||||
Value generator = adaptor.getGenerator();
|
Value generator = adaptor.getGenerator();
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
if (!generator.getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(generator.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
|
||||||
auto elements = self.getType().cast<RankedTensorType>().getShape();
|
auto elements = cast<RankedTensorType>(self.getType()).getShape();
|
||||||
if (llvm::any_of(elements,
|
if (llvm::any_of(elements,
|
||||||
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
[](int64_t dim) { return dim == ShapedType::kDynamic; }))
|
||||||
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD");
|
||||||
|
@ -1824,14 +1815,14 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
||||||
|
|
||||||
// The pin_memory should be either `False` or `none`.
|
// The pin_memory should be either `False` or `none`.
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
pinMemory))
|
pinMemory))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: pin_memory must be either None or false");
|
op, "unimplemented: pin_memory must be either None or false");
|
||||||
|
|
||||||
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
// Only `none`, `contiguous` and `preserve` memory_format is supported.
|
||||||
if (!op.getMemoryFormat().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType())) {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1844,7 +1835,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
||||||
"memory_format is supported");
|
"memory_format is supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!op.getDevice().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||||
std::string device;
|
std::string device;
|
||||||
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
if (!matchPattern(op.getDevice(), m_TorchConstantDevice(device)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1853,7 +1844,7 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
||||||
|
|
||||||
// TODO: Add support for non-strided layout.
|
// TODO: Add support for non-strided layout.
|
||||||
// torch.layout is by default strided i.e. 0.
|
// torch.layout is by default strided i.e. 0.
|
||||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||||
int64_t tensorLayout;
|
int64_t tensorLayout;
|
||||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1876,9 +1867,9 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
|
||||||
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
resultSizeIndex.push_back(castIntToIndex(rewriter, loc, size));
|
||||||
|
|
||||||
auto resultType =
|
auto resultType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
Type resultElementType;
|
Type resultElementType;
|
||||||
if (op.getDtype().getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||||
resultElementType = resultType.getElementType();
|
resultElementType = resultType.getElementType();
|
||||||
} else {
|
} else {
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
|
@ -1931,7 +1922,7 @@ LogicalResult ConvertAtenOp<AtenFillScalarOp>::matchAndRewrite(
|
||||||
AtenFillScalarOp op, OpAdaptor adaptor,
|
AtenFillScalarOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
auto dtype = outType.getElementType();
|
auto dtype = outType.getElementType();
|
||||||
Value scalarTensor =
|
Value scalarTensor =
|
||||||
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype);
|
||||||
|
@ -1951,7 +1942,7 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
SmallVector<int64_t> dims;
|
SmallVector<int64_t> dims;
|
||||||
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
|
if (!matchPattern(op.getDims(), m_TorchListOfConstantInts(dims))) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
loc, rewriter.getIntegerAttr(intType, 1));
|
loc, rewriter.getIntegerAttr(intType, 1));
|
||||||
|
|
||||||
// sliceSizes
|
// sliceSizes
|
||||||
auto inputRankTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputRankTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputRankTy.getRank();
|
auto inputRank = inputRankTy.getRank();
|
||||||
SmallVector<Value, 4> sliceSizes;
|
SmallVector<Value, 4> sliceSizes;
|
||||||
sliceSizes.reserve(inputRank);
|
sliceSizes.reserve(inputRank);
|
||||||
|
@ -85,7 +85,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
|
||||||
for (int64_t r = 0; r < axis; ++r) {
|
for (int64_t r = 0; r < axis; ++r) {
|
||||||
offsetDims.push_back(r);
|
offsetDims.push_back(r);
|
||||||
}
|
}
|
||||||
auto indicesRankTy = indices.getType().dyn_cast<RankedTensorType>();
|
auto indicesRankTy = dyn_cast<RankedTensorType>(indices.getType());
|
||||||
auto indicesRank = indicesRankTy.getRank();
|
auto indicesRank = indicesRankTy.getRank();
|
||||||
for (int64_t r = axis + 1; r < inputRank; ++r) {
|
for (int64_t r = axis + 1; r < inputRank; ++r) {
|
||||||
offsetDims.push_back(r + indicesRank - 1);
|
offsetDims.push_back(r + indicesRank - 1);
|
||||||
|
@ -132,8 +132,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
SmallVector<Value> &strides) {
|
SmallVector<Value> &strides) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
RankedTensorType inputType =
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
input.getType().template cast<RankedTensorType>();
|
|
||||||
|
|
||||||
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
|
@ -161,7 +160,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
|
||||||
|
|
||||||
int64_t step;
|
int64_t step;
|
||||||
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
|
||||||
if (!op.getStep().getType().template isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getStep().getType()))
|
||||||
return op->emitError("unimplemented: step is not constant");
|
return op->emitError("unimplemented: step is not constant");
|
||||||
step = 1;
|
step = 1;
|
||||||
}
|
}
|
||||||
|
@ -225,7 +224,7 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
||||||
// concat index tensor into to indices tensor for concat
|
// concat index tensor into to indices tensor for concat
|
||||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||||
auto indexTensor = indexTensors[i];
|
auto indexTensor = indexTensors[i];
|
||||||
auto indexTensorType = indexTensor.getType().cast<RankedTensorType>();
|
auto indexTensorType = cast<RankedTensorType>(indexTensor.getType());
|
||||||
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
for (int64_t size : makeShapeTorchCompatible(indexTensorType.getShape())) {
|
||||||
if (size == kUnknownSize)
|
if (size == kUnknownSize)
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -249,7 +248,7 @@ FailureOr<Value> broadcastAndConcatIndices(Operation *op,
|
||||||
|
|
||||||
SmallVector<Value> broadcastedIndices;
|
SmallVector<Value> broadcastedIndices;
|
||||||
Type indexElemTy =
|
Type indexElemTy =
|
||||||
indexTensors[0].getType().cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(indexTensors[0].getType()).getElementType();
|
||||||
RankedTensorType bcastIndexType =
|
RankedTensorType bcastIndexType =
|
||||||
RankedTensorType::get(indicesShape, indexElemTy);
|
RankedTensorType::get(indicesShape, indexElemTy);
|
||||||
for (auto indexTensor : indexTensors) {
|
for (auto indexTensor : indexTensors) {
|
||||||
|
@ -290,7 +289,7 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
AtenEmbeddingOp op, OpAdaptor adaptor,
|
AtenEmbeddingOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto weight = adaptor.getWeight();
|
auto weight = adaptor.getWeight();
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
if (!weightTy)
|
if (!weightTy)
|
||||||
return op.emitError("only ranked tensor types are supported");
|
return op.emitError("only ranked tensor types are supported");
|
||||||
|
|
||||||
|
@ -332,17 +331,17 @@ LogicalResult ConvertAtenOp<AtenEmbeddingBagPaddingIdxOp>::matchAndRewrite(
|
||||||
Value indices = adaptor.getIndices();
|
Value indices = adaptor.getIndices();
|
||||||
Value offsets = adaptor.getOffsets();
|
Value offsets = adaptor.getOffsets();
|
||||||
|
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2)
|
if (weightTy && weightTy.hasStaticShape() && weightTy.getRank() != 2)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "weight must be rank 2 tensor with static shapes");
|
op, "weight must be rank 2 tensor with static shapes");
|
||||||
|
|
||||||
auto indicesTy = indices.getType().cast<RankedTensorType>();
|
auto indicesTy = cast<RankedTensorType>(indices.getType());
|
||||||
if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1)
|
if (indicesTy && indicesTy.hasStaticShape() && indicesTy.getRank() != 1)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "indices must be a vector with static shapes");
|
op, "indices must be a vector with static shapes");
|
||||||
|
|
||||||
auto offsetsTy = offsets.getType().cast<RankedTensorType>();
|
auto offsetsTy = cast<RankedTensorType>(offsets.getType());
|
||||||
if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() &&
|
if (offsetsTy && offsetsTy.getRank() != 1 && offsetsTy.hasStaticShape() &&
|
||||||
offsetsTy.getShape()[0] == 1)
|
offsetsTy.getShape()[0] == 1)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -485,7 +484,7 @@ LogicalResult ConvertAtenOp<AtenIndexSelectOp>::matchAndRewrite(
|
||||||
AtenIndexSelectOp op, OpAdaptor adaptor,
|
AtenIndexSelectOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only ranked tensor types are supported");
|
return op.emitError("only ranked tensor types are supported");
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
|
@ -514,8 +513,8 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
Value index = adaptor.getIndex();
|
Value index = adaptor.getIndex();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto indexType = index.getType().cast<RankedTensorType>();
|
auto indexType = cast<RankedTensorType>(index.getType());
|
||||||
auto indexElemType = indexType.getElementType();
|
auto indexElemType = indexType.getElementType();
|
||||||
|
|
||||||
if (indexType.getRank() != inputType.getRank()) {
|
if (indexType.getRank() != inputType.getRank()) {
|
||||||
|
@ -623,7 +622,7 @@ LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
Value src = adaptor.getSrc();
|
Value src = adaptor.getSrc();
|
||||||
auto srcType = src.getType().cast<RankedTensorType>();
|
auto srcType = cast<RankedTensorType>(src.getType());
|
||||||
int64_t srcRank = srcType.getRank();
|
int64_t srcRank = srcType.getRank();
|
||||||
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
|
||||||
auto abstractSrcType = RankedTensorType::get(
|
auto abstractSrcType = RankedTensorType::get(
|
||||||
|
@ -651,9 +650,9 @@ public:
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
Value index = adaptor.getIndex();
|
Value index = adaptor.getIndex();
|
||||||
Value src = adaptor.getSrc();
|
Value src = adaptor.getSrc();
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto indexType = index.getType().cast<RankedTensorType>();
|
auto indexType = cast<RankedTensorType>(index.getType());
|
||||||
auto srcType = src.getType().cast<RankedTensorType>();
|
auto srcType = cast<RankedTensorType>(src.getType());
|
||||||
auto indexElemType = indexType.getElementType();
|
auto indexElemType = indexType.getElementType();
|
||||||
|
|
||||||
if (indexType.getRank() != inputType.getRank() ||
|
if (indexType.getRank() != inputType.getRank() ||
|
||||||
|
@ -789,9 +788,9 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op->getLoc();
|
Location loc = op->getLoc();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTensorType = input.getType().cast<RankedTensorType>();
|
auto inputTensorType = cast<RankedTensorType>(input.getType());
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
auto outShape = outType.getShape();
|
auto outShape = outType.getShape();
|
||||||
Value indexList = op.getIndices();
|
Value indexList = op.getIndices();
|
||||||
SmallVector<Value> indicesTorchType;
|
SmallVector<Value> indicesTorchType;
|
||||||
|
@ -857,10 +856,10 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
Value values = adaptor.getValues();
|
Value values = adaptor.getValues();
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
auto valuesType = values.getType().cast<RankedTensorType>();
|
auto valuesType = cast<RankedTensorType>(values.getType());
|
||||||
auto valuesShape = valuesType.getShape();
|
auto valuesShape = valuesType.getShape();
|
||||||
bool accumulate;
|
bool accumulate;
|
||||||
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
|
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) {
|
||||||
|
|
|
@ -32,7 +32,7 @@ namespace {
|
||||||
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||||
ArrayRef<int64_t> shape, ArrayRef<Value> dimSizes,
|
ArrayRef<int64_t> shape, ArrayRef<Value> dimSizes,
|
||||||
ArrayRef<int64_t> broadcastDims) {
|
ArrayRef<int64_t> broadcastDims) {
|
||||||
auto tensorTy = tensor.getType().dyn_cast<RankedTensorType>();
|
auto tensorTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
Value stablehloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor,
|
||||||
|
|
||||||
Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
|
||||||
ArrayRef<int64_t> inpTransDims) {
|
ArrayRef<int64_t> inpTransDims) {
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto rank = inputTy.getRank();
|
auto rank = inputTy.getRank();
|
||||||
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
|
auto transDims = hlo::toPositiveDims(inpTransDims, rank);
|
||||||
auto inpShape = inputTy.getShape();
|
auto inpShape = inputTy.getShape();
|
||||||
|
@ -70,8 +70,8 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
|
||||||
int64_t lhsResultDim, int64_t rhsResultDim,
|
int64_t lhsResultDim, int64_t rhsResultDim,
|
||||||
int64_t lhsContractingDim,
|
int64_t lhsContractingDim,
|
||||||
int64_t rhsContractingDim) {
|
int64_t rhsContractingDim) {
|
||||||
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
|
auto lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
|
||||||
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
|
auto rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
auto oldLhsShape = lhsTy.getShape();
|
auto oldLhsShape = lhsTy.getShape();
|
||||||
auto oldRhsShape = rhsTy.getShape();
|
auto oldRhsShape = rhsTy.getShape();
|
||||||
|
@ -129,8 +129,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
Value lhs = inpLhs;
|
Value lhs = inpLhs;
|
||||||
Value rhs = inpRhs;
|
Value rhs = inpRhs;
|
||||||
auto lhsRankTy = inpLhs.getType().dyn_cast<RankedTensorType>();
|
auto lhsRankTy = dyn_cast<RankedTensorType>(inpLhs.getType());
|
||||||
auto rhsRankTy = inpRhs.getType().dyn_cast<RankedTensorType>();
|
auto rhsRankTy = dyn_cast<RankedTensorType>(inpRhs.getType());
|
||||||
|
|
||||||
auto lhsRank = lhsRankTy.getRank();
|
auto lhsRank = lhsRankTy.getRank();
|
||||||
auto rhsRank = rhsRankTy.getRank();
|
auto rhsRank = rhsRankTy.getRank();
|
||||||
|
@ -177,8 +177,8 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
lhsShape = lhs.getType().cast<RankedTensorType>().getShape();
|
lhsShape = cast<RankedTensorType>(lhs.getType()).getShape();
|
||||||
rhsShape = rhs.getType().cast<RankedTensorType>().getShape();
|
rhsShape = cast<RankedTensorType>(rhs.getType()).getShape();
|
||||||
|
|
||||||
// check shape compatibility, check if we should broadcast
|
// check shape compatibility, check if we should broadcast
|
||||||
// first, we should got a new batch shape. Check from (0, nBatchDims)
|
// first, we should got a new batch shape. Check from (0, nBatchDims)
|
||||||
|
@ -266,8 +266,8 @@ public:
|
||||||
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
|
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter, Value &lhs,
|
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||||
Value &rhs, Value &output) const {
|
Value &rhs, Value &output) const {
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
auto lhsRank = lhsTy.getRank();
|
auto lhsRank = lhsTy.getRank();
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
@ -370,10 +370,10 @@ public:
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value &lhs, Value &rhs) const override {
|
Value &lhs, Value &rhs) const override {
|
||||||
lhs = adaptor.getSelf();
|
lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
|
|
||||||
rhs = adaptor.getOther();
|
rhs = adaptor.getOther();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
|
@ -393,10 +393,10 @@ public:
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value &lhs, Value &rhs) const override {
|
Value &lhs, Value &rhs) const override {
|
||||||
lhs = adaptor.getSelf();
|
lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
|
|
||||||
rhs = adaptor.getMat2();
|
rhs = adaptor.getMat2();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
|
@ -429,10 +429,10 @@ public:
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value &lhs, Value &rhs) const override {
|
Value &lhs, Value &rhs) const override {
|
||||||
lhs = adaptor.getInput();
|
lhs = adaptor.getInput();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
|
|
||||||
rhs = adaptor.getWeight();
|
rhs = adaptor.getWeight();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
|
@ -464,16 +464,15 @@ public:
|
||||||
auto biasTy = bias.getType();
|
auto biasTy = bias.getType();
|
||||||
|
|
||||||
// StableHLO does not mandate that elementwise op tensors need to be ranked.
|
// StableHLO does not mandate that elementwise op tensors need to be ranked.
|
||||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(biasTy) && !isa<RankedTensorType>(biasTy))
|
||||||
!biasTy.template isa<RankedTensorType>())
|
|
||||||
return op.emitError("only ranked tensor types are supported in StableHLO "
|
return op.emitError("only ranked tensor types are supported in StableHLO "
|
||||||
"matmul for bias tensor");
|
"matmul for bias tensor");
|
||||||
|
|
||||||
// weight.T
|
// weight.T
|
||||||
rhs = getPermutedTensor(rewriter, op, rhs, {1, 0});
|
rhs = getPermutedTensor(rewriter, op, rhs, {1, 0});
|
||||||
|
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
|
||||||
rhsTy.getRank() - lhsTy.getRank());
|
rhsTy.getRank() - lhsTy.getRank());
|
||||||
|
|
||||||
|
@ -503,7 +502,7 @@ public:
|
||||||
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);
|
||||||
|
|
||||||
Value matmulPlusBias = matmulOutput;
|
Value matmulPlusBias = matmulOutput;
|
||||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(biasTy)) {
|
||||||
// Bias addition broadcasts to the matmul output shape.
|
// Bias addition broadcasts to the matmul output shape.
|
||||||
matmulPlusBias = rewriter
|
matmulPlusBias = rewriter
|
||||||
.create<chlo::BroadcastAddOp>(
|
.create<chlo::BroadcastAddOp>(
|
||||||
|
@ -525,7 +524,7 @@ public:
|
||||||
|
|
||||||
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
|
Value reshapeConvWeight(PatternRewriter &rewriter, Operation *op,
|
||||||
Value weight, int64_t groups) const {
|
Value weight, int64_t groups) const {
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto weightElemTy = weightTy.getElementType();
|
auto weightElemTy = weightTy.getElementType();
|
||||||
auto rank = weightTy.getRank();
|
auto rank = weightTy.getRank();
|
||||||
const auto &options = getOptions();
|
const auto &options = getOptions();
|
||||||
|
@ -588,8 +587,8 @@ public:
|
||||||
ArrayRef<int64_t> dilation,
|
ArrayRef<int64_t> dilation,
|
||||||
ArrayRef<int64_t> outputPadding,
|
ArrayRef<int64_t> outputPadding,
|
||||||
int64_t groups) const {
|
int64_t groups) const {
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto weightShape = weightTy.getShape();
|
auto weightShape = weightTy.getShape();
|
||||||
|
|
||||||
auto nDims = inputTy.getRank();
|
auto nDims = inputTy.getRank();
|
||||||
|
@ -727,11 +726,11 @@ public:
|
||||||
Value weight = adaptor.getWeight();
|
Value weight = adaptor.getWeight();
|
||||||
|
|
||||||
// The input shape is [N, C, H, W]
|
// The input shape is [N, C, H, W]
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
// The weight shape is [OC, (IC//G), KH, KW]
|
// The weight shape is [OC, (IC//G), KH, KW]
|
||||||
// If transposed is set to true,
|
// If transposed is set to true,
|
||||||
// the weight shape changes to [IC, (OC//G), KH, KW]
|
// the weight shape changes to [IC, (OC//G), KH, KW]
|
||||||
auto weightTy = weight.getType().template cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto outTy = getTypeConverter()
|
auto outTy = getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
|
@ -819,11 +818,11 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle bias
|
// Handle bias
|
||||||
if (!bias.getType().cast<RankedTensorType>()) {
|
if (!cast<RankedTensorType>(bias.getType())) {
|
||||||
return op.emitError("bias provided but not a ranked tensor");
|
return op.emitError("bias provided but not a ranked tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto biasTy = bias.getType().cast<RankedTensorType>();
|
auto biasTy = cast<RankedTensorType>(bias.getType());
|
||||||
if (!biasTy.getElementType().isIntOrFloat()) {
|
if (!biasTy.getElementType().isIntOrFloat()) {
|
||||||
return op.emitError("only floating-point or integer datatype "
|
return op.emitError("only floating-point or integer datatype "
|
||||||
"legalization for bias supported");
|
"legalization for bias supported");
|
||||||
|
|
|
@ -81,12 +81,12 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dOp>::matchAndRewrite(
|
||||||
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
AtenMaxPool2dOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
|
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
auto outTy =
|
auto outTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
if (inputRank <= 2) {
|
if (inputRank <= 2) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
|
@ -176,14 +176,14 @@ LogicalResult ConvertAtenOp<AtenMaxPool2dWithIndicesOp>::matchAndRewrite(
|
||||||
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
AtenMaxPool2dWithIndicesOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
auto inputShape = inputTy.getShape();
|
auto inputShape = inputTy.getShape();
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
auto outValTy =
|
auto outValTy =
|
||||||
getTypeConverter()->convertType(op.getType(0)).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(0)));
|
||||||
auto outIdxTy =
|
auto outIdxTy =
|
||||||
getTypeConverter()->convertType(op.getType(1)).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||||
|
|
||||||
if (inputRank <= 2) {
|
if (inputRank <= 2) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
|
@ -366,7 +366,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
RankedTensorType inputTy = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputTy = cast<RankedTensorType>(input.getType());
|
||||||
Type inputElemTy = inputTy.getElementType();
|
Type inputElemTy = inputTy.getElementType();
|
||||||
int64_t inputRank = inputTy.getRank();
|
int64_t inputRank = inputTy.getRank();
|
||||||
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
RankedTensorType outTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
|
||||||
|
@ -539,11 +539,11 @@ LogicalResult ConvertAtenOp<AtenCumsumOp>::matchAndRewrite(
|
||||||
AtenCumsumOp op, OpAdaptor adaptor,
|
AtenCumsumOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto outTy =
|
auto outTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
input = hlo::promoteType(rewriter, op.getLoc(), input, outTy);
|
||||||
inputTy = input.getType().cast<RankedTensorType>();
|
inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
auto inputShape = inputTy.getShape();
|
auto inputShape = inputTy.getShape();
|
||||||
|
|
|
@ -126,7 +126,7 @@ static std::optional<ValueRange>
|
||||||
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input,
|
||||||
ArrayRef<Value> inputShapeVec, int64_t dim,
|
ArrayRef<Value> inputShapeVec, int64_t dim,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
@ -249,7 +249,7 @@ LogicalResult ConvertAtenReductionOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
AtenArgmaxOp op, OpAdaptor adaptor,
|
AtenArgmaxOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only Tensor types supported in StableHLO");
|
op, "only Tensor types supported in StableHLO");
|
||||||
|
@ -321,7 +321,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
AtenMaxDimOp op, OpAdaptor adaptor,
|
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().template dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only Tensor types supported in StableHLO");
|
op, "only Tensor types supported in StableHLO");
|
||||||
|
@ -410,7 +410,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
AtenSumOp op, OpAdaptor adaptor,
|
AtenSumOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto outTy = getTypeConverter()
|
auto outTy = getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template dyn_cast<RankedTensorType>();
|
.template dyn_cast<RankedTensorType>();
|
||||||
|
@ -423,7 +423,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumOp>::matchAndRewrite(
|
||||||
auto dstElemTy = outTy.getElementType();
|
auto dstElemTy = outTy.getElementType();
|
||||||
input =
|
input =
|
||||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
if (!inputElemTy.isIntOrFloat()) {
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
@ -626,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
|
||||||
AtenProdOp op, OpAdaptor adaptor,
|
AtenProdOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto outTy = getTypeConverter()
|
auto outTy = getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template dyn_cast<RankedTensorType>();
|
.template dyn_cast<RankedTensorType>();
|
||||||
|
@ -639,7 +639,7 @@ LogicalResult ConvertAtenReductionOp<AtenProdOp>::matchAndRewrite(
|
||||||
auto dstElemTy = outTy.getElementType();
|
auto dstElemTy = outTy.getElementType();
|
||||||
input =
|
input =
|
||||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
if (!inputElemTy.isIntOrFloat()) {
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
@ -699,7 +699,7 @@ LogicalResult ConvertAtenReductionOp<AtenMaxOp>::matchAndRewrite(
|
||||||
AtenMaxOp op, OpAdaptor adaptor,
|
AtenMaxOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only Tensor types supported in StableHLO");
|
op, "only Tensor types supported in StableHLO");
|
||||||
|
@ -762,7 +762,7 @@ LogicalResult ConvertAtenReductionOp<AtenMinOp>::matchAndRewrite(
|
||||||
AtenMinOp op, OpAdaptor adaptor,
|
AtenMinOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputTy) {
|
if (!inputTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only Tensor types supported in StableHLO");
|
op, "only Tensor types supported in StableHLO");
|
||||||
|
@ -825,7 +825,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
AtenSumDimIntListOp op, OpAdaptor adaptor,
|
AtenSumDimIntListOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto outTy = getTypeConverter()
|
auto outTy = getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template dyn_cast<RankedTensorType>();
|
.template dyn_cast<RankedTensorType>();
|
||||||
|
@ -838,7 +838,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
|
||||||
auto dstElemTy = outTy.getElementType();
|
auto dstElemTy = outTy.getElementType();
|
||||||
input =
|
input =
|
||||||
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
rewriter.create<stablehlo::ConvertOp>(op->getLoc(), input, dstElemTy);
|
||||||
inputTy = input.getType().dyn_cast<RankedTensorType>();
|
inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
}
|
}
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
if (!inputElemTy.isIntOrFloat()) {
|
if (!inputElemTy.isIntOrFloat()) {
|
||||||
|
@ -958,7 +958,7 @@ LogicalResult ConvertAtenReductionOp<AtenFrobeniusNormDimOp>::matchAndRewrite(
|
||||||
const TorchToStablehloOptions &options = getOptions();
|
const TorchToStablehloOptions &options = getOptions();
|
||||||
|
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputType) {
|
if (!inputType) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"only ranked tensor input supported in AtenFrobeniusNormDimOp");
|
"only ranked tensor input supported in AtenFrobeniusNormDimOp");
|
||||||
|
@ -1070,7 +1070,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
||||||
const TorchToStablehloOptions &options = getOptions();
|
const TorchToStablehloOptions &options = getOptions();
|
||||||
|
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
auto inputType = input.getType().dyn_cast<RankedTensorType>();
|
auto inputType = dyn_cast<RankedTensorType>(input.getType());
|
||||||
if (!inputType) {
|
if (!inputType) {
|
||||||
return op.emitError(
|
return op.emitError(
|
||||||
"only ranked tensor input supported in AtenLinalgVectorNormOp");
|
"only ranked tensor input supported in AtenLinalgVectorNormOp");
|
||||||
|
@ -1078,7 +1078,7 @@ LogicalResult ConvertAtenReductionOp<AtenLinalgVectorNormOp>::matchAndRewrite(
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
|
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
auto outElemType = outType.getElementType();
|
auto outElemType = outType.getElementType();
|
||||||
if (!isa<mlir::FloatType>(outElemType)) {
|
if (!isa<mlir::FloatType>(outElemType)) {
|
||||||
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
|
return op.emitError("only float dtype allowed in AtenLinalgVectorNormOp");
|
||||||
|
|
|
@ -144,7 +144,7 @@ Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter,
|
||||||
|
|
||||||
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
Value promoteType(PatternRewriter &rewriter, Location loc, Value input,
|
||||||
TensorType outType) {
|
TensorType outType) {
|
||||||
TensorType in_type = input.getType().cast<TensorType>();
|
TensorType in_type = cast<TensorType>(input.getType());
|
||||||
|
|
||||||
if (in_type.getElementType() != outType.getElementType()) {
|
if (in_type.getElementType() != outType.getElementType()) {
|
||||||
TensorType promotedType =
|
TensorType promotedType =
|
||||||
|
@ -162,7 +162,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input,
|
||||||
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
// dimension, the dimension sizes must either be equal, one of them is 1, or
|
||||||
// one of them does not exist.
|
// one of them does not exist.
|
||||||
Operation *op = input.getDefiningOp();
|
Operation *op = input.getDefiningOp();
|
||||||
TensorType in_type = input.getType().dyn_cast<TensorType>();
|
TensorType in_type = dyn_cast<TensorType>(input.getType());
|
||||||
|
|
||||||
if (in_type.getElementType() != outType.getElementType()) {
|
if (in_type.getElementType() != outType.getElementType()) {
|
||||||
TensorType promoted_type =
|
TensorType promoted_type =
|
||||||
|
@ -217,7 +217,7 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
Operation *op, Value value,
|
Operation *op, Value value,
|
||||||
ArrayRef<int64_t> inpDims,
|
ArrayRef<int64_t> inpDims,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
||||||
if (!valueTy) {
|
if (!valueTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||||
|
@ -240,7 +240,7 @@ FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
FailureOr<SmallVector<Value, 4>> getDimSizesOfTensor(PatternRewriter &rewriter,
|
||||||
Operation *op, Value value,
|
Operation *op, Value value,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
auto valueTy = value.getType().dyn_cast<RankedTensorType>();
|
auto valueTy = dyn_cast<RankedTensorType>(value.getType());
|
||||||
if (!valueTy) {
|
if (!valueTy) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
op, "getDimSizesOfTensor(): the input is not a ranked tensor");
|
||||||
|
@ -279,7 +279,7 @@ FailureOr<Value> unsqueezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
op, "unsqueeze dimensions must be specified in order");
|
op, "unsqueeze dimensions must be specified in order");
|
||||||
|
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto rankTy = tensor.getType().dyn_cast<RankedTensorType>();
|
auto rankTy = dyn_cast<RankedTensorType>(tensor.getType());
|
||||||
auto oldShape = rankTy.getShape();
|
auto oldShape = rankTy.getShape();
|
||||||
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
Type intType = rewriter.getIntegerType(dimSizeIndexBits);
|
||||||
auto one = rewriter.create<arith::ConstantOp>(
|
auto one = rewriter.create<arith::ConstantOp>(
|
||||||
|
|
|
@ -72,7 +72,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op,
|
||||||
SmallVector<Value, 4> endIndices;
|
SmallVector<Value, 4> endIndices;
|
||||||
SmallVector<Value, 4> strides;
|
SmallVector<Value, 4> strides;
|
||||||
|
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
size_t rank = inputTy.getRank();
|
size_t rank = inputTy.getRank();
|
||||||
startIndices.reserve(rank);
|
startIndices.reserve(rank);
|
||||||
endIndices.reserve(rank);
|
endIndices.reserve(rank);
|
||||||
|
@ -116,7 +116,7 @@ FailureOr<Value> getDynamicSlice(PatternRewriter &rewriter, Operation *op,
|
||||||
std::optional<Value> stepOpt, int64_t dim,
|
std::optional<Value> stepOpt, int64_t dim,
|
||||||
size_t dimSizeIndexBits) {
|
size_t dimSizeIndexBits) {
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto inputTy = input.getType().dyn_cast<RankedTensorType>();
|
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
|
||||||
auto rank = inputTy.getRank();
|
auto rank = inputTy.getRank();
|
||||||
|
|
||||||
dim = (dim + rank) % rank;
|
dim = (dim + rank) % rank;
|
||||||
|
@ -168,8 +168,7 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto rankType =
|
auto rankType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
adaptor.getSelf().getType().template dyn_cast<RankedTensorType>();
|
|
||||||
if (!rankType)
|
if (!rankType)
|
||||||
return op.emitError("Only ranked tensor types are currently supported");
|
return op.emitError("Only ranked tensor types are currently supported");
|
||||||
|
|
||||||
|
@ -233,11 +232,11 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only ranked tensor types are supported");
|
return op.emitError("only ranked tensor types are supported");
|
||||||
auto outTy =
|
auto outTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -275,7 +274,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
|
||||||
AtenSqueezeOp op, OpAdaptor adaptor,
|
AtenSqueezeOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only ranked tensor types are supported");
|
return op.emitError("only ranked tensor types are supported");
|
||||||
|
|
||||||
|
@ -318,7 +317,7 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
|
||||||
AtenSqueezeDimOp op, OpAdaptor adaptor,
|
AtenSqueezeDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return op.emitError("only ranked tensor types are supported");
|
return op.emitError("only ranked tensor types are supported");
|
||||||
|
|
||||||
|
@ -369,7 +368,7 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
AtenUnsqueezeOp op, OpAdaptor adaptor,
|
AtenUnsqueezeOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return op.emitError("only tensor types are currently supported");
|
return op.emitError("only tensor types are currently supported");
|
||||||
}
|
}
|
||||||
|
@ -378,7 +377,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
|
||||||
return op->emitError("dim must be a Scalar constant");
|
return op->emitError("dim must be a Scalar constant");
|
||||||
int64_t inputRank =
|
int64_t inputRank =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>().getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
dim = toPositiveDim(dim, inputRank + 1);
|
dim = toPositiveDim(dim, inputRank + 1);
|
||||||
if (!isValidDim(dim, inputRank + 1))
|
if (!isValidDim(dim, inputRank + 1))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
@ -397,7 +396,7 @@ template <>
|
||||||
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
LogicalResult ConvertAtenOp<PrimsCollapseOp>::matchAndRewrite(
|
||||||
PrimsCollapseOp op, OpAdaptor adaptor,
|
PrimsCollapseOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto selfType = adaptor.getA().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getA().getType());
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return op.emitError("only tensor types are currently supported");
|
return op.emitError("only tensor types are currently supported");
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,8 +89,8 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
|
||||||
Value indices, Value src,
|
Value indices, Value src,
|
||||||
int64_t dim) {
|
int64_t dim) {
|
||||||
// Get information on types for inputs
|
// Get information on types for inputs
|
||||||
RankedTensorType indexType = indices.getType().cast<RankedTensorType>();
|
RankedTensorType indexType = cast<RankedTensorType>(indices.getType());
|
||||||
RankedTensorType srcSelf = src.getType().cast<RankedTensorType>();
|
RankedTensorType srcSelf = cast<RankedTensorType>(src.getType());
|
||||||
|
|
||||||
// Store location for insertions
|
// Store location for insertions
|
||||||
Location loc = src.getLoc();
|
Location loc = src.getLoc();
|
||||||
|
@ -219,7 +219,7 @@ static Value createTMTensorScatterOp(
|
||||||
llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices,
|
llvm::ArrayRef<int64_t> dimensionsMap, bool uniqueIndices,
|
||||||
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||||
auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap);
|
auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap);
|
||||||
auto originalTensorType = original.getType().cast<RankedTensorType>();
|
auto originalTensorType = cast<RankedTensorType>(original.getType());
|
||||||
Type originalElementType = originalTensorType.getElementType();
|
Type originalElementType = originalTensorType.getElementType();
|
||||||
auto scatterOp = b.create<TMTensor::ScatterOp>(
|
auto scatterOp = b.create<TMTensor::ScatterOp>(
|
||||||
loc, originalTensorType, ValueRange{updates, indices},
|
loc, originalTensorType, ValueRange{updates, indices},
|
||||||
|
@ -241,8 +241,8 @@ static Value createTMTensorScanOp(
|
||||||
OpBuilder &b, Location loc, Value input, Value output, Value accumulator,
|
OpBuilder &b, Location loc, Value input, Value output, Value accumulator,
|
||||||
int64_t dim, bool inclusive,
|
int64_t dim, bool inclusive,
|
||||||
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
|
||||||
auto inputType = input.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto accType = accumulator.getType().cast<RankedTensorType>();
|
auto accType = cast<RankedTensorType>(accumulator.getType());
|
||||||
Type elementType = inputType.getElementType();
|
Type elementType = inputType.getElementType();
|
||||||
auto scanOp = b.create<TMTensor::ScanOp>(
|
auto scanOp = b.create<TMTensor::ScanOp>(
|
||||||
loc, TypeRange{inputType, accType}, input,
|
loc, TypeRange{inputType, accType}, input,
|
||||||
|
@ -287,7 +287,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc,
|
||||||
|
|
||||||
// Step 3. Create comparison op which will be used as the sorting predicate.
|
// Step 3. Create comparison op which will be used as the sorting predicate.
|
||||||
Value compareOp;
|
Value compareOp;
|
||||||
if (auto intType = elementTypes[0].dyn_cast<mlir::IntegerType>()) {
|
if (auto intType = dyn_cast<mlir::IntegerType>(elementTypes[0])) {
|
||||||
// Case for using arith::CmpIOp.
|
// Case for using arith::CmpIOp.
|
||||||
arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
|
arith::CmpIPredicate ge = arith::CmpIPredicate::sge;
|
||||||
arith::CmpIPredicate le = arith::CmpIPredicate::sle;
|
arith::CmpIPredicate le = arith::CmpIPredicate::sle;
|
||||||
|
@ -329,9 +329,9 @@ public:
|
||||||
Value index = adaptor.getIndex();
|
Value index = adaptor.getIndex();
|
||||||
Value src = adaptor.getSrc();
|
Value src = adaptor.getSrc();
|
||||||
|
|
||||||
RankedTensorType selfType = self.getType().cast<RankedTensorType>();
|
RankedTensorType selfType = cast<RankedTensorType>(self.getType());
|
||||||
RankedTensorType indexType = index.getType().cast<RankedTensorType>();
|
RankedTensorType indexType = cast<RankedTensorType>(index.getType());
|
||||||
RankedTensorType srcType = src.getType().cast<RankedTensorType>();
|
RankedTensorType srcType = cast<RankedTensorType>(src.getType());
|
||||||
if (selfType.getRank() != indexType.getRank() ||
|
if (selfType.getRank() != indexType.getRank() ||
|
||||||
indexType.getRank() != srcType.getRank())
|
indexType.getRank() != srcType.getRank())
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -385,7 +385,7 @@ public:
|
||||||
// TODO: Add a check to verify that the input tensor elements are all
|
// TODO: Add a check to verify that the input tensor elements are all
|
||||||
// non-negative.
|
// non-negative.
|
||||||
// Check whether the input is a 1-d tensor of integer type or not.
|
// Check whether the input is a 1-d tensor of integer type or not.
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
if (inputType.getRank() != 1 ||
|
if (inputType.getRank() != 1 ||
|
||||||
!inputType.getElementType().isa<mlir::IntegerType>())
|
!inputType.getElementType().isa<mlir::IntegerType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -394,7 +394,7 @@ public:
|
||||||
|
|
||||||
// Check whether the input tensor element type is i64 or not.
|
// Check whether the input tensor element type is i64 or not.
|
||||||
IntegerType inputIntegerType =
|
IntegerType inputIntegerType =
|
||||||
inputType.getElementType().cast<IntegerType>();
|
cast<IntegerType>(inputType.getElementType());
|
||||||
if (inputIntegerType.getWidth() != 64)
|
if (inputIntegerType.getWidth() != 64)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op,
|
||||||
|
@ -409,7 +409,7 @@ public:
|
||||||
SmallVector<int64_t> maxTensorSizes;
|
SmallVector<int64_t> maxTensorSizes;
|
||||||
ValueTensorType maxTensorType = ValueTensorType::get(
|
ValueTensorType maxTensorType = ValueTensorType::get(
|
||||||
context, llvm::ArrayRef(maxTensorSizes),
|
context, llvm::ArrayRef(maxTensorSizes),
|
||||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
cast<ValueTensorType>(torchTypeInput.getType()).getDtype());
|
||||||
Value maxTensor =
|
Value maxTensor =
|
||||||
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
|
rewriter.create<AtenMaxOp>(loc, maxTensorType, torchTypeInput);
|
||||||
maxTensor = typeConverter->materializeTargetConversion(
|
maxTensor = typeConverter->materializeTargetConversion(
|
||||||
|
@ -432,7 +432,7 @@ public:
|
||||||
makeShapeTorchCompatible(inputType.getShape())[0], 1};
|
makeShapeTorchCompatible(inputType.getShape())[0], 1};
|
||||||
ValueTensorType expandInputType = ValueTensorType::get(
|
ValueTensorType expandInputType = ValueTensorType::get(
|
||||||
context, llvm::ArrayRef(expandedInputSizes),
|
context, llvm::ArrayRef(expandedInputSizes),
|
||||||
torchTypeInput.getType().cast<ValueTensorType>().getDtype());
|
cast<ValueTensorType>(torchTypeInput.getType()).getDtype());
|
||||||
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
Value torchCstOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
|
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
|
||||||
|
@ -571,7 +571,7 @@ Value combinePutIndices(Location loc, llvm::ArrayRef<Value> indicesRef,
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseTensorType unsqueezedTensorType =
|
BaseTensorType unsqueezedTensorType =
|
||||||
indices[0].getType().cast<BaseTensorType>();
|
cast<BaseTensorType>(indices[0].getType());
|
||||||
Value indicesTorchList = b.create<PrimListConstructOp>(
|
Value indicesTorchList = b.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(unsqueezedTensorType), indices);
|
loc, Torch::ListType::get(unsqueezedTensorType), indices);
|
||||||
llvm::SmallVector<int64_t, 2> concatShape{
|
llvm::SmallVector<int64_t, 2> concatShape{
|
||||||
|
@ -691,7 +691,7 @@ public:
|
||||||
auto inputType = cast<ValueTensorType>(input.getType());
|
auto inputType = cast<ValueTensorType>(input.getType());
|
||||||
auto valuesType = cast<ValueTensorType>(values.getType());
|
auto valuesType = cast<ValueTensorType>(values.getType());
|
||||||
int64_t inputRank = inputType.getSizes().size();
|
int64_t inputRank = inputType.getSizes().size();
|
||||||
auto valuesTensorType = op.getValues().getType().cast<BaseTensorType>();
|
auto valuesTensorType = cast<BaseTensorType>(op.getValues().getType());
|
||||||
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
auto resultType = typeConverter->convertType(op->getResult(0).getType())
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -902,9 +902,9 @@ public:
|
||||||
Value gradOutput = adaptor.getGradOutput();
|
Value gradOutput = adaptor.getGradOutput();
|
||||||
Value input = adaptor.getSelf();
|
Value input = adaptor.getSelf();
|
||||||
RankedTensorType gradOutputType =
|
RankedTensorType gradOutputType =
|
||||||
gradOutput.getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(gradOutput.getType());
|
||||||
Type gradOutputElemType = gradOutputType.getElementType();
|
Type gradOutputElemType = gradOutputType.getElementType();
|
||||||
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
|
RankedTensorType inputType = cast<RankedTensorType>(input.getType());
|
||||||
Type inputElemType = inputType.getElementType();
|
Type inputElemType = inputType.getElementType();
|
||||||
int64_t tensorOperandRank = inputType.getRank();
|
int64_t tensorOperandRank = inputType.getRank();
|
||||||
|
|
||||||
|
@ -914,7 +914,7 @@ public:
|
||||||
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
|
||||||
indices = typeConverter->materializeTargetConversion(
|
indices = typeConverter->materializeTargetConversion(
|
||||||
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
|
||||||
RankedTensorType indicesType = indices.getType().cast<RankedTensorType>();
|
RankedTensorType indicesType = cast<RankedTensorType>(indices.getType());
|
||||||
Type indicesElemType = indicesType.getElementType();
|
Type indicesElemType = indicesType.getElementType();
|
||||||
|
|
||||||
// The element type of the `input` and `grad_output` should be same.
|
// The element type of the `input` and `grad_output` should be same.
|
||||||
|
@ -1100,11 +1100,11 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
RankedTensorType selfType =
|
RankedTensorType selfType =
|
||||||
adaptor.getSelf().getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
RankedTensorType indexType =
|
RankedTensorType indexType =
|
||||||
adaptor.getIndex().getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(adaptor.getIndex().getType());
|
||||||
RankedTensorType srcType =
|
RankedTensorType srcType =
|
||||||
adaptor.getSrc().getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(adaptor.getSrc().getType());
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
|
|
||||||
|
@ -1324,7 +1324,7 @@ public:
|
||||||
|
|
||||||
// Step 1. Fetch Input to sort.
|
// Step 1. Fetch Input to sort.
|
||||||
Value inputTensor = adaptor.getSelf();
|
Value inputTensor = adaptor.getSelf();
|
||||||
auto inputType = inputTensor.getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(inputTensor.getType());
|
||||||
unsigned inputRank = inputType.getRank();
|
unsigned inputRank = inputType.getRank();
|
||||||
|
|
||||||
// Step 2. Fetch dimension to perform sort in.
|
// Step 2. Fetch dimension to perform sort in.
|
||||||
|
@ -1414,7 +1414,7 @@ public:
|
||||||
.cast<RankedTensorType>();
|
.cast<RankedTensorType>();
|
||||||
Type elementType = resultType.getElementType();
|
Type elementType = resultType.getElementType();
|
||||||
Type inputElementType =
|
Type inputElementType =
|
||||||
input.getType().cast<RankedTensorType>().getElementType();
|
cast<RankedTensorType>(input.getType()).getElementType();
|
||||||
|
|
||||||
// Converting the input element type to the result's element type.
|
// Converting the input element type to the result's element type.
|
||||||
// The only possible mismatch would be when the input element type is an
|
// The only possible mismatch would be when the input element type is an
|
||||||
|
@ -1486,7 +1486,7 @@ public:
|
||||||
Value isCausal = op.getIsCausal();
|
Value isCausal = op.getIsCausal();
|
||||||
Value scale = op.getScale();
|
Value scale = op.getScale();
|
||||||
Type elementType =
|
Type elementType =
|
||||||
adaptor.getQuery().getType().cast<ShapedType>().getElementType();
|
cast<ShapedType>(adaptor.getQuery().getType()).getElementType();
|
||||||
|
|
||||||
// Verify inputs (only support defaults)
|
// Verify inputs (only support defaults)
|
||||||
if (!mask.getType().isa<Torch::NoneType>())
|
if (!mask.getType().isa<Torch::NoneType>())
|
||||||
|
@ -1557,10 +1557,9 @@ public:
|
||||||
key = collapseBatch(key);
|
key = collapseBatch(key);
|
||||||
value = collapseBatch(value);
|
value = collapseBatch(value);
|
||||||
|
|
||||||
SmallVector<int64_t> outSizes(
|
SmallVector<int64_t> outSizes(cast<ShapedType>(query.getType()).getShape());
|
||||||
query.getType().cast<ShapedType>().getShape());
|
|
||||||
SmallVector<int64_t> valueSizes(
|
SmallVector<int64_t> valueSizes(
|
||||||
value.getType().cast<ShapedType>().getShape());
|
cast<ShapedType>(value.getType()).getShape());
|
||||||
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1];
|
||||||
SmallVector<Value> outSizesDynamic(
|
SmallVector<Value> outSizesDynamic(
|
||||||
getTensorSizes(rewriter, op.getLoc(), query));
|
getTensorSizes(rewriter, op.getLoc(), query));
|
||||||
|
|
|
@ -79,9 +79,9 @@ public:
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto operand = adaptor.getOperands()[0];
|
auto operand = adaptor.getOperands()[0];
|
||||||
auto operandTy = operand.getType().cast<RankedTensorType>();
|
auto operandTy = cast<RankedTensorType>(operand.getType());
|
||||||
auto resultTy =
|
auto resultTy =
|
||||||
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
int64_t rank = operandTy.getRank();
|
int64_t rank = operandTy.getRank();
|
||||||
if (rank == 0) {
|
if (rank == 0) {
|
||||||
|
|
|
@ -43,7 +43,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -93,9 +93,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<TensorType>();
|
auto lhsTy = cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
auto rhsTy = rhs.getType().cast<TensorType>();
|
auto rhsTy = cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -235,15 +235,15 @@ public:
|
||||||
// alpha : scalar: i32/i64/f32
|
// alpha : scalar: i32/i64/f32
|
||||||
// output: tensor: tensor<i32/i64/f32>
|
// output: tensor: tensor<i32/i64/f32>
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
||||||
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) {
|
if (auto lhsElemTy = dyn_cast<IntegerType>(lhsType.getElementType())) {
|
||||||
if (lhsElemTy.getWidth() > 64)
|
if (lhsElemTy.getWidth() > 64)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Integers with widths greater than 64 are not supported");
|
op, "Integers with widths greater than 64 are not supported");
|
||||||
|
@ -284,7 +284,7 @@ public:
|
||||||
op->getLoc(),
|
op->getLoc(),
|
||||||
RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs);
|
RankedTensorType::get(rhsType.getShape(), rhsAlphaMulElemType), rhs);
|
||||||
// reinitialize right value type to tensor<i32/f32>
|
// reinitialize right value type to tensor<i32/f32>
|
||||||
rhsType = rhs.getType().dyn_cast<TensorType>();
|
rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
}
|
}
|
||||||
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||||
|
|
||||||
|
@ -337,9 +337,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
auto lhsTy = dyn_cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
auto rhsTy = dyn_cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -409,7 +409,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
auto lhsType = dyn_cast<TensorType>(lhs.getType());
|
||||||
|
|
||||||
if (!lhsType)
|
if (!lhsType)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -430,7 +430,7 @@ public:
|
||||||
} else {
|
} else {
|
||||||
Value rhsAsTensor;
|
Value rhsAsTensor;
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
if (!rhsType) {
|
if (!rhsType) {
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(),
|
||||||
rhsAsTensor, outElemTy, {}))) {
|
rhsAsTensor, outElemTy, {}))) {
|
||||||
|
@ -469,9 +469,9 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value lhs = adaptor.getSelf();
|
Value lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
auto lhsTy = dyn_cast<TensorType>(lhs.getType());
|
||||||
Value rhs = adaptor.getOther();
|
Value rhs = adaptor.getOther();
|
||||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
auto rhsTy = dyn_cast<TensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy)
|
if (!lhsTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -497,7 +497,7 @@ public:
|
||||||
|
|
||||||
// auto result;
|
// auto result;
|
||||||
Value result;
|
Value result;
|
||||||
if (outType.getElementType().template isa<mlir::FloatType>()) {
|
if (isa<mlir::FloatType>(outType.getElementType())) {
|
||||||
// The input to the reciprocal is an integer sometimes, and we may need to
|
// The input to the reciprocal is an integer sometimes, and we may need to
|
||||||
// promote it to a floating point. Per TOSA specification, the input types
|
// promote it to a floating point. Per TOSA specification, the input types
|
||||||
// can only be floating point for tosa::ReciprocalOp.
|
// can only be floating point for tosa::ReciprocalOp.
|
||||||
|
@ -538,7 +538,7 @@ LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
|
||||||
AtenTanhOp op, OpAdaptor adaptor,
|
AtenTanhOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
|
@ -555,7 +555,7 @@ LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
|
||||||
AtenSigmoidOp op, OpAdaptor adaptor,
|
AtenSigmoidOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (selfTy && selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
|
||||||
op, getTypeConverter()->convertType(op.getType()), self);
|
op, getTypeConverter()->convertType(op.getType()), self);
|
||||||
|
@ -572,7 +572,7 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
|
||||||
AtenReluOp op, OpAdaptor adaptor,
|
AtenReluOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
|
|
||||||
// Maps to tosa.clamp which has both int and fp limits.
|
// Maps to tosa.clamp which has both int and fp limits.
|
||||||
int64_t clampMin = 0;
|
int64_t clampMin = 0;
|
||||||
|
@ -602,7 +602,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
if (!selfTy.getElementType().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only floating-point datatype legalization currently supported");
|
op, "Only floating-point datatype legalization currently supported");
|
||||||
|
@ -660,7 +660,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<TensorType>();
|
auto selfTy = cast<TensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -713,7 +713,7 @@ class ConvertAtenMultipleDimsReductionOp
|
||||||
"non-const dim parameter unsupported");
|
"non-const dim parameter unsupported");
|
||||||
int64_t N = reduceDims.size();
|
int64_t N = reduceDims.size();
|
||||||
int64_t inputRank =
|
int64_t inputRank =
|
||||||
adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
for (unsigned i = 0; i < N; i++) {
|
for (unsigned i = 0; i < N; i++) {
|
||||||
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
reduceDims[i] = toPositiveDim(reduceDims[i], inputRank);
|
||||||
if (!isValidDim(reduceDims[i], inputRank))
|
if (!isValidDim(reduceDims[i], inputRank))
|
||||||
|
@ -751,7 +751,7 @@ class ConvertAtenOneDimReductionOp
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"non-const dim parameter unsupported");
|
"non-const dim parameter unsupported");
|
||||||
int64_t inputRank =
|
int64_t inputRank =
|
||||||
adaptor.getSelf().getType().template cast<RankedTensorType>().getRank();
|
cast<RankedTensorType>(adaptor.getSelf().getType()).getRank();
|
||||||
reduceDim = toPositiveDim(reduceDim, inputRank);
|
reduceDim = toPositiveDim(reduceDim, inputRank);
|
||||||
if (!isValidDim(reduceDim, inputRank))
|
if (!isValidDim(reduceDim, inputRank))
|
||||||
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
return rewriter.notifyMatchFailure(op, "dim is statically invalid");
|
||||||
|
@ -782,7 +782,7 @@ public:
|
||||||
ElementsAttr &reduceDimsAttr,
|
ElementsAttr &reduceDimsAttr,
|
||||||
bool &keepDims) const override {
|
bool &keepDims) const override {
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
// Select all dims to reduce
|
// Select all dims to reduce
|
||||||
SmallVector<int64_t, 4> reduceDims;
|
SmallVector<int64_t, 4> reduceDims;
|
||||||
|
@ -804,7 +804,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -835,7 +835,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
// Create a single instance of tosa.argmax.
|
// Create a single instance of tosa.argmax.
|
||||||
// Multiple dims require chained construct.
|
// Multiple dims require chained construct.
|
||||||
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
|
auto buildArgmax = [&](int64_t reduceDim, Value input) -> Value {
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
|
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
|
||||||
SmallVector<int64_t> outputShapeArr = {};
|
SmallVector<int64_t> outputShapeArr = {};
|
||||||
int32_t i = 0;
|
int32_t i = 0;
|
||||||
|
@ -865,7 +865,7 @@ LogicalResult ConvertAtenOp<AtenArgmaxOp>::matchAndRewrite(
|
||||||
// Convert the final index to i64 for backend finalization, However, i64
|
// Convert the final index to i64 for backend finalization, However, i64
|
||||||
// is not a defined type for tosa.cast, so using arith.extsi instead.
|
// is not a defined type for tosa.cast, so using arith.extsi instead.
|
||||||
auto castToInt64 = [&](Value result) -> LogicalResult {
|
auto castToInt64 = [&](Value result) -> LogicalResult {
|
||||||
auto resTy = result.getType().cast<ShapedType>();
|
auto resTy = cast<ShapedType>(result.getType());
|
||||||
if (!resTy)
|
if (!resTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Argmax: Result is not a shaped type");
|
"Argmax: Result is not a shaped type");
|
||||||
|
@ -915,7 +915,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1010,7 +1010,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1021,7 +1021,7 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
||||||
op, "Only floating-point datatype legalization supported");
|
op, "Only floating-point datatype legalization supported");
|
||||||
|
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
Value expTensor;
|
Value expTensor;
|
||||||
Value expScalar = op.getExponent();
|
Value expScalar = op.getExponent();
|
||||||
|
@ -1063,8 +1063,8 @@ public:
|
||||||
ConversionPatternRewriter &rewriter, Value &lhs,
|
ConversionPatternRewriter &rewriter, Value &lhs,
|
||||||
Value &rhs, Value &output) const {
|
Value &rhs, Value &output) const {
|
||||||
|
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
auto lhsRank = lhsTy.getRank();
|
auto lhsRank = lhsTy.getRank();
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
|
@ -1097,7 +1097,7 @@ public:
|
||||||
// construct the input and output reshaping logic.
|
// construct the input and output reshaping logic.
|
||||||
auto getRankBroadcastedShape = [&](Value tensor,
|
auto getRankBroadcastedShape = [&](Value tensor,
|
||||||
bool isRHS) -> SmallVector<int64_t> {
|
bool isRHS) -> SmallVector<int64_t> {
|
||||||
auto tensorTy = tensor.getType().cast<TensorType>();
|
auto tensorTy = cast<TensorType>(tensor.getType());
|
||||||
auto tensorShape = makeShapeTorchCompatible(tensorTy.getShape());
|
auto tensorShape = makeShapeTorchCompatible(tensorTy.getShape());
|
||||||
auto tensorRank = tensorTy.getRank();
|
auto tensorRank = tensorTy.getRank();
|
||||||
|
|
||||||
|
@ -1151,7 +1151,7 @@ public:
|
||||||
// TOSA matmul is performed on two 3D inputs and generates a 3D output.
|
// TOSA matmul is performed on two 3D inputs and generates a 3D output.
|
||||||
// Lower ranked tensors are dim-1 reshaped up to 3D
|
// Lower ranked tensors are dim-1 reshaped up to 3D
|
||||||
auto reshapeUpTo3DTensor = [&](Value tensor) -> Value {
|
auto reshapeUpTo3DTensor = [&](Value tensor) -> Value {
|
||||||
auto tensorTy = tensor.getType().cast<TensorType>();
|
auto tensorTy = cast<TensorType>(tensor.getType());
|
||||||
auto rank = tensorTy.getRank();
|
auto rank = tensorTy.getRank();
|
||||||
|
|
||||||
assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3");
|
assert(rank <= 3 && "reshapeUpTo3D tensor must receive rank <= 3");
|
||||||
|
@ -1440,9 +1440,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
auto matmulLhsShape = makeShapeTorchCompatible(
|
auto matmulLhsShape = makeShapeTorchCompatible(
|
||||||
matmulLhs.getType().template cast<RankedTensorType>().getShape());
|
cast<RankedTensorType>(matmulLhs.getType()).getShape());
|
||||||
auto matmulRhsShape = makeShapeTorchCompatible(
|
auto matmulRhsShape = makeShapeTorchCompatible(
|
||||||
matmulRhs.getType().template cast<RankedTensorType>().getShape());
|
cast<RankedTensorType>(matmulRhs.getType()).getShape());
|
||||||
|
|
||||||
// The reshape/transpose should ensure the tosa.matmul always has same
|
// The reshape/transpose should ensure the tosa.matmul always has same
|
||||||
// batch size for either matrix. If if shapes are dynamic, they'll be
|
// batch size for either matrix. If if shapes are dynamic, they'll be
|
||||||
|
@ -1642,10 +1642,10 @@ public:
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value &lhs, Value &rhs) const override {
|
Value &lhs, Value &rhs) const override {
|
||||||
lhs = adaptor.getSelf();
|
lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
|
|
||||||
rhs = adaptor.getOther();
|
rhs = adaptor.getOther();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1666,10 +1666,10 @@ public:
|
||||||
Value &lhs, Value &rhs) const override {
|
Value &lhs, Value &rhs) const override {
|
||||||
|
|
||||||
lhs = adaptor.getSelf();
|
lhs = adaptor.getSelf();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
|
|
||||||
rhs = adaptor.getMat2();
|
rhs = adaptor.getMat2();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1703,10 +1703,10 @@ public:
|
||||||
Value &lhs, Value &rhs) const override {
|
Value &lhs, Value &rhs) const override {
|
||||||
|
|
||||||
lhs = adaptor.getInput();
|
lhs = adaptor.getInput();
|
||||||
auto lhsTy = lhs.getType().cast<RankedTensorType>();
|
auto lhsTy = cast<RankedTensorType>(lhs.getType());
|
||||||
|
|
||||||
rhs = adaptor.getWeight();
|
rhs = adaptor.getWeight();
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
|
|
||||||
if (!lhsTy || !rhsTy)
|
if (!lhsTy || !rhsTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1744,14 +1744,13 @@ public:
|
||||||
auto biasTy = bias.getType();
|
auto biasTy = bias.getType();
|
||||||
|
|
||||||
// TOSA does not mandate that elementwise op tensors need to be ranked.
|
// TOSA does not mandate that elementwise op tensors need to be ranked.
|
||||||
if (!biasTy.template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(biasTy) && !isa<TensorType>(biasTy))
|
||||||
!biasTy.template isa<TensorType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types supported in GEMM to TOSA for bias tensor");
|
op, "Only tensor types supported in GEMM to TOSA for bias tensor");
|
||||||
|
|
||||||
// RHS must have its last two dims transposed prior to matrix
|
// RHS must have its last two dims transposed prior to matrix
|
||||||
// multiplication.
|
// multiplication.
|
||||||
auto rhsTy = rhs.getType().cast<RankedTensorType>();
|
auto rhsTy = cast<RankedTensorType>(rhs.getType());
|
||||||
auto rhsRank = rhsTy.getRank();
|
auto rhsRank = rhsTy.getRank();
|
||||||
auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape());
|
auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape());
|
||||||
auto rhsElemTy = rhsTy.getElementType();
|
auto rhsElemTy = rhsTy.getElementType();
|
||||||
|
@ -1789,7 +1788,7 @@ public:
|
||||||
"Failed to perform matmul operation");
|
"Failed to perform matmul operation");
|
||||||
|
|
||||||
Value matmulPlusBias = matmulOutput;
|
Value matmulPlusBias = matmulOutput;
|
||||||
if (!biasTy.template isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(biasTy)) {
|
||||||
// Bias addition broadcasts to the matmul output shape.
|
// Bias addition broadcasts to the matmul output shape.
|
||||||
matmulPlusBias =
|
matmulPlusBias =
|
||||||
rewriter
|
rewriter
|
||||||
|
@ -1818,7 +1817,7 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
||||||
auto otherScalar = op.getOther();
|
auto otherScalar = op.getOther();
|
||||||
auto alphaScalar = op.getAlpha();
|
auto alphaScalar = op.getAlpha();
|
||||||
|
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Rsub");
|
op, "Only ranked tensor types supported in TOSA Rsub");
|
||||||
|
@ -1867,8 +1866,8 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
auto input = adaptor.getInput();
|
auto input = adaptor.getInput();
|
||||||
auto weight = adaptor.getWeight();
|
auto weight = adaptor.getWeight();
|
||||||
|
|
||||||
auto inputTy = input.getType().cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto weightTy = weight.getType().cast<RankedTensorType>();
|
auto weightTy = cast<RankedTensorType>(weight.getType());
|
||||||
auto outputTy = getTypeConverter()
|
auto outputTy = getTypeConverter()
|
||||||
->convertType(op.getType())
|
->convertType(op.getType())
|
||||||
.template cast<RankedTensorType>();
|
.template cast<RankedTensorType>();
|
||||||
|
@ -1893,7 +1892,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
|
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
|
||||||
// required.
|
// required.
|
||||||
auto bias = adaptor.getBias();
|
auto bias = adaptor.getBias();
|
||||||
if (adaptor.getBias().getType().template isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(adaptor.getBias().getType())) {
|
||||||
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
|
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
|
||||||
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
|
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
|
||||||
// define a 48-bit int.
|
// define a 48-bit int.
|
||||||
|
@ -1909,7 +1908,7 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
|
||||||
.value();
|
.value();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!bias.getType().cast<RankedTensorType>())
|
if (!cast<RankedTensorType>(bias.getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Bias provided but not a ranked tensor");
|
op, "Bias provided but not a ranked tensor");
|
||||||
}
|
}
|
||||||
|
@ -2115,7 +2114,7 @@ LogicalResult ConvertAtenOp<AtenReshapeOp>::matchAndRewrite(
|
||||||
|
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
|
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Reshape");
|
op, "Only ranked tensor types supported in TOSA Reshape");
|
||||||
|
@ -2199,7 +2198,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a ranked tensor output
|
// Not a ranked tensor output
|
||||||
if (!adaptor.getInput().getType().dyn_cast<RankedTensorType>())
|
if (!dyn_cast<RankedTensorType>(adaptor.getInput().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types are supported");
|
op, "Only ranked tensor types are supported");
|
||||||
|
|
||||||
|
@ -2211,8 +2210,8 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
if (op.getMomentum().getType().isa<Torch::NoneType>())
|
if (op.getMomentum().getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
|
return rewriter.notifyMatchFailure(op, "Unsupported None for momentum");
|
||||||
|
|
||||||
auto meanType = adaptor.getRunningMean().getType().dyn_cast<TensorType>();
|
auto meanType = dyn_cast<TensorType>(adaptor.getRunningMean().getType());
|
||||||
auto varianceType = adaptor.getRunningVar().getType().dyn_cast<TensorType>();
|
auto varianceType = dyn_cast<TensorType>(adaptor.getRunningVar().getType());
|
||||||
if (!varianceType || !meanType)
|
if (!varianceType || !meanType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types are supported");
|
op, "Only ranked tensor types are supported");
|
||||||
|
@ -2225,7 +2224,7 @@ LogicalResult ConvertAtenOp<AtenBatchNormOp>::matchAndRewrite(
|
||||||
const TypeConverter *converter, Type outType,
|
const TypeConverter *converter, Type outType,
|
||||||
const Value toBcast, Value &result) {
|
const Value toBcast, Value &result) {
|
||||||
RankedTensorType toBcastType =
|
RankedTensorType toBcastType =
|
||||||
toBcast.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(toBcast.getType());
|
||||||
if (toBcastType.getRank() > 1)
|
if (toBcastType.getRank() > 1)
|
||||||
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
|
return rewriter.notifyMatchFailure(op, "Rank cannot be more than 1");
|
||||||
|
|
||||||
|
@ -2298,11 +2297,11 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
// eventually being reshaped for broadcasting.
|
// eventually being reshaped for broadcasting.
|
||||||
|
|
||||||
// Not a ranked tensor output
|
// Not a ranked tensor output
|
||||||
if (!adaptor.getInput().getType().dyn_cast<RankedTensorType>())
|
if (!dyn_cast<RankedTensorType>(adaptor.getInput().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types are supported");
|
op, "Only ranked tensor types are supported");
|
||||||
|
|
||||||
auto inputType = adaptor.getInput().getType().cast<RankedTensorType>();
|
auto inputType = cast<RankedTensorType>(adaptor.getInput().getType());
|
||||||
if (inputType.getRank() > 4)
|
if (inputType.getRank() > 4)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only up to 4D tensors are supported");
|
"Only up to 4D tensors are supported");
|
||||||
|
@ -2317,8 +2316,8 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
|
||||||
if (adaptor.getBias().getType().isa<Torch::NoneType>())
|
if (adaptor.getBias().getType().isa<Torch::NoneType>())
|
||||||
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
|
return rewriter.notifyMatchFailure(op, "Unsupported None for bias");
|
||||||
|
|
||||||
auto weightType = adaptor.getWeight().getType().cast<RankedTensorType>();
|
auto weightType = cast<RankedTensorType>(adaptor.getWeight().getType());
|
||||||
auto biasType = adaptor.getBias().getType().cast<RankedTensorType>();
|
auto biasType = cast<RankedTensorType>(adaptor.getBias().getType());
|
||||||
int64_t inputRank = inputType.getRank();
|
int64_t inputRank = inputType.getRank();
|
||||||
Type elemTy = inputType.getElementType();
|
Type elemTy = inputType.getElementType();
|
||||||
SmallVector<int64_t> inputTypeShape(
|
SmallVector<int64_t> inputTypeShape(
|
||||||
|
@ -2461,7 +2460,7 @@ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
|
||||||
// element type. All tensors with element types other than integer can reuse
|
// element type. All tensors with element types other than integer can reuse
|
||||||
// existing elements attribute.
|
// existing elements attribute.
|
||||||
// TODO: what about unsigned integer?
|
// TODO: what about unsigned integer?
|
||||||
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
|
if (auto elements = dyn_cast<DenseIntElementsAttr>(op.getValueAttr())) {
|
||||||
if (elements.getElementType().isSignedInteger()) {
|
if (elements.getElementType().isSignedInteger()) {
|
||||||
Type builtinTensorElemTy = outputTy.getElementType();
|
Type builtinTensorElemTy = outputTy.getElementType();
|
||||||
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
unsigned bitWidth = builtinTensorElemTy.getIntOrFloatBitWidth();
|
||||||
|
@ -2483,7 +2482,7 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a ranked tensor type
|
// Not a ranked tensor type
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only ranked tensor types supported");
|
"Only ranked tensor types supported");
|
||||||
|
@ -2548,7 +2547,7 @@ LogicalResult ConvertAtenOp<AtenUnflattenIntOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a ranked tensor type
|
// Not a ranked tensor type
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op,
|
||||||
|
@ -2602,7 +2601,7 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a ranked tensor type
|
// Not a ranked tensor type
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
auto selfType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op,
|
||||||
|
@ -2637,7 +2636,7 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2665,7 +2664,7 @@ LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2715,7 +2714,7 @@ LogicalResult ConvertAtenOp<AtenUnsqueezeOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2763,7 +2762,7 @@ LogicalResult ConvertAtenOp<AtenContiguousOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2781,7 +2780,7 @@ LogicalResult ConvertAtenOp<AtenDropoutOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getInput().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getInput().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2807,7 +2806,7 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2869,7 +2868,7 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
|
||||||
//
|
//
|
||||||
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
|
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4
|
||||||
|
|
||||||
auto outType = x.getType().cast<TensorType>();
|
auto outType = cast<TensorType>(x.getType());
|
||||||
auto loc = op->getLoc();
|
auto loc = op->getLoc();
|
||||||
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
|
||||||
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
|
||||||
|
@ -2949,7 +2948,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -2986,7 +2985,7 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -3043,7 +3042,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType) {
|
if (!selfType) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types are currently supported");
|
op, "Only tensor types are currently supported");
|
||||||
|
@ -3063,7 +3062,7 @@ LogicalResult ConvertAtenOp<AtenHardtanhBackwardOp>::matchAndRewrite(
|
||||||
}
|
}
|
||||||
|
|
||||||
Value gradOutput = adaptor.getGradOutput();
|
Value gradOutput = adaptor.getGradOutput();
|
||||||
auto gradOutputType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto gradOutputType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
|
|
||||||
Type gradOutputElemType = gradOutputType.getElementType();
|
Type gradOutputElemType = gradOutputType.getElementType();
|
||||||
|
|
||||||
|
@ -3119,14 +3118,14 @@ LogicalResult ConvertAtenOp<AtenEmbeddingOp>::matchAndRewrite(
|
||||||
Value weight = adaptor.getWeight();
|
Value weight = adaptor.getWeight();
|
||||||
Value indices = adaptor.getIndices();
|
Value indices = adaptor.getIndices();
|
||||||
RankedTensorType outType =
|
RankedTensorType outType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
|
|
||||||
auto indicesType = indices.getType().dyn_cast<RankedTensorType>();
|
auto indicesType = dyn_cast<RankedTensorType>(indices.getType());
|
||||||
if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
|
if (!indicesType || !indicesType.getElementType().isa<IntegerType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Indices must be of integer tensor type");
|
op, "Indices must be of integer tensor type");
|
||||||
|
|
||||||
auto weightType = weight.getType().cast<RankedTensorType>();
|
auto weightType = cast<RankedTensorType>(weight.getType());
|
||||||
if (weightType.getRank() != 2)
|
if (weightType.getRank() != 2)
|
||||||
return op.emitError("weight must be of rank 2");
|
return op.emitError("weight must be of rank 2");
|
||||||
|
|
||||||
|
@ -3216,7 +3215,7 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
|
||||||
AtenTransposeIntOp op, OpAdaptor adaptor,
|
AtenTransposeIntOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
@ -3258,12 +3257,12 @@ LogicalResult ConvertAtenOp<AtenMaxDimOp>::matchAndRewrite(
|
||||||
AtenMaxDimOp op, OpAdaptor adaptor,
|
AtenMaxDimOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
auto indicesType =
|
auto indicesType =
|
||||||
getTypeConverter()->convertType(op.getType(1)).dyn_cast<TensorType>();
|
dyn_cast<TensorType>(getTypeConverter()->convertType(op.getType(1)));
|
||||||
if (!indicesType)
|
if (!indicesType)
|
||||||
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
|
||||||
|
|
||||||
|
@ -3334,7 +3333,7 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
||||||
AtenSliceTensorOp op, OpAdaptor adaptor,
|
AtenSliceTensorOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types with static shape are supported");
|
op, "Only tensor types with static shape are supported");
|
||||||
|
@ -3406,7 +3405,7 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types with static shape are supported");
|
op, "Only tensor types with static shape are supported");
|
||||||
|
@ -3500,13 +3499,13 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
auto inputType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
if (!inputType)
|
if (!inputType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only RankedTensorType input are currently supported");
|
op, "Only RankedTensorType input are currently supported");
|
||||||
|
|
||||||
auto index = adaptor.getIndex();
|
auto index = adaptor.getIndex();
|
||||||
auto indexType = adaptor.getIndex().getType().dyn_cast<RankedTensorType>();
|
auto indexType = dyn_cast<RankedTensorType>(adaptor.getIndex().getType());
|
||||||
auto inputShape = inputType.getShape();
|
auto inputShape = inputType.getShape();
|
||||||
int paramsRank = inputShape.size();
|
int paramsRank = inputShape.size();
|
||||||
|
|
||||||
|
@ -3593,13 +3592,13 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
|
||||||
auto fillValues = adaptor.getValues();
|
auto fillValues = adaptor.getValues();
|
||||||
auto valuesType = adaptor.getValues().getType().dyn_cast<TensorType>();
|
auto valuesType = dyn_cast<TensorType>(adaptor.getValues().getType());
|
||||||
if (!valuesType)
|
if (!valuesType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
@ -3640,7 +3639,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Multiple None index is not support for now.");
|
op, "Multiple None index is not support for now.");
|
||||||
}
|
}
|
||||||
auto indexNextType = indexNext.getType().dyn_cast<RankedTensorType>();
|
auto indexNextType = dyn_cast<RankedTensorType>(indexNext.getType());
|
||||||
auto indexNextShape = indexNextType.getShape();
|
auto indexNextShape = indexNextType.getShape();
|
||||||
|
|
||||||
int64_t size = 1;
|
int64_t size = 1;
|
||||||
|
@ -3652,7 +3651,7 @@ LogicalResult ConvertAtenOp<Aten_IndexPutImplOp>::matchAndRewrite(
|
||||||
.value();
|
.value();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
auto indexShape = indexType.getShape();
|
auto indexShape = indexType.getShape();
|
||||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||||
indexesRank.push_back(indexType.getRank());
|
indexesRank.push_back(indexType.getRank());
|
||||||
|
@ -3734,7 +3733,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
|
// [[16, 17, 18, 19, 20], [11, 12, 13, 14, 15], [ 6, 7, 8, 9, 10]]]
|
||||||
auto input = adaptor.getSelf();
|
auto input = adaptor.getSelf();
|
||||||
auto inputTensorType =
|
auto inputTensorType =
|
||||||
adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(adaptor.getSelf().getType());
|
||||||
// Check input is a tensor type.
|
// Check input is a tensor type.
|
||||||
if (!inputTensorType)
|
if (!inputTensorType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -3771,7 +3770,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
for (size_t i = 0; i < indexTensors.size(); i++) {
|
for (size_t i = 0; i < indexTensors.size(); i++) {
|
||||||
auto index = indexTensors[i];
|
auto index = indexTensors[i];
|
||||||
|
|
||||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
auto indexShape = indexType.getShape();
|
auto indexShape = indexType.getShape();
|
||||||
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
indexesShape.push_back(makeShapeTorchCompatible(indexShape));
|
||||||
indexesRank.push_back(indexType.getRank());
|
indexesRank.push_back(indexType.getRank());
|
||||||
|
@ -3837,7 +3836,7 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Support for multiple index
|
// Support for multiple index
|
||||||
auto index = indexTensors[0];
|
auto index = indexTensors[0];
|
||||||
auto indexType = index.getType().dyn_cast<RankedTensorType>();
|
auto indexType = dyn_cast<RankedTensorType>(index.getType());
|
||||||
auto indexShape = indexType.getShape();
|
auto indexShape = indexType.getShape();
|
||||||
// index i64 to i32 for tosa compatible
|
// index i64 to i32 for tosa compatible
|
||||||
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
|
||||||
|
@ -3879,7 +3878,7 @@ LogicalResult ConvertAtenOp<AtenAbsOp>::matchAndRewrite(
|
||||||
AtenAbsOp op, OpAdaptor adaptor,
|
AtenAbsOp op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
@ -3896,11 +3895,11 @@ LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
auto condType = adaptor.getCondition().getType().dyn_cast<TensorType>();
|
auto condType = dyn_cast<TensorType>(adaptor.getCondition().getType());
|
||||||
if (!condType)
|
if (!condType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types condition are currently supported");
|
op, "Only tensor types condition are currently supported");
|
||||||
|
@ -3919,11 +3918,11 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
|
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
|
||||||
if (!otherType)
|
if (!otherType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types condition are currently supported");
|
op, "Only tensor types condition are currently supported");
|
||||||
|
@ -3955,8 +3954,8 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
|
||||||
op, "unimplemented: equal_nan is expected to be false");
|
op, "unimplemented: equal_nan is expected to be false");
|
||||||
|
|
||||||
// check tensor type.
|
// check tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
|
auto otherType = dyn_cast<TensorType>(adaptor.getOther().getType());
|
||||||
if (!selfType || !otherType)
|
if (!selfType || !otherType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
@ -3998,7 +3997,7 @@ LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType)
|
if (!selfType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only tensor types input are currently supported");
|
op, "only tensor types input are currently supported");
|
||||||
|
@ -4251,8 +4250,8 @@ LogicalResult ConvertAtenOp<AtenCopyOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
auto srcType = adaptor.getSrc().getType().dyn_cast<TensorType>();
|
auto srcType = dyn_cast<TensorType>(adaptor.getSrc().getType());
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types with static shape are supported");
|
op, "Only tensor types with static shape are supported");
|
||||||
|
@ -4297,7 +4296,7 @@ LogicalResult ConvertAtenOp<AtenToDtypeOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType || !selfType.hasStaticShape())
|
if (!selfType || !selfType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types with static shape are supported");
|
op, "Only tensor types with static shape are supported");
|
||||||
|
@ -4355,14 +4354,14 @@ LogicalResult ConvertAtenOp<AtenRemainderScalarOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().template cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
|
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only ranked tensor types supported in TOSA Remainder");
|
op, "Only ranked tensor types supported in TOSA Remainder");
|
||||||
|
|
||||||
auto outType =
|
auto outType =
|
||||||
getTypeConverter()->convertType(op.getType()).template cast<TensorType>();
|
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
|
||||||
|
|
||||||
Type outElemTy = outType.getElementType();
|
Type outElemTy = outType.getElementType();
|
||||||
if (!outElemTy.isIntOrFloat())
|
if (!outElemTy.isIntOrFloat())
|
||||||
|
@ -4438,7 +4437,7 @@ public:
|
||||||
// Apply the transposeDims vector on input to generate a transposed form.
|
// Apply the transposeDims vector on input to generate a transposed form.
|
||||||
Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter,
|
Value transposeTensor(AtenOpT op, ConversionPatternRewriter &rewriter,
|
||||||
Value input, ArrayRef<int32_t> transposeDims) const {
|
Value input, ArrayRef<int32_t> transposeDims) const {
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputElemTy = inputTy.getElementType();
|
auto inputElemTy = inputTy.getElementType();
|
||||||
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
|
auto inputShape = makeShapeTorchCompatible(inputTy.getShape());
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
|
@ -4462,8 +4461,7 @@ public:
|
||||||
Value transposePoolingInputToHwc(AtenOpT op,
|
Value transposePoolingInputToHwc(AtenOpT op,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value input) const {
|
Value input) const {
|
||||||
auto inputRank =
|
auto inputRank = cast<RankedTensorType>(input.getType()).getRank();
|
||||||
input.getType().template cast<RankedTensorType>().getRank();
|
|
||||||
|
|
||||||
SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
|
SmallVector<int32_t> nchwToNhwc4DTransposeDims({0, 2, 3, 1});
|
||||||
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
|
SmallVector<int32_t> chwToHwc3DTransposeDims({1, 2, 0});
|
||||||
|
@ -4476,7 +4474,7 @@ public:
|
||||||
Value transposePoolingOutputToChw(AtenOpT op,
|
Value transposePoolingOutputToChw(AtenOpT op,
|
||||||
ConversionPatternRewriter &rewriter,
|
ConversionPatternRewriter &rewriter,
|
||||||
Value input) const {
|
Value input) const {
|
||||||
auto inputTy = input.getType().template cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputTy.getRank();
|
auto inputRank = inputTy.getRank();
|
||||||
|
|
||||||
SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2});
|
SmallVector<int32_t> nhwcToNchw4DTransposeDims({0, 3, 1, 2});
|
||||||
|
@ -4547,7 +4545,7 @@ public:
|
||||||
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
|
||||||
Type &outputTy) const override {
|
Type &outputTy) const override {
|
||||||
auto inputXchw = adaptor.getSelf();
|
auto inputXchw = adaptor.getSelf();
|
||||||
auto inputTy = inputXchw.getType().template cast<RankedTensorType>();
|
auto inputTy = cast<RankedTensorType>(inputXchw.getType());
|
||||||
if (!inputTy)
|
if (!inputTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Adaptive avgpool requires ranked tensor input");
|
op, "Adaptive avgpool requires ranked tensor input");
|
||||||
|
@ -4659,7 +4657,7 @@ static LogicalResult getOutputTypeAndPoolingParameters(
|
||||||
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
|
DenseI64ArrayAttr &kernel, DenseI64ArrayAttr &stride,
|
||||||
DenseI64ArrayAttr &pad) {
|
DenseI64ArrayAttr &pad) {
|
||||||
|
|
||||||
RankedTensorType inputTy = inputXchw.getType().cast<RankedTensorType>();
|
RankedTensorType inputTy = cast<RankedTensorType>(inputXchw.getType());
|
||||||
if (!inputTy)
|
if (!inputTy)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Pooling op requires ranked tensor input");
|
op, "Pooling op requires ranked tensor input");
|
||||||
|
@ -4797,7 +4795,7 @@ public:
|
||||||
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
|
// FIXME: Handle layout, device and pin_memory. Assume dtype has been
|
||||||
// processed to set output type correctly?
|
// processed to set output type correctly?
|
||||||
// The layout arg should be either `none` or `0` i.e. strided.
|
// The layout arg should be either `none` or `0` i.e. strided.
|
||||||
if (!op.getLayout().getType().template isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||||
int64_t tensorLayout;
|
int64_t tensorLayout;
|
||||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4808,7 +4806,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!op.getPinMemory().getType().template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType()) &&
|
||||||
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
(!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)) ||
|
||||||
pinMemory)) {
|
pinMemory)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4892,19 +4890,19 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Not a tensor type.
|
// Not a tensor type.
|
||||||
auto selfType = adaptor.getSelf().getType().template dyn_cast<TensorType>();
|
auto selfType = dyn_cast<TensorType>(adaptor.getSelf().getType());
|
||||||
if (!selfType || !outType.hasStaticShape())
|
if (!selfType || !outType.hasStaticShape())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op,
|
op,
|
||||||
"Only tensor types with static shapes input are currently supported");
|
"Only tensor types with static shapes input are currently supported");
|
||||||
|
|
||||||
auto maskType = adaptor.getMask().getType().template dyn_cast<TensorType>();
|
auto maskType = dyn_cast<TensorType>(adaptor.getMask().getType());
|
||||||
if (!maskType)
|
if (!maskType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types mask are currently supported");
|
op, "Only tensor types mask are currently supported");
|
||||||
|
|
||||||
Value rhs = adaptor.getValue();
|
Value rhs = adaptor.getValue();
|
||||||
auto rhsType = rhs.getType().template dyn_cast<TensorType>();
|
auto rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
Value rhsAsTensor;
|
Value rhsAsTensor;
|
||||||
if (!rhsType) { // scalar
|
if (!rhsType) { // scalar
|
||||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(),
|
if (failed(torchScalarToTosaTensor(rewriter, op, op.getValue(),
|
||||||
|
@ -4913,11 +4911,11 @@ public:
|
||||||
op, "Currently only scalar constants are supported for "
|
op, "Currently only scalar constants are supported for "
|
||||||
"conversion in TOSA operation");
|
"conversion in TOSA operation");
|
||||||
} else { // tensor
|
} else { // tensor
|
||||||
rhsType = rhs.getType().dyn_cast<TensorType>();
|
rhsType = dyn_cast<TensorType>(rhs.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||||
auto rhsTensorType = rhsTensor.getType().template dyn_cast<TensorType>();
|
auto rhsTensorType = dyn_cast<TensorType>(rhsTensor.getType());
|
||||||
if (rhsTensorType.getElementType() != outElemTy)
|
if (rhsTensorType.getElementType() != outElemTy)
|
||||||
rhsTensor = rewriter.create<tosa::CastOp>(
|
rhsTensor = rewriter.create<tosa::CastOp>(
|
||||||
op.getLoc(),
|
op.getLoc(),
|
||||||
|
@ -4940,7 +4938,7 @@ public:
|
||||||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
int64_t memoryFormat;
|
int64_t memoryFormat;
|
||||||
if (!op.getMemoryFormat().getType().template isa<Torch::NoneType>() &&
|
if (!isa<Torch::NoneType>(op.getMemoryFormat().getType()) &&
|
||||||
(!matchPattern(op.getMemoryFormat(),
|
(!matchPattern(op.getMemoryFormat(),
|
||||||
m_TorchConstantInt(&memoryFormat)) ||
|
m_TorchConstantInt(&memoryFormat)) ||
|
||||||
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
(memoryFormat != torch_upstream::MemoryFormat::Contiguous &&
|
||||||
|
@ -4964,7 +4962,7 @@ LogicalResult ConvertAtenOp<AtenConstantPadNdOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = adaptor.getSelf();
|
Value self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().cast<RankedTensorType>();
|
auto selfTy = cast<RankedTensorType>(self.getType());
|
||||||
auto selfElemTy = selfTy.getElementType();
|
auto selfElemTy = selfTy.getElementType();
|
||||||
int64_t rank = selfTy.getRank();
|
int64_t rank = selfTy.getRank();
|
||||||
|
|
||||||
|
@ -5033,7 +5031,7 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
const TypeConverter *typeConverter = this->getTypeConverter();
|
const TypeConverter *typeConverter = this->getTypeConverter();
|
||||||
auto outType =
|
auto outType =
|
||||||
typeConverter->convertType(op.getType()).cast<RankedTensorType>();
|
cast<RankedTensorType>(typeConverter->convertType(op.getType()));
|
||||||
int64_t rank = outType.getRank();
|
int64_t rank = outType.getRank();
|
||||||
int64_t dim;
|
int64_t dim;
|
||||||
|
|
||||||
|
@ -5074,7 +5072,7 @@ LogicalResult ConvertAtenOp<AtenSqrtOp>::matchAndRewrite(
|
||||||
|
|
||||||
// Converts AtenSqrtOp into pow(x, 0.5)
|
// Converts AtenSqrtOp into pow(x, 0.5)
|
||||||
auto self = adaptor.getSelf();
|
auto self = adaptor.getSelf();
|
||||||
auto selfTy = self.getType().dyn_cast<TensorType>();
|
auto selfTy = dyn_cast<TensorType>(self.getType());
|
||||||
if (!selfTy)
|
if (!selfTy)
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Only Tensor types supported in TOSA");
|
"Only Tensor types supported in TOSA");
|
||||||
|
|
|
@ -117,8 +117,8 @@ template <>
|
||||||
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
tosa::DivOp createBinaryOpAndCast<DivOp>(PatternRewriter &rewriter,
|
||||||
Operation *op, TensorType outType,
|
Operation *op, TensorType outType,
|
||||||
Value lhs, Value rhs) {
|
Value lhs, Value rhs) {
|
||||||
auto lhsElemTy = lhs.getType().cast<TensorType>().getElementType();
|
auto lhsElemTy = cast<TensorType>(lhs.getType()).getElementType();
|
||||||
auto rhsElemTy = rhs.getType().cast<TensorType>().getElementType();
|
auto rhsElemTy = cast<TensorType>(rhs.getType()).getElementType();
|
||||||
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
|
if (isa<mlir::FloatType>(lhsElemTy) || isa<mlir::FloatType>(rhsElemTy)) {
|
||||||
(void)rewriter.notifyMatchFailure(op,
|
(void)rewriter.notifyMatchFailure(op,
|
||||||
"tosa.div only supports integer type");
|
"tosa.div only supports integer type");
|
||||||
|
@ -148,8 +148,8 @@ std::optional<Value> convertTorchIndexToTfIndices(PatternRewriter &rewriter,
|
||||||
// [2,1] [[0, 3, 2],[0, 3, 1]]
|
// [2,1] [[0, 3, 2],[0, 3, 1]]
|
||||||
// ]] 1*4*2 ]] 1*4*2*3
|
// ]] 1*4*2 ]] 1*4*2*3
|
||||||
|
|
||||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
|
||||||
auto indexType = indexValue.getType().dyn_cast<RankedTensorType>();
|
auto indexType = dyn_cast<RankedTensorType>(indexValue.getType());
|
||||||
auto paramsShape = paramsType.getShape(); // [1 4 3]
|
auto paramsShape = paramsType.getShape(); // [1 4 3]
|
||||||
auto indexShape = indexType.getShape(); // [1 4 2]
|
auto indexShape = indexType.getShape(); // [1 4 2]
|
||||||
int paramsRank = paramsShape.size(); // 3
|
int paramsRank = paramsShape.size(); // 3
|
||||||
|
@ -214,8 +214,8 @@ std::optional<Value> convertGatherNdOp(PatternRewriter &rewriter, Operation *op,
|
||||||
Type outType, Value paramsValue,
|
Type outType, Value paramsValue,
|
||||||
Value indicesValue) {
|
Value indicesValue) {
|
||||||
auto resultType = dyn_cast<ShapedType>(outType);
|
auto resultType = dyn_cast<ShapedType>(outType);
|
||||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
|
||||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
|
||||||
|
|
||||||
if (!resultType || !paramsType || !indicesType)
|
if (!resultType || !paramsType || !indicesType)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
@ -420,9 +420,9 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||||
Value paramsValue, Value indicesValue,
|
Value paramsValue, Value indicesValue,
|
||||||
Value fillValues) {
|
Value fillValues) {
|
||||||
auto resultType = dyn_cast<ShapedType>(outType);
|
auto resultType = dyn_cast<ShapedType>(outType);
|
||||||
auto paramsType = paramsValue.getType().dyn_cast<RankedTensorType>();
|
auto paramsType = dyn_cast<RankedTensorType>(paramsValue.getType());
|
||||||
auto indicesType = indicesValue.getType().dyn_cast<RankedTensorType>();
|
auto indicesType = dyn_cast<RankedTensorType>(indicesValue.getType());
|
||||||
auto fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
auto fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
|
||||||
|
|
||||||
if (!resultType || !paramsType || !indicesType)
|
if (!resultType || !paramsType || !indicesType)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
@ -572,7 +572,7 @@ std::optional<Value> convertScatterNdOp(PatternRewriter &rewriter,
|
||||||
tosaFillValuesTileOp.getResult(),
|
tosaFillValuesTileOp.getResult(),
|
||||||
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
|
rewriter.getDenseI64ArrayAttr(newTosaFillValuesShape));
|
||||||
fillValues = newTosaFillValuesReshapeOp.getResult();
|
fillValues = newTosaFillValuesReshapeOp.getResult();
|
||||||
fillValuesType = fillValues.getType().dyn_cast<RankedTensorType>();
|
fillValuesType = dyn_cast<RankedTensorType>(fillValues.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
// fillK: range of each index, total number of fillInput(could be scatter)
|
// fillK: range of each index, total number of fillInput(could be scatter)
|
||||||
|
@ -691,7 +691,7 @@ std::optional<Value> convertReduceOpCommon(
|
||||||
Type reduce_element_type, bool is_quantized, double input_scale,
|
Type reduce_element_type, bool is_quantized, double input_scale,
|
||||||
int64_t input_zp, double output_scale, int64_t output_zp) {
|
int64_t input_zp, double output_scale, int64_t output_zp) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -754,7 +754,7 @@ convertReduceAllOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -769,7 +769,7 @@ convertReduceAnyOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -784,7 +784,7 @@ convertReduceMinOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -799,7 +799,7 @@ convertReduceMaxOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -814,7 +814,7 @@ convertReduceProdOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -840,7 +840,7 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -863,9 +863,9 @@ convertReduceSumOp(PatternRewriter &rewriter, Operation *op,
|
||||||
|
|
||||||
if (input_is_qtype) {
|
if (input_is_qtype) {
|
||||||
auto input_qtype =
|
auto input_qtype =
|
||||||
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||||
auto output_qtype =
|
auto output_qtype =
|
||||||
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||||
|
|
||||||
int32_t input_shift = 20;
|
int32_t input_shift = 20;
|
||||||
|
|
||||||
|
@ -895,7 +895,7 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||||
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
|
// op2 = mul(op1, 1.0 / num_elements_on_reduced_axis)
|
||||||
|
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
@ -940,9 +940,9 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op,
|
||||||
|
|
||||||
if (input_is_qtype) {
|
if (input_is_qtype) {
|
||||||
auto input_qtype =
|
auto input_qtype =
|
||||||
input_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||||
auto output_qtype =
|
auto output_qtype =
|
||||||
output_type.getElementType().cast<mlir::quant::UniformQuantizedType>();
|
cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||||
|
|
||||||
// Combine 'div_scale' as part of output rescale
|
// Combine 'div_scale' as part of output rescale
|
||||||
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
|
output_scale = div_scale * input_qtype.getScale() / output_qtype.getScale();
|
||||||
|
@ -976,7 +976,7 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op,
|
||||||
RankedTensorType output_type, Value input_value,
|
RankedTensorType output_type, Value input_value,
|
||||||
ElementsAttr axes_elems, bool keep_dims) {
|
ElementsAttr axes_elems, bool keep_dims) {
|
||||||
RankedTensorType input_type =
|
RankedTensorType input_type =
|
||||||
input_value.getType().dyn_cast<RankedTensorType>();
|
dyn_cast<RankedTensorType>(input_value.getType());
|
||||||
if (!input_type)
|
if (!input_type)
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
|
||||||
Value input_val, double input_scale,
|
Value input_val, double input_scale,
|
||||||
int64_t input_zp) {
|
int64_t input_zp) {
|
||||||
// Output is always int32 type
|
// Output is always int32 type
|
||||||
auto input_type = input_val.getType().dyn_cast<mlir::ShapedType>();
|
auto input_type = dyn_cast<mlir::ShapedType>(input_val.getType());
|
||||||
assert(input_type);
|
assert(input_type);
|
||||||
auto output_type = input_type.clone(rewriter.getI32Type());
|
auto output_type = input_type.clone(rewriter.getI32Type());
|
||||||
|
|
||||||
|
@ -58,9 +58,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
||||||
Value conv_val, ShapedType input_type,
|
Value conv_val, ShapedType input_type,
|
||||||
ShapedType weight_type, ShapedType output_type) {
|
ShapedType weight_type, ShapedType output_type) {
|
||||||
auto input_qtype =
|
auto input_qtype =
|
||||||
input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
|
dyn_cast<mlir::quant::UniformQuantizedType>(input_type.getElementType());
|
||||||
auto output_qtype = output_type.getElementType()
|
auto output_qtype =
|
||||||
.dyn_cast<mlir::quant::UniformQuantizedType>();
|
dyn_cast<mlir::quant::UniformQuantizedType>(output_type.getElementType());
|
||||||
|
|
||||||
double input_scale = input_qtype.getScale();
|
double input_scale = input_qtype.getScale();
|
||||||
|
|
||||||
|
@ -71,8 +71,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
||||||
int32_t scale_width = scale32 ? 32 : 16;
|
int32_t scale_width = scale32 ? 32 : 16;
|
||||||
|
|
||||||
if (auto weight_per_tensor_qtype =
|
if (auto weight_per_tensor_qtype =
|
||||||
weight_type.getElementType()
|
dyn_cast<mlir::quant::UniformQuantizedType>(
|
||||||
.dyn_cast<mlir::quant::UniformQuantizedType>()) {
|
weight_type.getElementType())) {
|
||||||
// Per-tensor quantization
|
// Per-tensor quantization
|
||||||
double weight_scale = weight_per_tensor_qtype.getScale();
|
double weight_scale = weight_per_tensor_qtype.getScale();
|
||||||
|
|
||||||
|
@ -94,8 +94,8 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
|
||||||
return rescale_op.getResult();
|
return rescale_op.getResult();
|
||||||
|
|
||||||
} else if (auto weight_per_channel_qtype =
|
} else if (auto weight_per_channel_qtype =
|
||||||
weight_type.getElementType()
|
dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(
|
||||||
.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
|
weight_type.getElementType())) {
|
||||||
// Per-channel quantization
|
// Per-channel quantization
|
||||||
SmallVector<int32_t> multiplier_arr;
|
SmallVector<int32_t> multiplier_arr;
|
||||||
SmallVector<int8_t> shift_arr;
|
SmallVector<int8_t> shift_arr;
|
||||||
|
@ -311,7 +311,7 @@ static LogicalResult checkValidityOfCast(Type src, Type dest) {
|
||||||
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||||
Value src, Type destType, Value &result) {
|
Value src, Type destType, Value &result) {
|
||||||
|
|
||||||
Type srcElemTy = src.getType().dyn_cast<TensorType>().getElementType();
|
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
|
||||||
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
|
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();
|
||||||
|
|
||||||
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
|
||||||
|
@ -319,7 +319,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||||
op, "casting to result dtype is invalid or unsupported");
|
op, "casting to result dtype is invalid or unsupported");
|
||||||
|
|
||||||
if (destElemTy.isInteger(1)) {
|
if (destElemTy.isInteger(1)) {
|
||||||
auto srcType = src.getType().dyn_cast<TensorType>();
|
auto srcType = dyn_cast<TensorType>(src.getType());
|
||||||
SmallVector<int64_t> srcShape(srcType.getShape());
|
SmallVector<int64_t> srcShape(srcType.getShape());
|
||||||
uint64_t num_total_elements = 1;
|
uint64_t num_total_elements = 1;
|
||||||
for (int64_t a : srcShape)
|
for (int64_t a : srcShape)
|
||||||
|
@ -355,7 +355,7 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
|
||||||
|
|
||||||
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) {
|
||||||
Operation *op = input.getDefiningOp();
|
Operation *op = input.getDefiningOp();
|
||||||
TensorType inType = input.getType().cast<TensorType>();
|
TensorType inType = cast<TensorType>(input.getType());
|
||||||
|
|
||||||
if (inType.getElementType() != outType.getElementType()) {
|
if (inType.getElementType() != outType.getElementType()) {
|
||||||
TensorType promotedType =
|
TensorType promotedType =
|
||||||
|
|
|
@ -52,7 +52,7 @@ LogicalResult checkNotNone(PatternRewriter &rewriter, Operation *op, Value v) {
|
||||||
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
// Generate IR: dim = dim >= 0 ? dim : dim + inputRank
|
||||||
Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim,
|
||||||
Value inputRank) {
|
Value inputRank) {
|
||||||
assert(dim.getType().isa<IntegerType>() &&
|
assert(isa<IntegerType>(dim.getType()) &&
|
||||||
"dim arg of toPositiveDim must be integer type");
|
"dim arg of toPositiveDim must be integer type");
|
||||||
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
|
Value dimAddInputRank = b.create<arith::AddIOp>(loc, dim, inputRank);
|
||||||
Value cst0 =
|
Value cst0 =
|
||||||
|
@ -132,7 +132,7 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
|
||||||
Type elemTy) {
|
Type elemTy) {
|
||||||
Value initTensor =
|
Value initTensor =
|
||||||
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
b.create<tensor::EmptyOp>(loc, getAsOpFoldResult(sizes), elemTy);
|
||||||
RankedTensorType type = initTensor.getType().cast<RankedTensorType>();
|
RankedTensorType type = cast<RankedTensorType>(initTensor.getType());
|
||||||
Value c0 =
|
Value c0 =
|
||||||
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
b.create<arith::ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
|
||||||
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
|
||||||
|
@ -172,7 +172,7 @@ Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
|
||||||
|
|
||||||
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||||
Value tensor, int dim) {
|
Value tensor, int dim) {
|
||||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
RankedTensorType type = cast<RankedTensorType>(tensor.getType());
|
||||||
assert(dim < type.getRank() &&
|
assert(dim < type.getRank() &&
|
||||||
"The given dim must be smaller than tensor rank");
|
"The given dim must be smaller than tensor rank");
|
||||||
(void)type;
|
(void)type;
|
||||||
|
@ -183,7 +183,7 @@ SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor) {
|
SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc, Value tensor) {
|
||||||
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
|
RankedTensorType type = cast<RankedTensorType>(tensor.getType());
|
||||||
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
return getTensorSizesUntilDim(b, loc, tensor, type.getRank() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v,
|
||||||
|
|
||||||
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
OpFoldResult TMTensor::getDim(OpBuilder &builder, Location loc, Value v,
|
||||||
int64_t dim) {
|
int64_t dim) {
|
||||||
auto t = v.getType().cast<ShapedType>();
|
auto t = cast<ShapedType>(v.getType());
|
||||||
if (t.isDynamicDim(dim)) {
|
if (t.isDynamicDim(dim)) {
|
||||||
return getDimValue(builder, loc, v, dim);
|
return getDimValue(builder, loc, v, dim);
|
||||||
}
|
}
|
||||||
|
@ -123,7 +123,7 @@ bool AttentionOp::payloadUsesValueFromOperand(OpOperand *opOperand) {
|
||||||
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes,
|
||||||
Value rhs, ValueRange rhsSizes, Value output,
|
Value rhs, ValueRange rhsSizes, Value output,
|
||||||
ValueRange outputSizes, bool transposed = false) {
|
ValueRange outputSizes, bool transposed = false) {
|
||||||
auto elementType = lhs.getType().cast<MemRefType>().getElementType();
|
auto elementType = cast<MemRefType>(lhs.getType()).getElementType();
|
||||||
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
Value one = b.create<arith::ConstantIndexOp>(loc, 1);
|
||||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
auto rank = outputSizes.size();
|
auto rank = outputSizes.size();
|
||||||
|
@ -168,9 +168,9 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b,
|
||||||
Value key = getKey();
|
Value key = getKey();
|
||||||
Value value = getValue();
|
Value value = getValue();
|
||||||
Value output = getOutput();
|
Value output = getOutput();
|
||||||
auto queryType = query.getType().cast<MemRefType>();
|
auto queryType = cast<MemRefType>(query.getType());
|
||||||
auto keyType = key.getType().cast<MemRefType>();
|
auto keyType = cast<MemRefType>(key.getType());
|
||||||
auto valueType = value.getType().cast<MemRefType>();
|
auto valueType = cast<MemRefType>(value.getType());
|
||||||
auto queryRank = queryType.getRank();
|
auto queryRank = queryType.getRank();
|
||||||
auto keyRank = keyType.getRank();
|
auto keyRank = keyType.getRank();
|
||||||
auto valueRank = valueType.getRank();
|
auto valueRank = valueType.getRank();
|
||||||
|
@ -330,12 +330,12 @@ LogicalResult ScanOp::verify() {
|
||||||
if (getNumOutputs() != 2) {
|
if (getNumOutputs() != 2) {
|
||||||
return emitOpError("expected two output operands");
|
return emitOpError("expected two output operands");
|
||||||
}
|
}
|
||||||
if (!input().getType().isa<ShapedType>()) {
|
if (!isa<ShapedType>(input().getType())) {
|
||||||
return emitOpError("expected first input element type to be shaped");
|
return emitOpError("expected first input element type to be shaped");
|
||||||
}
|
}
|
||||||
auto accumulatorType = accumulator().getType().cast<ShapedType>();
|
auto accumulatorType = cast<ShapedType>(accumulator().getType());
|
||||||
auto inputType = input().getType().cast<ShapedType>();
|
auto inputType = cast<ShapedType>(input().getType());
|
||||||
auto outputType = output().getType().cast<ShapedType>();
|
auto outputType = cast<ShapedType>(output().getType());
|
||||||
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
ArrayRef<int64_t> inputShapes = inputType.getShape();
|
||||||
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
ArrayRef<int64_t> outputShapes = outputType.getShape();
|
||||||
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
if (accumulatorType.getElementType() != inputType.getElementType()) {
|
||||||
|
@ -706,7 +706,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
|
||||||
loadIndices.push_back(Value());
|
loadIndices.push_back(Value());
|
||||||
|
|
||||||
// Populate with empty values.
|
// Populate with empty values.
|
||||||
auto originalTy = original().getType().cast<ShapedType>();
|
auto originalTy = cast<ShapedType>(original().getType());
|
||||||
starts.resize(originalTy.getRank(), Value());
|
starts.resize(originalTy.getRank(), Value());
|
||||||
auto updateIvs = ivs.drop_front(1);
|
auto updateIvs = ivs.drop_front(1);
|
||||||
|
|
||||||
|
@ -797,7 +797,7 @@ LogicalResult SortOp::verify() {
|
||||||
if (yieldOp.getNumOperands() != 1) {
|
if (yieldOp.getNumOperands() != 1) {
|
||||||
return op->emitOpError("should yield exactly one operand");
|
return op->emitOpError("should yield exactly one operand");
|
||||||
}
|
}
|
||||||
auto ty = yieldOp.getOperand(0).getType().dyn_cast<IntegerType>();
|
auto ty = dyn_cast<IntegerType>(yieldOp.getOperand(0).getType());
|
||||||
if (!ty || ty.getWidth() != 1) {
|
if (!ty || ty.getWidth() != 1) {
|
||||||
return op->emitOpError("should yield i1 type");
|
return op->emitOpError("should yield i1 type");
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ using namespace ::mlir;
|
||||||
using namespace ::mlir::torch::TMTensor;
|
using namespace ::mlir::torch::TMTensor;
|
||||||
|
|
||||||
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
|
static Value cloneMemref(Location loc, Value memref, OpBuilder &b) {
|
||||||
auto memrefType = memref.getType().cast<MemRefType>();
|
auto memrefType = cast<MemRefType>(memref.getType());
|
||||||
auto alloc = b.create<memref::AllocOp>(
|
auto alloc = b.create<memref::AllocOp>(
|
||||||
loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType());
|
loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType());
|
||||||
b.create<memref::CopyOp>(loc, memref, alloc);
|
b.create<memref::CopyOp>(loc, memref, alloc);
|
||||||
|
|
|
@ -80,7 +80,7 @@ struct ScalarLoopOpInterfaceLowerToLoopsPattern : public RewritePattern {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (llvm::any_of(scalarLoopOp->getResults(),
|
if (llvm::any_of(scalarLoopOp->getResults(),
|
||||||
[&](Value v) { return v.getType().isa<ShapedType>(); })) {
|
[&](Value v) { return isa<ShapedType>(v.getType()); })) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
scalarLoopOp, "lower to loops needs to have tensor semantics");
|
scalarLoopOp, "lower to loops needs to have tensor semantics");
|
||||||
}
|
}
|
||||||
|
|
|
@ -122,14 +122,14 @@ LogicalResult TorchDialect::verifyRegionArgAttribute(Operation *op,
|
||||||
auto func = dyn_cast<func::FuncOp>(op);
|
auto func = dyn_cast<func::FuncOp>(op);
|
||||||
if (!func)
|
if (!func)
|
||||||
return op->emitError() << "'torch.type_bound' must be attached to a func";
|
return op->emitError() << "'torch.type_bound' must be attached to a func";
|
||||||
TypeAttr attr = namedAttr.getValue().dyn_cast<TypeAttr>();
|
TypeAttr attr = dyn_cast<TypeAttr>(namedAttr.getValue());
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return op->emitError() << "'torch.type_bound' must be TypeAttr";
|
return op->emitError() << "'torch.type_bound' must be TypeAttr";
|
||||||
auto type = attr.getValue().dyn_cast<BaseTensorType>();
|
auto type = dyn_cast<BaseTensorType>(attr.getValue());
|
||||||
if (!type)
|
if (!type)
|
||||||
return op->emitError() << "'torch.type_bound' must be of "
|
return op->emitError() << "'torch.type_bound' must be of "
|
||||||
"!torch.tensor/!torch.vtensor type";
|
"!torch.tensor/!torch.vtensor type";
|
||||||
if (!func.getFunctionType().getInput(argIndex).isa<BaseTensorType>())
|
if (!isa<BaseTensorType>(func.getFunctionType().getInput(argIndex)))
|
||||||
return op->emitError() << "'torch.type_bound' must be attached to an "
|
return op->emitError() << "'torch.type_bound' must be attached to an "
|
||||||
"argument of !torch.tensor/!torch.vtensor type";
|
"argument of !torch.tensor/!torch.vtensor type";
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -75,7 +75,7 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder,
|
||||||
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
BaseTensorType newType,
|
BaseTensorType newType,
|
||||||
Value tensor) {
|
Value tensor) {
|
||||||
auto originalType = tensor.getType().cast<BaseTensorType>();
|
auto originalType = cast<BaseTensorType>(tensor.getType());
|
||||||
// Adjust the static information in the type to match between the original and
|
// Adjust the static information in the type to match between the original and
|
||||||
// new types.
|
// new types.
|
||||||
if (!originalType.hasSameSizesAndDtype(newType)) {
|
if (!originalType.hasSameSizesAndDtype(newType)) {
|
||||||
|
@ -87,7 +87,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
// up creating one op that converts between the value and non-value tensor
|
// up creating one op that converts between the value and non-value tensor
|
||||||
// domains. If both the original and new types are both non-value tensors,
|
// domains. If both the original and new types are both non-value tensors,
|
||||||
// then we do the copy by going to a value tensor and back.
|
// then we do the copy by going to a value tensor and back.
|
||||||
if (tensor.getType().isa<NonValueTensorType>())
|
if (isa<NonValueTensorType>(tensor.getType()))
|
||||||
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
tensor = builder.create<CopyToValueTensorOp>(loc, tensor);
|
||||||
if (isa<NonValueTensorType>(newType))
|
if (isa<NonValueTensorType>(newType))
|
||||||
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
tensor = builder.create<CopyToNonValueTensorOp>(loc, tensor);
|
||||||
|
@ -96,7 +96,7 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc,
|
||||||
}
|
}
|
||||||
|
|
||||||
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
|
bool mlir::torch::Torch::isListPotentiallyMutated(Value list) {
|
||||||
assert(list.getType().isa<Torch::ListType>());
|
assert(isa<Torch::ListType>(list.getType()));
|
||||||
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
|
return llvm::any_of(list.getUsers(), potentiallyMutatesListOperands);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,8 +148,7 @@ static Value getScalarIntValue(Value input, Location loc,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
if (auto valueTensorLiteralOp = input.getDefiningOp<ValueTensorLiteralOp>()) {
|
||||||
auto val = valueTensorLiteralOp.getValue()
|
auto val = cast<DenseIntElementsAttr>(valueTensorLiteralOp.getValue())
|
||||||
.cast<DenseIntElementsAttr>()
|
|
||||||
.getSplatValue<int64_t>();
|
.getSplatValue<int64_t>();
|
||||||
return rewriter.create<Torch::ConstantIntOp>(
|
return rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(val));
|
loc, rewriter.getI64IntegerAttr(val));
|
||||||
|
@ -777,7 +776,7 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) {
|
||||||
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) {
|
||||||
if (getOperand(0).getType() != getResult().getType())
|
if (getOperand(0).getType() != getResult().getType())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (auto tensorType = getOperand(0).getType().dyn_cast<BaseTensorType>()) {
|
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand(0).getType())) {
|
||||||
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
if (tensorType.hasSizes() && tensorType.getSizes().size() == 0)
|
||||||
return getOperand(0);
|
return getOperand(0);
|
||||||
}
|
}
|
||||||
|
@ -798,11 +797,11 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
||||||
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg)
|
if (!matchPattern(getCopy(), m_TorchConstantBool(©Arg)) || copyArg)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
// The memory_format arg must be `none`.
|
// The memory_format arg must be `none`.
|
||||||
if (!getMemoryFormat().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto inputType = getSelf().getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(getSelf().getType());
|
||||||
auto resType = getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(getType());
|
||||||
// If the types aren't equal, then we can't fold.
|
// If the types aren't equal, then we can't fold.
|
||||||
if (inputType != resType)
|
if (inputType != resType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -821,7 +820,7 @@ OpFoldResult AtenToDtypeOp::fold(FoldAdaptor adaptor) {
|
||||||
|
|
||||||
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
||||||
// The pin_memory arg should be either constant `False` or `none`.
|
// The pin_memory arg should be either constant `False` or `none`.
|
||||||
if (!getPinMemory().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(getPinMemory().getType())) {
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
if (!matchPattern(getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -844,15 +843,15 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
// The device arg must be `none`.
|
// The device arg must be `none`.
|
||||||
if (!getDevice().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(getDevice().getType()))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
// The memory_format arg must be `none`.
|
// The memory_format arg must be `none`.
|
||||||
if (!getMemoryFormat().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(getMemoryFormat().getType()))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto inputType = getSelf().getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(getSelf().getType());
|
||||||
auto resType = getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(getType());
|
||||||
// If the types aren't equal, then we can't fold.
|
// If the types aren't equal, then we can't fold.
|
||||||
if (inputType != resType)
|
if (inputType != resType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -863,7 +862,7 @@ OpFoldResult AtenToDtypeLayoutOp::fold(FoldAdaptor adaptor) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
// The layout arg should be either `none` or `0` i.e. strided.
|
// The layout arg should be either `none` or `0` i.e. strided.
|
||||||
if (!getLayout().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(getLayout().getType())) {
|
||||||
int64_t tensorLayout;
|
int64_t tensorLayout;
|
||||||
if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout)))
|
if (!matchPattern(getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -882,7 +881,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
||||||
// is false
|
// is false
|
||||||
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](AtenToDtypeLayoutOp op, PatternRewriter &rewriter) {
|
||||||
// The pin_memory arg should be either constant `False` or `none`.
|
// The pin_memory arg should be either constant `False` or `none`.
|
||||||
if (!op.getPinMemory().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getPinMemory().getType())) {
|
||||||
bool pinMemory;
|
bool pinMemory;
|
||||||
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
if (!matchPattern(op.getPinMemory(), m_TorchConstantBool(&pinMemory)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -891,7 +890,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
||||||
}
|
}
|
||||||
|
|
||||||
// The layout arg should be either `none` or `0` i.e. strided.
|
// The layout arg should be either `none` or `0` i.e. strided.
|
||||||
if (!op.getLayout().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getLayout().getType())) {
|
||||||
int64_t tensorLayout;
|
int64_t tensorLayout;
|
||||||
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
if (!matchPattern(op.getLayout(), m_TorchConstantInt(&tensorLayout)))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -899,7 +898,7 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns(
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op.getDevice().getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op.getDevice().getType())) {
|
||||||
// The device arg is `none`. Rewrite to to.dtype.
|
// The device arg is `none`. Rewrite to to.dtype.
|
||||||
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
|
AtenToDtypeOp toDtype = rewriter.create<AtenToDtypeOp>(
|
||||||
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
|
op.getLoc(), op.getType(), op.getSelf(), op.getDtype(),
|
||||||
|
@ -985,10 +984,10 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
|
||||||
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
auto inputType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
||||||
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
if (!inputType || !inputType.hasSizes() || inputType.getSizes().size() != 1)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
auto resType = getType().dyn_cast<BaseTensorType>();
|
auto resType = dyn_cast<BaseTensorType>(getType());
|
||||||
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (inputType != resType)
|
if (inputType != resType)
|
||||||
|
@ -1011,7 +1010,7 @@ OpFoldResult PrimsViewOfOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenDimOp::fold(FoldAdaptor adaptor) {
|
||||||
if (auto tensorType = getOperand().getType().dyn_cast<BaseTensorType>()) {
|
if (auto tensorType = dyn_cast<BaseTensorType>(getOperand().getType())) {
|
||||||
if (tensorType.hasSizes())
|
if (tensorType.hasSizes())
|
||||||
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
return IntegerAttr::get(IntegerType::get(getContext(), 64),
|
||||||
tensorType.getSizes().size());
|
tensorType.getSizes().size());
|
||||||
|
@ -1117,7 +1116,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
|
if (isa<AtenDivTensorModeOp, AtenDivScalarModeOp>(op)) {
|
||||||
if (op->getOperand(2).getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(op->getOperand(2).getType())) {
|
||||||
// None rounding mode
|
// None rounding mode
|
||||||
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
Value quotient = rewriter.create<AtenDivOp>(loc, lhs, rhs);
|
||||||
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType,
|
||||||
|
@ -1879,9 +1878,9 @@ OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||||
if (resultType && resultType.hasDtype() &&
|
if (resultType && resultType.hasDtype() &&
|
||||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||||
return getSelf();
|
return getSelf();
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
|
@ -1892,9 +1891,9 @@ OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||||
if (resultType && resultType.hasDtype() &&
|
if (resultType && resultType.hasDtype() &&
|
||||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||||
return getSelf();
|
return getSelf();
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
|
@ -1905,9 +1904,9 @@ OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||||
if (resultType && resultType.hasDtype() &&
|
if (resultType && resultType.hasDtype() &&
|
||||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
isa<mlir::IntegerType>(resultType.getDtype())) {
|
||||||
return getSelf();
|
return getSelf();
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
|
@ -1918,7 +1917,7 @@ OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenTruncOp::fold(FoldAdaptor adaptor) {
|
||||||
auto resultType = getType().dyn_cast<ValueTensorType>();
|
auto resultType = dyn_cast<ValueTensorType>(getType());
|
||||||
if (resultType && resultType.hasDtype() &&
|
if (resultType && resultType.hasDtype() &&
|
||||||
resultType.getDtype().isa<mlir::IntegerType>()) {
|
resultType.getDtype().isa<mlir::IntegerType>()) {
|
||||||
return getSelf();
|
return getSelf();
|
||||||
|
@ -1987,7 +1986,7 @@ void AtenDivScalarModeOp::getCanonicalizationPatterns(
|
||||||
void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
void AtenNumelOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) {
|
patterns.add(+[](AtenNumelOp op, PatternRewriter &rewriter) {
|
||||||
auto inputType = op.getSelf().getType().dyn_cast<BaseTensorType>();
|
auto inputType = dyn_cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!inputType || !inputType.areAllSizesKnown()) {
|
if (!inputType || !inputType.areAllSizesKnown()) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -2113,7 +2112,7 @@ traceKnownSizeTensorType(Value value, std::optional<int64_t> dim) {
|
||||||
if (!value || !value.getType().isa<BaseTensorType>())
|
if (!value || !value.getType().isa<BaseTensorType>())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto tensorType = value.getType().cast<BaseTensorType>();
|
auto tensorType = cast<BaseTensorType>(value.getType());
|
||||||
if (foundType(tensorType, dim))
|
if (foundType(tensorType, dim))
|
||||||
return tensorType;
|
return tensorType;
|
||||||
|
|
||||||
|
@ -2649,7 +2648,7 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
||||||
.dyn_cast_or_null<ElementsAttr>();
|
.dyn_cast_or_null<ElementsAttr>();
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return failure();
|
return failure();
|
||||||
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||||
NonValueTensorType returnType =
|
NonValueTensorType returnType =
|
||||||
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
||||||
tensorType.getElementType());
|
tensorType.getElementType());
|
||||||
|
@ -2691,7 +2690,7 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
||||||
.dyn_cast_or_null<ElementsAttr>();
|
.dyn_cast_or_null<ElementsAttr>();
|
||||||
if (!attr)
|
if (!attr)
|
||||||
return failure();
|
return failure();
|
||||||
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
|
RankedTensorType tensorType = cast<RankedTensorType>(attr.getType());
|
||||||
ValueTensorType returnType =
|
ValueTensorType returnType =
|
||||||
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
|
||||||
tensorType.getElementType());
|
tensorType.getElementType());
|
||||||
|
@ -2751,8 +2750,8 @@ void TensorStaticInfoCastOp::getCanonicalizationPatterns(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult CopyToNonValueTensorOp::verify() {
|
LogicalResult CopyToNonValueTensorOp::verify() {
|
||||||
auto resultType = getResult().getType().cast<BaseTensorType>();
|
auto resultType = cast<BaseTensorType>(getResult().getType());
|
||||||
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
auto operandType = cast<BaseTensorType>(getOperand().getType());
|
||||||
if (!resultType.hasSameSizesAndDtype(operandType))
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
||||||
return emitError() << "operand and result must have same sizes and dtype";
|
return emitError() << "operand and result must have same sizes and dtype";
|
||||||
return success();
|
return success();
|
||||||
|
@ -2762,7 +2761,7 @@ LogicalResult CopyToNonValueTensorOp::inferReturnTypes(
|
||||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
auto resultType = operands[0].getType().cast<ValueTensorType>();
|
auto resultType = cast<ValueTensorType>(operands[0].getType());
|
||||||
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
inferredReturnTypes.push_back(resultType.getWithoutValueSemantics());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -2778,8 +2777,8 @@ void CopyToNonValueTensorOp::getEffects(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult CopyToValueTensorOp::verify() {
|
LogicalResult CopyToValueTensorOp::verify() {
|
||||||
auto resultType = getResult().getType().cast<BaseTensorType>();
|
auto resultType = cast<BaseTensorType>(getResult().getType());
|
||||||
auto operandType = getOperand().getType().cast<BaseTensorType>();
|
auto operandType = cast<BaseTensorType>(getOperand().getType());
|
||||||
if (!resultType.hasSameSizesAndDtype(operandType))
|
if (!resultType.hasSameSizesAndDtype(operandType))
|
||||||
return emitError() << "operand and result must have same sizes and dtype";
|
return emitError() << "operand and result must have same sizes and dtype";
|
||||||
return success();
|
return success();
|
||||||
|
@ -2789,7 +2788,7 @@ LogicalResult CopyToValueTensorOp::inferReturnTypes(
|
||||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
auto resultType = operands[0].getType().cast<NonValueTensorType>();
|
auto resultType = cast<NonValueTensorType>(operands[0].getType());
|
||||||
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
inferredReturnTypes.push_back(resultType.getWithValueSemantics());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -3004,7 +3003,7 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
|
||||||
auto operandType = getSelf().getType().dyn_cast<BaseTensorType>();
|
auto operandType = dyn_cast<BaseTensorType>(getSelf().getType());
|
||||||
if (!operandType)
|
if (!operandType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
if (operandType.hasDtype()) {
|
if (operandType.hasDtype()) {
|
||||||
|
@ -3493,8 +3492,8 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) {
|
||||||
auto inType = getOperand(0).getType().dyn_cast<BaseTensorType>();
|
auto inType = dyn_cast<BaseTensorType>(getOperand(0).getType());
|
||||||
auto outType = getResult().getType().dyn_cast<BaseTensorType>();
|
auto outType = dyn_cast<BaseTensorType>(getResult().getType());
|
||||||
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() ||
|
||||||
!outType.hasDtype())
|
!outType.hasDtype())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -3534,8 +3533,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) {
|
||||||
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
IntegerAttr end = dyn_cast_or_null<IntegerAttr>(adaptor.getEnd());
|
||||||
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
IntegerAttr step = dyn_cast_or_null<IntegerAttr>(adaptor.getStep());
|
||||||
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
IntegerAttr dim = dyn_cast_or_null<IntegerAttr>(adaptor.getDim());
|
||||||
auto inType = getOperand(0).getType().dyn_cast<ValueTensorType>();
|
auto inType = dyn_cast<ValueTensorType>(getOperand(0).getType());
|
||||||
auto outType = getResult().getType().dyn_cast<ValueTensorType>();
|
auto outType = dyn_cast<ValueTensorType>(getResult().getType());
|
||||||
|
|
||||||
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
if (start && end && step && step.getValue().getSExtValue() == 1 &&
|
||||||
start.getValue().getSExtValue() == 0 &&
|
start.getValue().getSExtValue() == 0 &&
|
||||||
|
@ -3793,7 +3792,7 @@ OpFoldResult AtenSqrtIntOp::fold(FoldAdaptor adaptor) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
OpFoldResult PrimDtypeOp::fold(FoldAdaptor adaptor) {
|
||||||
BaseTensorType tensorType = getA().getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(getA().getType());
|
||||||
if (tensorType.hasDtype()) {
|
if (tensorType.hasDtype()) {
|
||||||
torch_upstream::ScalarType scalarType =
|
torch_upstream::ScalarType scalarType =
|
||||||
Torch::getScalarTypeForType(tensorType.getDtype());
|
Torch::getScalarTypeForType(tensorType.getDtype());
|
||||||
|
@ -4568,7 +4567,7 @@ LogicalResult AtenNormScalarOp::verify() {
|
||||||
// Per PyTorch docs, only float and complex types are valid for norm
|
// Per PyTorch docs, only float and complex types are valid for norm
|
||||||
// operation.
|
// operation.
|
||||||
|
|
||||||
auto inTensor = getSelf().getType().cast<BaseTensorType>();
|
auto inTensor = cast<BaseTensorType>(getSelf().getType());
|
||||||
|
|
||||||
// If no dtype is specified, it will default to a float one.
|
// If no dtype is specified, it will default to a float one.
|
||||||
if (!inTensor.hasDtype()) {
|
if (!inTensor.hasDtype()) {
|
||||||
|
@ -4605,8 +4604,8 @@ LogicalResult AtenPermuteOp::verify() {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto outType = getResult().getType().cast<BaseTensorType>();
|
auto outType = cast<BaseTensorType>(getResult().getType());
|
||||||
auto inType = getSelf().getType().cast<BaseTensorType>();
|
auto inType = cast<BaseTensorType>(getSelf().getType());
|
||||||
|
|
||||||
if (!outType.hasSizes() || !inType.hasSizes()) {
|
if (!outType.hasSizes() || !inType.hasSizes()) {
|
||||||
return success();
|
return success();
|
||||||
|
@ -4689,8 +4688,8 @@ LogicalResult AtenPermuteOp::verify() {
|
||||||
|
|
||||||
LogicalResult AtenLinalgCrossOp::verify() {
|
LogicalResult AtenLinalgCrossOp::verify() {
|
||||||
|
|
||||||
auto selfType = getSelf().getType().cast<BaseTensorType>();
|
auto selfType = cast<BaseTensorType>(getSelf().getType());
|
||||||
auto otherType = getOther().getType().cast<BaseTensorType>();
|
auto otherType = cast<BaseTensorType>(getOther().getType());
|
||||||
|
|
||||||
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
|
if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() ||
|
||||||
!otherType.hasSizes()) {
|
!otherType.hasSizes()) {
|
||||||
|
@ -4857,7 +4856,7 @@ LogicalResult GlobalSlotModuleInitializerOp::verify() {
|
||||||
|
|
||||||
// Check that initial values satisfy type bounds.
|
// Check that initial values satisfy type bounds.
|
||||||
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
|
for (int i = 0, e = initialize.getNumOperands(); i < e; ++i) {
|
||||||
auto symName = initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
auto symName = cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||||
auto initialValue = initialize.getOperand(i);
|
auto initialValue = initialize.getOperand(i);
|
||||||
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
|
auto globalSlotOp = symbolTable.lookup<GlobalSlotOp>(symName.getValue());
|
||||||
if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) {
|
if (!isValidSubtype(initialValue.getType(), globalSlotOp.getTypeBound())) {
|
||||||
|
|
|
@ -49,7 +49,7 @@ public:
|
||||||
// The incoporation of the torch.type_bound arg attr is context-dependent.
|
// The incoporation of the torch.type_bound arg attr is context-dependent.
|
||||||
|
|
||||||
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||||
if (type.value().isa<NonValueTensorType>()) {
|
if (isa<NonValueTensorType>(type.value())) {
|
||||||
auto typeBoundAttr =
|
auto typeBoundAttr =
|
||||||
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
func.getArgAttrOfType<TypeAttr>(type.index(), typeBoundIdent);
|
||||||
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
|
Type bound = typeBoundAttr ? typeBoundAttr.getValue() : Type();
|
||||||
|
@ -61,7 +61,7 @@ public:
|
||||||
? typeBoundAttr.getValue()
|
? typeBoundAttr.getValue()
|
||||||
: type.value());
|
: type.value());
|
||||||
continue;
|
continue;
|
||||||
} else if (auto none = type.value().dyn_cast<Torch::NoneType>()) {
|
} else if (auto none = dyn_cast<Torch::NoneType>(type.value())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// TODO: add tuple type.
|
// TODO: add tuple type.
|
||||||
|
@ -111,7 +111,7 @@ public:
|
||||||
|
|
||||||
SmallVector<Value> newOperands;
|
SmallVector<Value> newOperands;
|
||||||
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
|
for (auto operand : llvm::enumerate(adaptor.getOperands())) {
|
||||||
if (operand.value().getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(operand.value().getType()))
|
||||||
continue;
|
continue;
|
||||||
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
auto it = typeBoundMap.find({call.getCallee(), operand.index()});
|
||||||
if (it != typeBoundMap.end()) {
|
if (it != typeBoundMap.end()) {
|
||||||
|
@ -167,9 +167,9 @@ public:
|
||||||
for (auto operand : adaptor.getOperands()) {
|
for (auto operand : adaptor.getOperands()) {
|
||||||
if (!operand)
|
if (!operand)
|
||||||
continue;
|
continue;
|
||||||
if (operand.getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(operand.getType()))
|
||||||
continue;
|
continue;
|
||||||
if (auto tuple = operand.getType().dyn_cast<Torch::TupleType>()) {
|
if (auto tuple = dyn_cast<Torch::TupleType>(operand.getType())) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
|
for (auto en : llvm::enumerate(tuple.getContainedTypes())) {
|
||||||
auto i = rewriter.create<ConstantIntOp>(
|
auto i = rewriter.create<ConstantIntOp>(
|
||||||
|
@ -207,7 +207,7 @@ static LogicalResult adjustCallingConventions(func::FuncOp func,
|
||||||
[](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs,
|
[](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs,
|
||||||
Location loc) -> Value {
|
Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<BaseTensorType>());
|
assert(isa<BaseTensorType>(inputs[0].getType()));
|
||||||
return copyTensorToType(builder, loc, type, inputs[0]);
|
return copyTensorToType(builder, loc, type, inputs[0]);
|
||||||
});
|
});
|
||||||
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
|
patterns.add<AdjustCallingConventionForFunc>(typeConverter, context);
|
||||||
|
|
|
@ -29,7 +29,7 @@ using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
// Helper function to check whether the `dtype` is None or Float type.
|
// Helper function to check whether the `dtype` is None or Float type.
|
||||||
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
static bool isNoneOrFloatDtype(MLIRContext *context, Value dtype) {
|
||||||
if (dtype.getType().isa<Torch::NoneType>())
|
if (isa<Torch::NoneType>(dtype.getType()))
|
||||||
return true;
|
return true;
|
||||||
int64_t dtypeInt;
|
int64_t dtypeInt;
|
||||||
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
if (!matchPattern(dtype, m_TorchConstantInt(&dtypeInt)))
|
||||||
|
@ -87,7 +87,7 @@ static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||||
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
Value dtype = rewriter.create<ConstantNoneOp>(loc);
|
||||||
Type resultType = computeReductionType(
|
Type resultType = computeReductionType(
|
||||||
rewriter, op, input.getType().cast<BaseTensorType>(), dim, keepDim);
|
rewriter, op, cast<BaseTensorType>(input.getType()), dim, keepDim);
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
return rewriter.create<AtenSumDimIntListOp>(loc, resultType, input, dimList,
|
||||||
|
@ -100,7 +100,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc,
|
||||||
bool keepDim) {
|
bool keepDim) {
|
||||||
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
Value keepDimCst = rewriter.create<ConstantBoolOp>(loc, keepDim);
|
||||||
BaseTensorType valueType =
|
BaseTensorType valueType =
|
||||||
computeReductionType(rewriter, op, input.getType().cast<BaseTensorType>(),
|
computeReductionType(rewriter, op, cast<BaseTensorType>(input.getType()),
|
||||||
dim, keepDim)
|
dim, keepDim)
|
||||||
.cast<BaseTensorType>();
|
.cast<BaseTensorType>();
|
||||||
if (!valueType)
|
if (!valueType)
|
||||||
|
@ -296,7 +296,7 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
int64_t contractingDimsLength,
|
int64_t contractingDimsLength,
|
||||||
int64_t otherDimsLength,
|
int64_t otherDimsLength,
|
||||||
int64_t reduceDimsLength, bool isLhs) {
|
int64_t reduceDimsLength, bool isLhs) {
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
auto inputRank = batchDimsLength + contractingDimsLength + otherDimsLength +
|
||||||
reduceDimsLength;
|
reduceDimsLength;
|
||||||
SmallVector<Value> inputShapeTensor;
|
SmallVector<Value> inputShapeTensor;
|
||||||
|
@ -415,7 +415,7 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
SmallVector<char> &contractingDims,
|
SmallVector<char> &contractingDims,
|
||||||
SmallVector<char> &otherDims,
|
SmallVector<char> &otherDims,
|
||||||
SmallVector<char> &reduceDims, bool isLhs) {
|
SmallVector<char> &reduceDims, bool isLhs) {
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
llvm::SmallDenseMap<char, int64_t> dimTokenMap;
|
||||||
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
for (size_t idx = 0; idx < dimTokens.size(); ++idx) {
|
||||||
dimTokenMap[dimTokens[idx]] = idx;
|
dimTokenMap[dimTokens[idx]] = idx;
|
||||||
|
@ -451,8 +451,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc,
|
||||||
Value &result,
|
Value &result,
|
||||||
SmallVector<char> &resultTokens,
|
SmallVector<char> &resultTokens,
|
||||||
SmallVector<char> &finalResultTokens) {
|
SmallVector<char> &finalResultTokens) {
|
||||||
auto lhsType = lhs.getType().cast<BaseTensorType>();
|
auto lhsType = cast<BaseTensorType>(lhs.getType());
|
||||||
auto rhsType = rhs.getType().cast<BaseTensorType>();
|
auto rhsType = cast<BaseTensorType>(rhs.getType());
|
||||||
|
|
||||||
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype()
|
||||||
: rhsType.getOptionalDtype();
|
: rhsType.getOptionalDtype();
|
||||||
|
@ -562,7 +562,7 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter,
|
||||||
Value input,
|
Value input,
|
||||||
SmallVector<char> &inputTokens,
|
SmallVector<char> &inputTokens,
|
||||||
SmallVector<char> &outTokens) {
|
SmallVector<char> &outTokens) {
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
|
|
||||||
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
|
llvm::SmallDenseSet<char> outTokenSet(outTokens.begin(), outTokens.end());
|
||||||
SmallVector<int64_t> sumDims;
|
SmallVector<int64_t> sumDims;
|
||||||
|
@ -643,7 +643,7 @@ public:
|
||||||
op, "Expected a constant boolean value for keepDim");
|
op, "Expected a constant boolean value for keepDim");
|
||||||
|
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
auto inputTy = input.getType().dyn_cast<Torch::ValueTensorType>();
|
auto inputTy = dyn_cast<Torch::ValueTensorType>(input.getType());
|
||||||
if (!inputTy || !inputTy.hasSizes()) {
|
if (!inputTy || !inputTy.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Expected input type having sizes");
|
"Expected input type having sizes");
|
||||||
|
@ -677,7 +677,7 @@ public:
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
if (!inputType.hasSizes() || !inputType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
|
return rewriter.notifyMatchFailure(op, "should have shape and dtype");
|
||||||
}
|
}
|
||||||
|
@ -764,7 +764,7 @@ public:
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
|
|
||||||
auto resultTy = op.getType().cast<BaseTensorType>();
|
auto resultTy = cast<BaseTensorType>(op.getType());
|
||||||
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
|
if (!resultTy.hasSizes() || !resultTy.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have sizes and dtype");
|
op, "expected result type to have sizes and dtype");
|
||||||
|
@ -785,8 +785,8 @@ public:
|
||||||
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
rewriter.create<AtenAddIntOp>(loc, one.getType(), start, one);
|
||||||
Value slice = rewriter.create<AtenSliceTensorOp>(
|
Value slice = rewriter.create<AtenSliceTensorOp>(
|
||||||
loc,
|
loc,
|
||||||
computeReductionType(rewriter, op,
|
computeReductionType(rewriter, op, cast<BaseTensorType>(self.getType()),
|
||||||
self.getType().cast<BaseTensorType>(), dim,
|
dim,
|
||||||
/*keepDim=*/true),
|
/*keepDim=*/true),
|
||||||
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
op.getSelf(), dim, start, startPlusOne, /*step=*/one);
|
||||||
|
|
||||||
|
@ -988,7 +988,7 @@ public:
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
|
|
||||||
auto outputTy = op.getType().dyn_cast<Torch::ValueTensorType>();
|
auto outputTy = dyn_cast<Torch::ValueTensorType>(op.getType());
|
||||||
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
|
if (!outputTy || !outputTy.hasSizes() || !outputTy.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Expected output type having sizes and dtype");
|
op, "Expected output type having sizes and dtype");
|
||||||
|
@ -1069,7 +1069,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"unimplemented: m must be constant");
|
"unimplemented: m must be constant");
|
||||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
auto outType = op.getType().dyn_cast<BaseTensorType>();
|
auto outType = dyn_cast<BaseTensorType>(op.getType());
|
||||||
if (!outType)
|
if (!outType)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Only tensor types input are currently supported");
|
op, "Only tensor types input are currently supported");
|
||||||
|
@ -1111,13 +1111,13 @@ public:
|
||||||
|
|
||||||
// compare unsqueezed input with boundaries
|
// compare unsqueezed input with boundaries
|
||||||
auto eqType = ValueTensorType::get(
|
auto eqType = ValueTensorType::get(
|
||||||
context, op.getType().cast<BaseTensorType>().getSizes(),
|
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||||
IntegerType::get(context, 1));
|
IntegerType::get(context, 1));
|
||||||
Value eqTensor =
|
Value eqTensor =
|
||||||
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
|
rewriter.create<AtenEqTensorOp>(loc, eqType, unsqzRangeN, rangeM);
|
||||||
|
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::BoolType>()) {
|
if (isa<Torch::BoolType>(dtype.getType())) {
|
||||||
rewriter.replaceOp(op, eqTensor);
|
rewriter.replaceOp(op, eqTensor);
|
||||||
return success();
|
return success();
|
||||||
} else {
|
} else {
|
||||||
|
@ -1210,7 +1210,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
// TODO: Handle non value tensor type operands.
|
// TODO: Handle non value tensor type operands.
|
||||||
if (!input.getType().isa<ValueTensorType>()) {
|
if (!isa<ValueTensorType>(input.getType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: only value tensor type operands are supported");
|
op, "unimplemented: only value tensor type operands are supported");
|
||||||
}
|
}
|
||||||
|
@ -1248,7 +1248,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
auto allTensorHasSizes = [](Value tensor) {
|
auto allTensorHasSizes = [](Value tensor) {
|
||||||
auto type = tensor.getType().dyn_cast<BaseTensorType>();
|
auto type = dyn_cast<BaseTensorType>(tensor.getType());
|
||||||
if (!type || !type.hasSizes())
|
if (!type || !type.hasSizes())
|
||||||
return false;
|
return false;
|
||||||
return true;
|
return true;
|
||||||
|
@ -1267,7 +1267,7 @@ public:
|
||||||
if (equation.find("...") != std::string::npos) {
|
if (equation.find("...") != std::string::npos) {
|
||||||
SmallVector<int64_t> inputRanks;
|
SmallVector<int64_t> inputRanks;
|
||||||
for (Value tensor : inputTensors) {
|
for (Value tensor : inputTensors) {
|
||||||
auto type = tensor.getType().cast<BaseTensorType>();
|
auto type = cast<BaseTensorType>(tensor.getType());
|
||||||
inputRanks.push_back(type.getSizes().size());
|
inputRanks.push_back(type.getSizes().size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1332,10 +1332,10 @@ public:
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||||
Value one =
|
Value one =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
BaseTensorType inputType = self.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(self.getType());
|
||||||
|
|
||||||
Value output = op.getResult();
|
Value output = op.getResult();
|
||||||
BaseTensorType outputType = output.getType().cast<BaseTensorType>();
|
BaseTensorType outputType = cast<BaseTensorType>(output.getType());
|
||||||
|
|
||||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||||
int64_t diagonalSize = std::min(inputShape[0], inputShape[1]);
|
int64_t diagonalSize = std::min(inputShape[0], inputShape[1]);
|
||||||
|
@ -1399,7 +1399,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
LogicalResult matchAndRewrite(AtenSoftmaxIntOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resultTensorType.hasDtype()) {
|
if (!resultTensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
|
@ -1410,7 +1410,7 @@ public:
|
||||||
"Only support floating-point type");
|
"Only support floating-point type");
|
||||||
|
|
||||||
// If `dtype` arg is non-none then convert the input to `dtype`.
|
// If `dtype` arg is non-none then convert the input to `dtype`.
|
||||||
if (!op.getDtype().getType().isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(op.getDtype().getType())) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
|
||||||
|
@ -1440,15 +1440,15 @@ public:
|
||||||
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
LogicalResult matchAndRewrite(Aten_SoftmaxOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
bool halfToFloat;
|
bool halfToFloat;
|
||||||
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
if (!matchPattern(op.getHalfToFloat(), m_TorchConstantBool(&halfToFloat)))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Expected a boolean value for half_to_float");
|
op, "Expected a boolean value for half_to_float");
|
||||||
|
|
||||||
BaseTensorType resultTensorType = op.getType().cast<BaseTensorType>();
|
BaseTensorType resultTensorType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resultTensorType.hasDtype()) {
|
if (!resultTensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
|
@ -1500,8 +1500,8 @@ public:
|
||||||
Value output = op.getOutput();
|
Value output = op.getOutput();
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
|
|
||||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
Value newGrad =
|
Value newGrad =
|
||||||
|
@ -1536,8 +1536,8 @@ public:
|
||||||
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
// Since, dTanh(x) = (1 - tanh(x)^2) hence, dOutput = (1 - output^2).
|
||||||
Value output = op.getOutput();
|
Value output = op.getOutput();
|
||||||
|
|
||||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
Value tanhSquare =
|
Value tanhSquare =
|
||||||
|
@ -1567,8 +1567,8 @@ public:
|
||||||
Value output = op.getOutput();
|
Value output = op.getOutput();
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
|
|
||||||
BaseTensorType tensorType = gradOutput.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(gradOutput.getType());
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
Value expOut = rewriter.create<AtenExpOp>(loc, tensorType, output);
|
||||||
|
@ -1650,8 +1650,8 @@ public:
|
||||||
Value keepDim = op.getKeepdim();
|
Value keepDim = op.getKeepdim();
|
||||||
Value result = op.getResult();
|
Value result = op.getResult();
|
||||||
|
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
BaseTensorType indicesTensorType = result.getType().cast<BaseTensorType>();
|
BaseTensorType indicesTensorType = cast<BaseTensorType>(result.getType());
|
||||||
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
std::optional<unsigned> maybeInputRank = getTensorRank(input);
|
||||||
if (!maybeInputRank) {
|
if (!maybeInputRank) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -1670,7 +1670,7 @@ public:
|
||||||
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
// `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so
|
||||||
// first the input tensor is flattened to 1d tensor and then the reduction
|
// first the input tensor is flattened to 1d tensor and then the reduction
|
||||||
// happens on the 0th dimension.
|
// happens on the 0th dimension.
|
||||||
if (dim.getType().isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(dim.getType())) {
|
||||||
BaseTensorType flattenType =
|
BaseTensorType flattenType =
|
||||||
inputType
|
inputType
|
||||||
.getWithSizesAndDtype({kUnknownSize},
|
.getWithSizesAndDtype({kUnknownSize},
|
||||||
|
@ -1720,7 +1720,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: input must have known sizes");
|
op, "unimplemented: input must have known sizes");
|
||||||
|
@ -1728,7 +1728,7 @@ public:
|
||||||
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
ArrayRef<int64_t> inputShape = inputType.getSizes();
|
||||||
|
|
||||||
Value boundaries = op.getBoundaries();
|
Value boundaries = op.getBoundaries();
|
||||||
auto boundariesType = boundaries.getType().cast<BaseTensorType>();
|
auto boundariesType = cast<BaseTensorType>(boundaries.getType());
|
||||||
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
if (!boundariesType.hasSizes() || boundariesType.getSizes().size() != 1) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"unimplemented: boundaries must have "
|
"unimplemented: boundaries must have "
|
||||||
|
@ -1827,7 +1827,7 @@ static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
|
||||||
Value xMax =
|
Value xMax =
|
||||||
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true);
|
||||||
if (!xMax)
|
if (!xMax)
|
||||||
|
@ -1856,12 +1856,12 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
LogicalResult matchAndRewrite(AtenLogSoftmaxIntOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
if (!op.getDtype().getType().isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getDtype().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented non-None dtype for log_softmax");
|
op, "Unimplemented non-None dtype for log_softmax");
|
||||||
|
|
||||||
BaseTensorType tensorType = self.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!tensorType.hasDtype() || !tensorType.getDtype().isa<mlir::FloatType>())
|
if (!tensorType.hasDtype() || !isa<mlir::FloatType>(tensorType.getDtype()))
|
||||||
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
return rewriter.notifyMatchFailure(op, "Only support floating type");
|
||||||
|
|
||||||
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
|
Value logSoftmax = getLogSoftmaxResult(op, rewriter);
|
||||||
|
@ -1974,7 +1974,7 @@ public:
|
||||||
Type opType = op.getType();
|
Type opType = op.getType();
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
|
|
||||||
auto resType = self.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(self.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -2088,7 +2088,7 @@ public:
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value inValue = op.getSelf();
|
Value inValue = op.getSelf();
|
||||||
auto inType = inValue.getType().cast<BaseTensorType>();
|
auto inType = cast<BaseTensorType>(inValue.getType());
|
||||||
auto maybeSizes = inType.getOptionalSizes();
|
auto maybeSizes = inType.getOptionalSizes();
|
||||||
if (!maybeSizes) {
|
if (!maybeSizes) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -2234,7 +2234,7 @@ public:
|
||||||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||||
Value input) {
|
Value input) {
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
|
|
||||||
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
||||||
Value cst6 =
|
Value cst6 =
|
||||||
|
@ -2252,7 +2252,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
LogicalResult matchAndRewrite(AtenRelu6Op op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -2304,7 +2304,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value negativeSlope = op.getNegativeSlope();
|
Value negativeSlope = op.getNegativeSlope();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -2341,7 +2341,7 @@ public:
|
||||||
Value gradOutput = op.getGradOutput();
|
Value gradOutput = op.getGradOutput();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value negativeSlope = op.getNegativeSlope();
|
Value negativeSlope = op.getNegativeSlope();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -2382,7 +2382,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value weight = op.getWeight();
|
Value weight = op.getWeight();
|
||||||
auto resType = op.getType().cast<ValueTensorType>();
|
auto resType = cast<ValueTensorType>(op.getType());
|
||||||
auto boolTensorType = rewriter.getType<ValueTensorType>(
|
auto boolTensorType = rewriter.getType<ValueTensorType>(
|
||||||
resType.getOptionalSizes(), rewriter.getI1Type());
|
resType.getOptionalSizes(), rewriter.getI1Type());
|
||||||
Value zero =
|
Value zero =
|
||||||
|
@ -2408,14 +2408,14 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenLerpScalarOp op,
|
LogicalResult matchAndRewrite(AtenLerpScalarOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
Value cstOne =
|
Value cstOne =
|
||||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||||
auto start = op.getSelf();
|
auto start = op.getSelf();
|
||||||
auto inputType = start.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(start.getType());
|
||||||
|
|
||||||
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
|
auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
|
||||||
start, cstOne);
|
start, cstOne);
|
||||||
|
@ -2442,7 +2442,7 @@ public:
|
||||||
Value alpha = op.getAlpha();
|
Value alpha = op.getAlpha();
|
||||||
Value scale = op.getScale();
|
Value scale = op.getScale();
|
||||||
Value inputScale = op.getInputScale();
|
Value inputScale = op.getInputScale();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -2486,7 +2486,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -2578,7 +2578,7 @@ public:
|
||||||
}
|
}
|
||||||
// Ensure all tensors have known sizes
|
// Ensure all tensors have known sizes
|
||||||
for (Value tensor : tensors) {
|
for (Value tensor : tensors) {
|
||||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
||||||
if (!tensorType.hasSizes()) {
|
if (!tensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: one tensor does not have known sizes");
|
op, "unimplemented: one tensor does not have known sizes");
|
||||||
|
@ -2596,8 +2596,9 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
Type listElemType =
|
Type listElemType =
|
||||||
op.getType().cast<BaseTensorType>().getWithSizesAndDtype(
|
cast<BaseTensorType>(op.getType())
|
||||||
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
|
.getWithSizesAndDtype(
|
||||||
|
/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr);
|
||||||
Type listType = Torch::ListType::get(listElemType);
|
Type listType = Torch::ListType::get(listElemType);
|
||||||
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
|
Value unsqueezedTensorList = rewriter.create<PrimListConstructOp>(
|
||||||
op.getLoc(), listType, unsqueezedTensors);
|
op.getLoc(), listType, unsqueezedTensors);
|
||||||
|
@ -2635,7 +2636,7 @@ public:
|
||||||
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
Value constOne = rewriter.create<Torch::ConstantIntOp>(
|
||||||
loc, rewriter.getI64IntegerAttr(1));
|
loc, rewriter.getI64IntegerAttr(1));
|
||||||
auto self = op.getSelf();
|
auto self = op.getSelf();
|
||||||
auto selfTy = self.getType().cast<BaseTensorType>();
|
auto selfTy = cast<BaseTensorType>(self.getType());
|
||||||
// roll(input, shift, dim) = cat({
|
// roll(input, shift, dim) = cat({
|
||||||
// slice(input, dim, -shift, none),
|
// slice(input, dim, -shift, none),
|
||||||
// slice(input, dim, 0, -shift)}, dim)
|
// slice(input, dim, 0, -shift)}, dim)
|
||||||
|
@ -2817,7 +2818,7 @@ public:
|
||||||
if (!selfTy.hasSizes())
|
if (!selfTy.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: no implementation for rankless tensor");
|
op, "Unimplemented: no implementation for rankless tensor");
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasSizes())
|
if (!resType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Unimplemented: no implementation for rankless tensor");
|
op, "Unimplemented: no implementation for rankless tensor");
|
||||||
|
@ -2968,7 +2969,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
|
BaseTensorType outputTensorType = cast<BaseTensorType>(op.getType());
|
||||||
if (!outputTensorType.hasSizes())
|
if (!outputTensorType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: output must have known sizes");
|
op, "unimplemented: output must have known sizes");
|
||||||
|
@ -2977,7 +2978,7 @@ public:
|
||||||
if (!maybeRank)
|
if (!maybeRank)
|
||||||
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
|
||||||
unsigned inputRank = *maybeRank;
|
unsigned inputRank = *maybeRank;
|
||||||
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
auto inputTensorType = cast<Torch::ValueTensorType>(self.getType());
|
||||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"Expected input type having sizes");
|
"Expected input type having sizes");
|
||||||
|
@ -3077,7 +3078,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
LogicalResult matchAndRewrite(AtenWhereScalarOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -3100,7 +3101,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
LogicalResult matchAndRewrite(AtenWhereScalarOtherOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -3122,7 +3123,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
LogicalResult matchAndRewrite(AtenWhereScalarSelfOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -3186,7 +3187,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
LogicalResult matchAndRewrite(AtenMaskedFillScalarOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -3227,7 +3228,7 @@ static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter,
|
||||||
int64_t dimB,
|
int64_t dimB,
|
||||||
Value &transposed) {
|
Value &transposed) {
|
||||||
Type transposedType;
|
Type transposedType;
|
||||||
if (failed(getTransposedType(input.getType().cast<Torch::BaseTensorType>(),
|
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
|
||||||
dimA, dimB, transposedType)))
|
dimA, dimB, transposedType)))
|
||||||
return failure();
|
return failure();
|
||||||
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
@ -3578,7 +3579,7 @@ public:
|
||||||
op.getGroups(), op.getDilation());
|
op.getGroups(), op.getDilation());
|
||||||
|
|
||||||
Type transposedType;
|
Type transposedType;
|
||||||
if (failed(getTransposedType(input.getType().cast<BaseTensorType>(), 0, 1,
|
if (failed(getTransposedType(cast<BaseTensorType>(input.getType()), 0, 1,
|
||||||
transposedType)))
|
transposedType)))
|
||||||
return failure();
|
return failure();
|
||||||
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
Value inputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||||
|
@ -3605,7 +3606,7 @@ public:
|
||||||
ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2],
|
ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2],
|
||||||
gradOutputSize[3]});
|
gradOutputSize[3]});
|
||||||
|
|
||||||
BaseTensorType gradOutputTy = gradOutput.getType().cast<BaseTensorType>();
|
BaseTensorType gradOutputTy = cast<BaseTensorType>(gradOutput.getType());
|
||||||
if (!gradOutputTy.hasSizes())
|
if (!gradOutputTy.hasSizes())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<int64_t> gradOutputSizesInt(gradOutputTy.getSizes());
|
SmallVector<int64_t> gradOutputSizesInt(gradOutputTy.getSizes());
|
||||||
|
@ -3625,7 +3626,7 @@ public:
|
||||||
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList);
|
||||||
|
|
||||||
BaseTensorType inputTransposedTy =
|
BaseTensorType inputTransposedTy =
|
||||||
inputTransposed.getType().cast<BaseTensorType>();
|
cast<BaseTensorType>(inputTransposed.getType());
|
||||||
if (!inputTransposedTy.hasSizes())
|
if (!inputTransposedTy.hasSizes())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<int64_t> inputTransposedSizesInt(
|
SmallVector<int64_t> inputTransposedSizesInt(
|
||||||
|
@ -3660,7 +3661,7 @@ public:
|
||||||
/*dilation=*/op.getStride(), op.getTransposed(),
|
/*dilation=*/op.getStride(), op.getTransposed(),
|
||||||
op.getOutputPadding(), numGroup);
|
op.getOutputPadding(), numGroup);
|
||||||
|
|
||||||
BaseTensorType weightTy = weight.getType().cast<BaseTensorType>();
|
BaseTensorType weightTy = cast<BaseTensorType>(weight.getType());
|
||||||
if (!weightTy.hasSizes())
|
if (!weightTy.hasSizes())
|
||||||
return failure();
|
return failure();
|
||||||
SmallVector<int64_t> weightSizes(weightTy.getSizes());
|
SmallVector<int64_t> weightSizes(weightTy.getSizes());
|
||||||
|
@ -3707,7 +3708,7 @@ public:
|
||||||
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
gradWeight = rewriter.create<Torch::AtenViewOp>(
|
||||||
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList);
|
||||||
|
|
||||||
gradWeightTy = gradWeight.getType().cast<BaseTensorType>();
|
gradWeightTy = cast<BaseTensorType>(gradWeight.getType());
|
||||||
SmallVector<int64_t, 5> gradWeightDimsOrder =
|
SmallVector<int64_t, 5> gradWeightDimsOrder =
|
||||||
computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size());
|
computeDimsOrderForMoveDim(0, 2, gradWeightViewShapeInt.size());
|
||||||
SmallVector<int64_t, 5> gradWeightMoveDimShape;
|
SmallVector<int64_t, 5> gradWeightMoveDimShape;
|
||||||
|
@ -3733,7 +3734,7 @@ public:
|
||||||
/*keepdim=*/cstFalse,
|
/*keepdim=*/cstFalse,
|
||||||
/*dtype=*/cstNone);
|
/*dtype=*/cstNone);
|
||||||
} else {
|
} else {
|
||||||
if (failed(getTransposedType(gradOutput.getType().cast<BaseTensorType>(),
|
if (failed(getTransposedType(cast<BaseTensorType>(gradOutput.getType()),
|
||||||
0, 1, transposedType)))
|
0, 1, transposedType)))
|
||||||
return failure();
|
return failure();
|
||||||
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
Value gradOutputTransposed = rewriter.create<Torch::AtenTransposeIntOp>(
|
||||||
|
@ -3792,7 +3793,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Handle integer type operands.
|
// TODO: Handle integer type operands.
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "unimplemented: non-floating point dtype");
|
op, "unimplemented: non-floating point dtype");
|
||||||
|
@ -3821,7 +3822,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value output = op.getResult();
|
Value output = op.getResult();
|
||||||
BaseTensorType outputTensorType = output.getType().cast<BaseTensorType>();
|
BaseTensorType outputTensorType = cast<BaseTensorType>(output.getType());
|
||||||
Value sum =
|
Value sum =
|
||||||
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype());
|
rewriter.create<AtenSumOp>(loc, outputTensorType, input, op.getDtype());
|
||||||
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
Value numTensorElements = rewriter.create<AtenNumelOp>(loc, input);
|
||||||
|
@ -3854,7 +3855,7 @@ public:
|
||||||
Type outputType = op.getType();
|
Type outputType = op.getType();
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
|
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>() ||
|
||||||
!isNoneOrFloatDtype(context, dtype)) {
|
!isNoneOrFloatDtype(context, dtype)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -3944,7 +3945,7 @@ public:
|
||||||
rewriter.replaceOp(op, input);
|
rewriter.replaceOp(op, input);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support floating type input for training mode");
|
op, "only support floating type input for training mode");
|
||||||
|
@ -3992,7 +3993,7 @@ public:
|
||||||
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
|
rewriter.replaceOp(op, ArrayRef<Value>{input, trueMask});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support floating type input for training mode");
|
op, "only support floating type input for training mode");
|
||||||
|
@ -4029,7 +4030,7 @@ public:
|
||||||
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
|
||||||
}
|
}
|
||||||
unsigned inputRank = *maybeInputRank;
|
unsigned inputRank = *maybeInputRank;
|
||||||
BaseTensorType rank0FloatTensorTy = op.getType().cast<BaseTensorType>();
|
BaseTensorType rank0FloatTensorTy = cast<BaseTensorType>(op.getType());
|
||||||
if (!rank0FloatTensorTy.hasSizes() ||
|
if (!rank0FloatTensorTy.hasSizes() ||
|
||||||
rank0FloatTensorTy.getSizes().size() != 0) {
|
rank0FloatTensorTy.getSizes().size() != 0) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4060,7 +4061,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenStdOp op,
|
LogicalResult matchAndRewrite(AtenStdOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||||
if (!inputTensorTy.hasDtype() ||
|
if (!inputTensorTy.hasDtype() ||
|
||||||
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
!inputTensorTy.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
|
@ -4084,7 +4085,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
|
|
||||||
Value inputTimesBeta =
|
Value inputTimesBeta =
|
||||||
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta());
|
rewriter.create<AtenMulScalarOp>(loc, inputType, input, op.getBeta());
|
||||||
|
@ -4116,7 +4117,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenStdDimOp op,
|
LogicalResult matchAndRewrite(AtenStdDimOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!inputTensorType.hasDtype() ||
|
if (!inputTensorType.hasDtype() ||
|
||||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4141,7 +4142,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
|
LogicalResult matchAndRewrite(AtenStdCorrectionOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
BaseTensorType inputTensorType = self.getType().cast<BaseTensorType>();
|
BaseTensorType inputTensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!inputTensorType.hasDtype() ||
|
if (!inputTensorType.hasDtype() ||
|
||||||
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
!inputTensorType.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -4167,8 +4168,8 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -4208,8 +4209,8 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
auto resType = op.getType().cast<BaseTensorType>();
|
auto resType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resType.hasDtype()) {
|
if (!resType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
return rewriter.notifyMatchFailure(op, "result should have dtype");
|
||||||
}
|
}
|
||||||
|
@ -4235,7 +4236,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Type resultType = op.getType();
|
Type resultType = op.getType();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support floating-point type");
|
"only support floating-point type");
|
||||||
|
@ -4268,8 +4269,8 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter,
|
||||||
Operation *op, Location loc,
|
Operation *op, Location loc,
|
||||||
Value input, Value prob,
|
Value input, Value prob,
|
||||||
Value &output) {
|
Value &output) {
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
auto probType = prob.getType().cast<BaseTensorType>();
|
auto probType = cast<BaseTensorType>(prob.getType());
|
||||||
// Both the `input` and `prob` must be ranked tensors.
|
// Both the `input` and `prob` must be ranked tensors.
|
||||||
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
if (!inputType.hasSizes() || !inputType.hasDtype() || !probType.hasSizes() ||
|
||||||
!probType.hasDtype()) {
|
!probType.hasDtype()) {
|
||||||
|
@ -4338,12 +4339,12 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Value p = op.getP();
|
Value p = op.getP();
|
||||||
if (!op.getGenerator().getType().template isa<Torch::NoneType>())
|
if (!isa<Torch::NoneType>(op.getGenerator().getType()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "The generator has to be None because only global default "
|
op, "The generator has to be None because only global default "
|
||||||
"generator is supported");
|
"generator is supported");
|
||||||
|
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
SmallVector<int64_t> empty;
|
SmallVector<int64_t> empty;
|
||||||
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
|
Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty),
|
||||||
rewriter.getF64Type());
|
rewriter.getF64Type());
|
||||||
|
@ -4485,7 +4486,7 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
|
|
||||||
auto input = op.getInput().getType().cast<BaseTensorType>();
|
auto input = cast<BaseTensorType>(op.getInput().getType());
|
||||||
if (!input.hasSizes())
|
if (!input.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "input tensor should have known sizes.");
|
op, "input tensor should have known sizes.");
|
||||||
|
@ -4518,7 +4519,7 @@ class DecomposeAtenInstanceNormOp
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto context = op.getContext();
|
auto context = op.getContext();
|
||||||
|
|
||||||
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
auto inputTy = cast<BaseTensorType>(op.getInput().getType());
|
||||||
int64_t inputRank = inputTy.getSizes().size();
|
int64_t inputRank = inputTy.getSizes().size();
|
||||||
SmallVector<int64_t> reducedShape(inputTy.getSizes());
|
SmallVector<int64_t> reducedShape(inputTy.getSizes());
|
||||||
SmallVector<int64_t> reduceDimInts;
|
SmallVector<int64_t> reduceDimInts;
|
||||||
|
@ -4583,7 +4584,7 @@ class DecomposeAtenInstanceNormOp
|
||||||
loc, op.getResult().getType(), inputNormalized);
|
loc, op.getResult().getType(), inputNormalized);
|
||||||
|
|
||||||
Value weight = op.getWeight();
|
Value weight = op.getWeight();
|
||||||
auto weightTy = weight.getType().cast<BaseTensorType>();
|
auto weightTy = cast<BaseTensorType>(weight.getType());
|
||||||
dtype = weightTy.getOptionalDtype();
|
dtype = weightTy.getOptionalDtype();
|
||||||
|
|
||||||
SmallVector<int64_t> weightShape(weightTy.getSizes());
|
SmallVector<int64_t> weightShape(weightTy.getSizes());
|
||||||
|
@ -4610,7 +4611,7 @@ class DecomposeAtenInstanceNormOp
|
||||||
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
|
rewriter.create<AtenExpandAsOp>(loc, inputTy, weight, op.getInput());
|
||||||
|
|
||||||
Value bias = op.getBias();
|
Value bias = op.getBias();
|
||||||
auto biasTy = bias.getType().cast<BaseTensorType>();
|
auto biasTy = cast<BaseTensorType>(bias.getType());
|
||||||
dtype = biasTy.getOptionalDtype();
|
dtype = biasTy.getOptionalDtype();
|
||||||
|
|
||||||
SmallVector<int64_t> biasShape(biasTy.getSizes());
|
SmallVector<int64_t> biasShape(biasTy.getSizes());
|
||||||
|
@ -4654,7 +4655,7 @@ class DecomposeAtenNativeLayerNormOp
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto context = op.getContext();
|
auto context = op.getContext();
|
||||||
|
|
||||||
auto inputTy = op.getInput().getType().cast<BaseTensorType>();
|
auto inputTy = cast<BaseTensorType>(op.getInput().getType());
|
||||||
if (!inputTy.hasSizes())
|
if (!inputTy.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "input tensor should have known sizes.");
|
op, "input tensor should have known sizes.");
|
||||||
|
@ -4889,10 +4890,10 @@ class DecomposeAtenNativeGroupNormOp
|
||||||
Value eps = op.getEps();
|
Value eps = op.getEps();
|
||||||
|
|
||||||
// Check the rank of the input/outputs tensor.
|
// Check the rank of the input/outputs tensor.
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
auto outputType = op.getResult0().getType().cast<BaseTensorType>();
|
auto outputType = cast<BaseTensorType>(op.getResult0().getType());
|
||||||
auto meanType = op.getResult1().getType().cast<BaseTensorType>();
|
auto meanType = cast<BaseTensorType>(op.getResult1().getType());
|
||||||
auto rsqrtVarType = op.getResult2().getType().cast<BaseTensorType>();
|
auto rsqrtVarType = cast<BaseTensorType>(op.getResult2().getType());
|
||||||
if (!inputType.hasSizes() || !outputType.hasSizes() ||
|
if (!inputType.hasSizes() || !outputType.hasSizes() ||
|
||||||
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
|
!meanType.hasSizes() || !rsqrtVarType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -5059,8 +5060,8 @@ class DecomposeAtenNativeBatchNormOp
|
||||||
|
|
||||||
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
SmallVector<int64_t> runningStatsShapeInt(inputRank, 1);
|
||||||
runningStatsShapeInt[1] =
|
runningStatsShapeInt[1] =
|
||||||
runningMean.getType().cast<BaseTensorType>().getSizes()[0];
|
cast<BaseTensorType>(runningMean.getType()).getSizes()[0];
|
||||||
Type dtype = input.getType().cast<ValueTensorType>().getOptionalDtype();
|
Type dtype = cast<ValueTensorType>(input.getType()).getOptionalDtype();
|
||||||
Type reshapeType = ValueTensorType::get(
|
Type reshapeType = ValueTensorType::get(
|
||||||
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
context, llvm::ArrayRef(runningStatsShapeInt), dtype);
|
||||||
|
|
||||||
|
@ -5175,8 +5176,7 @@ class DecomposeConstantTensorNewLikeOp : public OpRewritePattern<OpTy> {
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||||
BaseTensorType tensorType =
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
op.getSelf().getType().template cast<BaseTensorType>();
|
|
||||||
if (!tensorType.hasDtype()) {
|
if (!tensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected input tensor to have a dtype");
|
op, "expected input tensor to have a dtype");
|
||||||
|
@ -5200,7 +5200,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenFullOp op,
|
LogicalResult matchAndRewrite(AtenFullOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
BaseTensorType outTy = cast<BaseTensorType>(op.getType());
|
||||||
if (!outTy.hasDtype()) {
|
if (!outTy.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
|
@ -5231,12 +5231,12 @@ public:
|
||||||
Value weight = op.getWeight();
|
Value weight = op.getWeight();
|
||||||
Value bias = op.getBias();
|
Value bias = op.getBias();
|
||||||
|
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
|
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected input to be rank 2 or greater");
|
op, "expected input to be rank 2 or greater");
|
||||||
|
|
||||||
BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
|
BaseTensorType weightType = cast<BaseTensorType>(weight.getType());
|
||||||
// `weight` must be a rank 2 matrix.
|
// `weight` must be a rank 2 matrix.
|
||||||
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
|
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
|
||||||
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
|
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");
|
||||||
|
@ -5255,7 +5255,7 @@ public:
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
|
BaseTensorType biasType = cast<BaseTensorType>(bias.getType());
|
||||||
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
|
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
|
||||||
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");
|
||||||
|
|
||||||
|
@ -5280,7 +5280,7 @@ public:
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
Type type = op.getType();
|
Type type = op.getType();
|
||||||
|
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasDtype())
|
if (!inputType.hasDtype())
|
||||||
return rewriter.notifyMatchFailure(op, "Dtype not present");
|
return rewriter.notifyMatchFailure(op, "Dtype not present");
|
||||||
|
|
||||||
|
@ -5306,7 +5306,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
LogicalResult matchAndRewrite(AtenFullLikeOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
BaseTensorType outTy = op.getType().template cast<BaseTensorType>();
|
BaseTensorType outTy = cast<BaseTensorType>(op.getType());
|
||||||
if (!outTy.hasDtype()) {
|
if (!outTy.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
|
@ -5335,7 +5335,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||||
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!tensorType.hasDtype()) {
|
if (!tensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected input tensor to have a dtype");
|
op, "expected input tensor to have a dtype");
|
||||||
|
@ -5393,7 +5393,7 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
LogicalResult matchAndRewrite(Aten_ToCopyOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto resultType = op.getType().cast<BaseTensorType>();
|
auto resultType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resultType.hasDtype()) {
|
if (!resultType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
|
@ -5419,12 +5419,12 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(AtenCopyOp op,
|
LogicalResult matchAndRewrite(AtenCopyOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto resultType = op.getType().cast<BaseTensorType>();
|
auto resultType = cast<BaseTensorType>(op.getType());
|
||||||
if (!resultType.hasDtype()) {
|
if (!resultType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected result type to have a dtype");
|
op, "expected result type to have a dtype");
|
||||||
}
|
}
|
||||||
auto srcTy = op.getSrc().getType().cast<BaseTensorType>();
|
auto srcTy = cast<BaseTensorType>(op.getSrc().getType());
|
||||||
if (!srcTy.hasSizes() || !srcTy.hasDtype()) {
|
if (!srcTy.hasSizes() || !srcTy.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected src type to have a known rank and dtype");
|
op, "expected src type to have a known rank and dtype");
|
||||||
|
@ -5448,7 +5448,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
|
||||||
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().isa<Torch::NoneType>()) {
|
if (dtype.getType().isa<Torch::NoneType>()) {
|
||||||
BaseTensorType tensorType = op.getSelf().getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!tensorType.hasDtype()) {
|
if (!tensorType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expected input tensor to have a dtype");
|
op, "expected input tensor to have a dtype");
|
||||||
|
@ -5588,7 +5588,7 @@ public:
|
||||||
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
Value constNone = rewriter.create<ConstantNoneOp>(loc);
|
||||||
|
|
||||||
Value dtype = op.getDtype();
|
Value dtype = op.getDtype();
|
||||||
if (dtype.getType().template isa<Torch::NoneType>()) {
|
if (isa<Torch::NoneType>(dtype.getType())) {
|
||||||
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
|
dtype = rewriter.create<Torch::PrimDtypeOp>(loc, op.getSelf());
|
||||||
}
|
}
|
||||||
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
rewriter.replaceOpWithNewOp<AtenToDtypeOp>(op, op.getType(), op.getSelf(),
|
||||||
|
@ -5665,7 +5665,7 @@ class DecomposeAtenAdaptiveAvgPool1dOp
|
||||||
|
|
||||||
SmallVector<Value, 1> kernelSize;
|
SmallVector<Value, 1> kernelSize;
|
||||||
if (outputSizeInt == 1) {
|
if (outputSizeInt == 1) {
|
||||||
BaseTensorType inputTensorType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
|
||||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||||
kernelSize.push_back(
|
kernelSize.push_back(
|
||||||
inputShape[rank - 1] == kUnknownSize
|
inputShape[rank - 1] == kUnknownSize
|
||||||
|
@ -5839,7 +5839,7 @@ class DecomposeAtenCosineSimilarityOp
|
||||||
SmallVector<Value> indexBroadcastShapeValue;
|
SmallVector<Value> indexBroadcastShapeValue;
|
||||||
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
|
computeBroadcastShape(rewriter, loc, x1, x2, indexBroadcastShapeInt,
|
||||||
indexBroadcastShapeValue);
|
indexBroadcastShapeValue);
|
||||||
Type dtype = x1.getType().cast<BaseTensorType>().getOptionalDtype();
|
Type dtype = cast<BaseTensorType>(x1.getType()).getOptionalDtype();
|
||||||
Type broadcastType = ValueTensorType::get(
|
Type broadcastType = ValueTensorType::get(
|
||||||
op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype);
|
op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype);
|
||||||
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
|
||||||
|
@ -5925,9 +5925,9 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern<AtenBaddbmmOp> {
|
||||||
Value alphaTimesBmm =
|
Value alphaTimesBmm =
|
||||||
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
rewriter.create<AtenMulScalarOp>(loc, op.getType(), bmm, op.getAlpha());
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
BaseTensorType resultType =
|
BaseTensorType resultType =
|
||||||
op->getResult(0).getType().cast<BaseTensorType>();
|
cast<BaseTensorType>(op->getResult(0).getType());
|
||||||
if (inputType.hasDtype() && resultType.hasDtype() &&
|
if (inputType.hasDtype() && resultType.hasDtype() &&
|
||||||
inputType.getDtype() != resultType.getDtype()) {
|
inputType.getDtype() != resultType.getDtype()) {
|
||||||
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
|
input = convertTensorToDtype(rewriter, loc, input, resultType.getDtype());
|
||||||
|
@ -6011,7 +6011,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
Value dimList = op.getDim();
|
Value dimList = op.getDim();
|
||||||
Value keepDim = op.getKeepdim();
|
Value keepDim = op.getKeepdim();
|
||||||
BaseTensorType inputTensorTy = self.getType().cast<BaseTensorType>();
|
BaseTensorType inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||||
Type outputType = op.getType();
|
Type outputType = op.getType();
|
||||||
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
|
BaseTensorType outputTensorType = cast<BaseTensorType>(outputType);
|
||||||
if (!outputTensorType.hasDtype()) {
|
if (!outputTensorType.hasDtype()) {
|
||||||
|
@ -6030,7 +6030,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||||
// computation of the result.
|
// computation of the result.
|
||||||
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
|
if (inputTensorTy.getDtype().getIntOrFloatBitWidth() != 64) {
|
||||||
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
|
self = convertTensorToDtype(rewriter, loc, self, rewriter.getF64Type());
|
||||||
inputTensorTy = self.getType().cast<BaseTensorType>();
|
inputTensorTy = cast<BaseTensorType>(self.getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
std::optional<unsigned> maybeInputRank = getTensorRank(self);
|
||||||
|
@ -6040,7 +6040,7 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter,
|
||||||
unsigned inputRank = *maybeInputRank;
|
unsigned inputRank = *maybeInputRank;
|
||||||
SmallVector<Value> dimListElements;
|
SmallVector<Value> dimListElements;
|
||||||
bool isNoneOrEmpty = true;
|
bool isNoneOrEmpty = true;
|
||||||
if (!dimList.getType().template isa<Torch::NoneType>()) {
|
if (!isa<Torch::NoneType>(dimList.getType())) {
|
||||||
if (!getListConstructElements(dimList, dimListElements))
|
if (!getListConstructElements(dimList, dimListElements))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "expect dimList to be constructed from list construct");
|
op, "expect dimList to be constructed from list construct");
|
||||||
|
@ -6287,8 +6287,8 @@ public:
|
||||||
op, "Expected a constant integer value for reduction");
|
op, "Expected a constant integer value for reduction");
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
BaseTensorType resultType = op.getType().cast<BaseTensorType>();
|
BaseTensorType resultType = cast<BaseTensorType>(op.getType());
|
||||||
BaseTensorType inputType = op.getSelf().getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(op.getSelf().getType());
|
||||||
if (!inputType.hasSizes())
|
if (!inputType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "Expected the input tensor to have sizes");
|
op, "Expected the input tensor to have sizes");
|
||||||
|
@ -6506,7 +6506,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
|
LogicalResult matchAndRewrite(AtenRandnGeneratorOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resultType = op.getType().cast<BaseTensorType>();
|
auto resultType = cast<BaseTensorType>(op.getType());
|
||||||
|
|
||||||
if (!resultType.hasDtype()) {
|
if (!resultType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -6617,7 +6617,7 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto resultType = op.getType().cast<BaseTensorType>();
|
auto resultType = cast<BaseTensorType>(op.getType());
|
||||||
|
|
||||||
if (!resultType.hasDtype()) {
|
if (!resultType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -6943,7 +6943,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
||||||
auto context = op.getContext();
|
auto context = op.getContext();
|
||||||
|
|
||||||
Value input = op.getSelf();
|
Value input = op.getSelf();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes())
|
if (!inputType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "input tensor should have known sizes.");
|
op, "input tensor should have known sizes.");
|
||||||
|
@ -6974,7 +6974,7 @@ class DecomposeAtenOneHotOp : public OpRewritePattern<AtenOneHotOp> {
|
||||||
|
|
||||||
// compare
|
// compare
|
||||||
auto eqType = ValueTensorType::get(
|
auto eqType = ValueTensorType::get(
|
||||||
context, op.getType().cast<BaseTensorType>().getSizes(),
|
context, cast<BaseTensorType>(op.getType()).getSizes(),
|
||||||
IntegerType::get(context, 1));
|
IntegerType::get(context, 1));
|
||||||
Value eqTensor = rewriter.create<AtenEqTensorOp>(
|
Value eqTensor = rewriter.create<AtenEqTensorOp>(
|
||||||
loc, eqType, unsqueezeTensor, arangeTensor);
|
loc, eqType, unsqueezeTensor, arangeTensor);
|
||||||
|
@ -7019,7 +7019,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenScalarTensorOp op,
|
LogicalResult matchAndRewrite(AtenScalarTensorOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
auto resultTy = op.getResult().getType().cast<BaseTensorType>();
|
auto resultTy = cast<BaseTensorType>(op.getResult().getType());
|
||||||
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
|
auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType());
|
||||||
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
Value numToTensor = rewriter.create<PrimNumToTensorScalarOp>(
|
||||||
op.getLoc(),
|
op.getLoc(),
|
||||||
|
@ -7060,7 +7060,7 @@ public:
|
||||||
|
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
Value dim = op.getDim();
|
Value dim = op.getDim();
|
||||||
auto selfType = self.getType().cast<BaseTensorType>();
|
auto selfType = cast<BaseTensorType>(self.getType());
|
||||||
auto sortIndicesType = selfType.getWithSizesAndDtype(
|
auto sortIndicesType = selfType.getWithSizesAndDtype(
|
||||||
selfType.getOptionalSizes(),
|
selfType.getOptionalSizes(),
|
||||||
IntegerType::get(context, 64, IntegerType::Signed));
|
IntegerType::get(context, 64, IntegerType::Signed));
|
||||||
|
@ -7111,8 +7111,8 @@ public:
|
||||||
Value sizeList = rewriter.create<PrimListConstructOp>(
|
Value sizeList = rewriter.create<PrimListConstructOp>(
|
||||||
loc, ListType::get(IntType::get(context)), sizes);
|
loc, ListType::get(IntType::get(context)), sizes);
|
||||||
|
|
||||||
auto selfType = self.getType().cast<BaseTensorType>();
|
auto selfType = cast<BaseTensorType>(self.getType());
|
||||||
auto indexType = index.getType().cast<BaseTensorType>();
|
auto indexType = cast<BaseTensorType>(index.getType());
|
||||||
BaseTensorType srcType =
|
BaseTensorType srcType =
|
||||||
selfType
|
selfType
|
||||||
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
.getWithSizesAndDtype(indexType.getOptionalSizes(),
|
||||||
|
@ -7135,7 +7135,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(AtenSgnOp op,
|
LogicalResult matchAndRewrite(AtenSgnOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
auto outType = op.getType().cast<BaseTensorType>();
|
auto outType = cast<BaseTensorType>(op.getType());
|
||||||
if (!outType.hasDtype()) {
|
if (!outType.hasDtype()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"expected result type to have dtype");
|
"expected result type to have dtype");
|
||||||
|
@ -7273,14 +7273,14 @@ public:
|
||||||
"failed to get elements of `indices`");
|
"failed to get elements of `indices`");
|
||||||
|
|
||||||
auto input = op.getSelf();
|
auto input = op.getSelf();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only input with shape information is supported");
|
op, "only input with shape information is supported");
|
||||||
}
|
}
|
||||||
auto inputSizes = inputType.getSizes();
|
auto inputSizes = inputType.getSizes();
|
||||||
int64_t inputRank = inputSizes.size();
|
int64_t inputRank = inputSizes.size();
|
||||||
auto outputType = op.getType().cast<BaseTensorType>();
|
auto outputType = cast<BaseTensorType>(op.getType());
|
||||||
if (!outputType.hasSizes()) {
|
if (!outputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only output with shape information is supported");
|
op, "only output with shape information is supported");
|
||||||
|
@ -7438,7 +7438,7 @@ public:
|
||||||
op, "failed to get elements of `dims` param");
|
op, "failed to get elements of `dims` param");
|
||||||
}
|
}
|
||||||
auto dimsSize = dimsElements.size();
|
auto dimsSize = dimsElements.size();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "only support input tensor with shape information");
|
op, "only support input tensor with shape information");
|
||||||
|
|
|
@ -89,7 +89,7 @@ public:
|
||||||
.cast<ValueTensorType>()
|
.cast<ValueTensorType>()
|
||||||
.getOptionalDtype();
|
.getOptionalDtype();
|
||||||
auto torchQType =
|
auto torchQType =
|
||||||
quant.getType().cast<ValueTensorType>().getOptionalDtype();
|
cast<ValueTensorType>(quant.getType()).getOptionalDtype();
|
||||||
auto transQTy =
|
auto transQTy =
|
||||||
rewriter.getType<ValueTensorType>(trans.getResult()
|
rewriter.getType<ValueTensorType>(trans.getResult()
|
||||||
.getType()
|
.getType()
|
||||||
|
@ -152,7 +152,7 @@ public:
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value bias = operands[2];
|
Value bias = operands[2];
|
||||||
auto biasTy = bias.getType().dyn_cast<ValueTensorType>();
|
auto biasTy = dyn_cast<ValueTensorType>(bias.getType());
|
||||||
|
|
||||||
if (biasTy) {
|
if (biasTy) {
|
||||||
auto biasETy = biasTy.getOptionalDtype();
|
auto biasETy = biasTy.getOptionalDtype();
|
||||||
|
|
|
@ -134,7 +134,7 @@ private:
|
||||||
slotName = setAttrOp.getName();
|
slotName = setAttrOp.getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
auto moduleType = module.getType().cast<NnModuleType>();
|
auto moduleType = cast<NnModuleType>(module.getType());
|
||||||
auto slots = moduleClassNameToSlots.find(moduleType.getClassName());
|
auto slots = moduleClassNameToSlots.find(moduleType.getClassName());
|
||||||
// TODO: Improve verifier so that this can never happen
|
// TODO: Improve verifier so that this can never happen
|
||||||
if (slots == moduleClassNameToSlots.end())
|
if (slots == moduleClassNameToSlots.end())
|
||||||
|
@ -163,13 +163,13 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||||
nnModule.getType().cast<NnModuleType>().getClassName());
|
cast<NnModuleType>(nnModule.getType()).getClassName());
|
||||||
for (auto t :
|
for (auto t :
|
||||||
llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) {
|
llvm::zip(nnModule.getOps<SlotOp>(), classType.getOps<AttrOp>())) {
|
||||||
auto slot = std::get<0>(t);
|
auto slot = std::get<0>(t);
|
||||||
auto attr = std::get<1>(t);
|
auto attr = std::get<1>(t);
|
||||||
nameStack.push_back(attr.getName().str());
|
nameStack.push_back(attr.getName().str());
|
||||||
if (attr.getType().isa<NnModuleType>()) {
|
if (isa<NnModuleType>(attr.getType())) {
|
||||||
if (failed(recursivelyTraverse(
|
if (failed(recursivelyTraverse(
|
||||||
slot.getValue().getDefiningOp<NnModuleOp>())))
|
slot.getValue().getDefiningOp<NnModuleOp>())))
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -333,7 +333,7 @@ static LogicalResult analyzeInstances(func::FuncOp func,
|
||||||
for (auto &argInstance : argInstances)
|
for (auto &argInstance : argInstances)
|
||||||
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
|
mapping.map(func.getArgument(argInstance.argIndex), argInstance.instance);
|
||||||
auto walkResult = func.walk([&](PrimGetAttrOp op) {
|
auto walkResult = func.walk([&](PrimGetAttrOp op) {
|
||||||
if (!op.getType().isa<NnModuleType>())
|
if (!isa<NnModuleType>(op.getType()))
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
auto instance = mapping.lookupOrNull(op.getReceiver());
|
auto instance = mapping.lookupOrNull(op.getReceiver());
|
||||||
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
||||||
|
@ -355,7 +355,7 @@ createMonomorphizationForCall(func::CallOp op, IRMapping &mapping,
|
||||||
Monomorphization monomorphization;
|
Monomorphization monomorphization;
|
||||||
monomorphization.func = func;
|
monomorphization.func = func;
|
||||||
for (auto operand : llvm::enumerate(op->getOperands())) {
|
for (auto operand : llvm::enumerate(op->getOperands())) {
|
||||||
if (!operand.value().getType().isa<NnModuleType>())
|
if (!isa<NnModuleType>(operand.value().getType()))
|
||||||
continue;
|
continue;
|
||||||
Value instance = mapping.lookupOrNull(operand.value());
|
Value instance = mapping.lookupOrNull(operand.value());
|
||||||
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
assert(instance && "verifyFuncConformsToSubset should ensure this");
|
||||||
|
@ -377,7 +377,7 @@ public:
|
||||||
monomorphization.func = func;
|
monomorphization.func = func;
|
||||||
bool canTriviallyMonomorphize = true;
|
bool canTriviallyMonomorphize = true;
|
||||||
for (auto arg : llvm::enumerate(func.getArguments())) {
|
for (auto arg : llvm::enumerate(func.getArguments())) {
|
||||||
auto type = arg.value().getType().dyn_cast<NnModuleType>();
|
auto type = dyn_cast<NnModuleType>(arg.value().getType());
|
||||||
if (!type)
|
if (!type)
|
||||||
continue;
|
continue;
|
||||||
auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName());
|
auto classType = symbolTable.lookup<ClassTypeOp>(type.getClassName());
|
||||||
|
@ -436,7 +436,7 @@ private:
|
||||||
// !torch.nn.Module<"..."> types.
|
// !torch.nn.Module<"..."> types.
|
||||||
static LogicalResult verifyNnModuleValueUses(Value value) {
|
static LogicalResult verifyNnModuleValueUses(Value value) {
|
||||||
// Trivially succeed for non-module types.
|
// Trivially succeed for non-module types.
|
||||||
if (!value.getType().isa<NnModuleType>())
|
if (!isa<NnModuleType>(value.getType()))
|
||||||
return success();
|
return success();
|
||||||
for (Operation *op : value.getUsers()) {
|
for (Operation *op : value.getUsers()) {
|
||||||
if (isa<func::CallOp, PrimGetAttrOp>(op))
|
if (isa<func::CallOp, PrimGetAttrOp>(op))
|
||||||
|
@ -516,7 +516,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
};
|
};
|
||||||
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
|
auto handlePrimGetAttr = [&](PrimGetAttrOp op) {
|
||||||
if (!op.getType().isa<NnModuleType>()) {
|
if (!isa<NnModuleType>(op.getType())) {
|
||||||
auto instance =
|
auto instance =
|
||||||
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
mapping.lookup(op.getReceiver()).getDefiningOp<NnModuleOp>();
|
||||||
SlotOp affectedSlot;
|
SlotOp affectedSlot;
|
||||||
|
@ -540,7 +540,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
||||||
Monomorphization monomorphization = std::move(*maybeMonomorphization);
|
Monomorphization monomorphization = std::move(*maybeMonomorphization);
|
||||||
auto newArguments = llvm::to_vector<6>(
|
auto newArguments = llvm::to_vector<6>(
|
||||||
llvm::make_filter_range(op->getOperands(), [](Value v) {
|
llvm::make_filter_range(op->getOperands(), [](Value v) {
|
||||||
return !v.getType().isa<NnModuleType>();
|
return !isa<NnModuleType>(v.getType());
|
||||||
}));
|
}));
|
||||||
assert(newFuncs.find(monomorphization) != newFuncs.end());
|
assert(newFuncs.find(monomorphization) != newFuncs.end());
|
||||||
auto newOp = OpBuilder(op).create<func::CallOp>(
|
auto newOp = OpBuilder(op).create<func::CallOp>(
|
||||||
|
@ -564,7 +564,7 @@ static LogicalResult rewriteMonomorphizedFuncClone(
|
||||||
}
|
}
|
||||||
llvm::BitVector argsToErase(func.getNumArguments());
|
llvm::BitVector argsToErase(func.getNumArguments());
|
||||||
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
for (auto type : llvm::enumerate(func.getArgumentTypes())) {
|
||||||
if (type.value().isa<NnModuleType>()) {
|
if (isa<NnModuleType>(type.value())) {
|
||||||
argsToErase.set(type.index());
|
argsToErase.set(type.index());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -248,8 +248,8 @@ bool InlineGlobalSlotsAnalysis::isValueSafeTransferFunction(Value value) {
|
||||||
}))
|
}))
|
||||||
continue;
|
continue;
|
||||||
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
if (auto initialize = dyn_cast<Torch::InitializeGlobalSlotsOp>(op)) {
|
||||||
auto symName = initialize.getSlotSymNames()[use.getOperandNumber()]
|
auto symName = cast<FlatSymbolRefAttr>(
|
||||||
.cast<FlatSymbolRefAttr>();
|
initialize.getSlotSymNames()[use.getOperandNumber()]);
|
||||||
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
auto *state = getOrCreateFor<InlineGlobalSlotsAnalysisState>(
|
||||||
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
|
value, getProgramPoint<FlatSymbolRefProgramPoint>(symName));
|
||||||
if (state->isSafe)
|
if (state->isSafe)
|
||||||
|
@ -333,10 +333,10 @@ class InlineGlobalSlotsPass
|
||||||
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
|
DenseSet</*FlatSymbolRefAttr*/ Attribute> safeToInline;
|
||||||
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
|
for (int i = 0, e = initialize->getNumOperands(); i != e; i++) {
|
||||||
auto slotSymName =
|
auto slotSymName =
|
||||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||||
Value operand = initialize.getOperand(i);
|
Value operand = initialize.getOperand(i);
|
||||||
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
auto symbolRefPoint = solver.getProgramPoint<FlatSymbolRefProgramPoint>(
|
||||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>());
|
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]));
|
||||||
auto *state =
|
auto *state =
|
||||||
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
solver.lookupState<InlineGlobalSlotsAnalysisState>(symbolRefPoint);
|
||||||
// We roll the analysis of whether a slot is set or public into the
|
// We roll the analysis of whether a slot is set or public into the
|
||||||
|
@ -408,7 +408,7 @@ class InlineGlobalSlotsPass
|
||||||
SmallVector<Value> newInitialValues;
|
SmallVector<Value> newInitialValues;
|
||||||
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
|
for (int i = 0, e = initialize.getNumOperands(); i != e; i++) {
|
||||||
auto slotSymName =
|
auto slotSymName =
|
||||||
initialize.getSlotSymNames()[i].cast<FlatSymbolRefAttr>();
|
cast<FlatSymbolRefAttr>(initialize.getSlotSymNames()[i]);
|
||||||
if (!safeToInline.count(slotSymName)) {
|
if (!safeToInline.count(slotSymName)) {
|
||||||
newSlotSymNames.push_back(slotSymName);
|
newSlotSymNames.push_back(slotSymName);
|
||||||
newInitialValues.push_back(initialize.getOperand(i));
|
newInitialValues.push_back(initialize.getOperand(i));
|
||||||
|
|
|
@ -118,7 +118,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
if (auto optionalType = dyn_cast<OptionalType>(type)) {
|
if (auto optionalType = dyn_cast<OptionalType>(type)) {
|
||||||
// TODO: Be stricter about tensor types.
|
// TODO: Be stricter about tensor types.
|
||||||
// See comment below for ListType.
|
// See comment below for ListType.
|
||||||
if (optionalType.getContainedType().isa<ValueTensorType>())
|
if (isa<ValueTensorType>(optionalType.getContainedType()))
|
||||||
return success();
|
return success();
|
||||||
return checkType(op, optionalType.getContainedType(),
|
return checkType(op, optionalType.getContainedType(),
|
||||||
actuallyEmitDiagnostics);
|
actuallyEmitDiagnostics);
|
||||||
|
@ -134,7 +134,7 @@ static LogicalResult checkType(Operation *op, Type type,
|
||||||
// the contained type information. Somehow this slips through and works.
|
// the contained type information. Somehow this slips through and works.
|
||||||
// We should be stricter about this and properly infer the contained type
|
// We should be stricter about this and properly infer the contained type
|
||||||
// and shape.
|
// and shape.
|
||||||
if (listType.getContainedType().isa<ValueTensorType>())
|
if (isa<ValueTensorType>(listType.getContainedType()))
|
||||||
return success();
|
return success();
|
||||||
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
|
return checkType(op, listType.getContainedType(), actuallyEmitDiagnostics);
|
||||||
}
|
}
|
||||||
|
@ -535,7 +535,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
|
||||||
}
|
}
|
||||||
target.addDynamicallyLegalOp<OperatorOp>(
|
target.addDynamicallyLegalOp<OperatorOp>(
|
||||||
[backendLegalOpsSet](OperatorOp opOp) {
|
[backendLegalOpsSet](OperatorOp opOp) {
|
||||||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
|
||||||
return backendLegalOpsSet.contains(opName);
|
return backendLegalOpsSet.contains(opName);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ public:
|
||||||
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
|
op.getLoc(), op.getOperand(0).getType(), op.getOperand(0),
|
||||||
op.getOperand(3), op.getOperand(4));
|
op.getOperand(3), op.getOperand(4));
|
||||||
|
|
||||||
auto clampTy = clamp.getType().cast<Torch::ValueTensorType>();
|
auto clampTy = cast<Torch::ValueTensorType>(clamp.getType());
|
||||||
if (!clampTy.hasDtype())
|
if (!clampTy.hasDtype())
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"dequantization has unknown dtype");
|
"dequantization has unknown dtype");
|
||||||
|
|
|
@ -23,7 +23,7 @@ using namespace mlir::torch;
|
||||||
using namespace mlir::torch::Torch;
|
using namespace mlir::torch::Torch;
|
||||||
|
|
||||||
static Value assertNonValueTensor(Value tensor) {
|
static Value assertNonValueTensor(Value tensor) {
|
||||||
assert(tensor.getType().isa<NonValueTensorType>() &&
|
assert(isa<NonValueTensorType>(tensor.getType()) &&
|
||||||
"tensor is expected to be a non-value tensor");
|
"tensor is expected to be a non-value tensor");
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ public:
|
||||||
// to use value semantics (which happens for example with ops
|
// to use value semantics (which happens for example with ops
|
||||||
// that take two aliases as input), then it is possible that the
|
// that take two aliases as input), then it is possible that the
|
||||||
// op no longer generates an alias.
|
// op no longer generates an alias.
|
||||||
if (userResult.getType().isa<NonValueTensorType>())
|
if (isa<NonValueTensorType>(userResult.getType()))
|
||||||
availableAliases.insert(userResult);
|
availableAliases.insert(userResult);
|
||||||
result.viewLikeOps.push_back(user);
|
result.viewLikeOps.push_back(user);
|
||||||
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
} else if (auto copyToValueTensor = dyn_cast<CopyToValueTensorOp>(user)) {
|
||||||
|
@ -177,7 +177,7 @@ public:
|
||||||
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
for (Operation *viewLikeOp : ops.viewLikeOps) {
|
||||||
rewriter.modifyOpInPlace(viewLikeOp, [&] {
|
rewriter.modifyOpInPlace(viewLikeOp, [&] {
|
||||||
Value result = viewLikeOp->getResult(0);
|
Value result = viewLikeOp->getResult(0);
|
||||||
auto resultType = result.getType().dyn_cast<NonValueTensorType>();
|
auto resultType = dyn_cast<NonValueTensorType>(result.getType());
|
||||||
if (resultType)
|
if (resultType)
|
||||||
result.setType(resultType.getWithValueSemantics());
|
result.setType(resultType.getWithValueSemantics());
|
||||||
});
|
});
|
||||||
|
@ -230,7 +230,7 @@ public:
|
||||||
if (isViewLikeOp(op)) {
|
if (isViewLikeOp(op)) {
|
||||||
// We currently only support view-like ops with one tensor output.
|
// We currently only support view-like ops with one tensor output.
|
||||||
if (op->getNumResults() != 1 ||
|
if (op->getNumResults() != 1 ||
|
||||||
!op->getResult(0).getType().isa<BaseTensorType>()) {
|
!isa<BaseTensorType>(op->getResult(0).getType())) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
copy, "unsupported: view-like ops must have one tensor output, "
|
copy, "unsupported: view-like ops must have one tensor output, "
|
||||||
"and the tensor output must be the first result");
|
"and the tensor output must be the first result");
|
||||||
|
@ -242,7 +242,7 @@ public:
|
||||||
// non-value tensor and the output being a value tensor. If this is the
|
// non-value tensor and the output being a value tensor. If this is the
|
||||||
// case then there is no need to look at the users of the result of the
|
// case then there is no need to look at the users of the result of the
|
||||||
// op.
|
// op.
|
||||||
if (opResult.getType().isa<NonValueTensorType>()) {
|
if (isa<NonValueTensorType>(opResult.getType())) {
|
||||||
if (operand.getOperandNumber() == 0) {
|
if (operand.getOperandNumber() == 0) {
|
||||||
validViewLikeOps.insert(op);
|
validViewLikeOps.insert(op);
|
||||||
llvm::append_range(workList, opResult.getUses());
|
llvm::append_range(workList, opResult.getUses());
|
||||||
|
@ -339,7 +339,7 @@ public:
|
||||||
for (Operation *op : viewLikeOps) {
|
for (Operation *op : viewLikeOps) {
|
||||||
rewriter.modifyOpInPlace(op, [&]() {
|
rewriter.modifyOpInPlace(op, [&]() {
|
||||||
if (auto nonValueTensorType =
|
if (auto nonValueTensorType =
|
||||||
op->getResult(0).getType().dyn_cast<NonValueTensorType>()) {
|
dyn_cast<NonValueTensorType>(op->getResult(0).getType())) {
|
||||||
originalTypes[op->getResult(0)] = nonValueTensorType;
|
originalTypes[op->getResult(0)] = nonValueTensorType;
|
||||||
op->getResult(0).setType(nonValueTensorType.getWithValueSemantics());
|
op->getResult(0).setType(nonValueTensorType.getWithValueSemantics());
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,7 +30,7 @@ public:
|
||||||
LogicalResult matchAndRewrite(PrimCallMethodOp op,
|
LogicalResult matchAndRewrite(PrimCallMethodOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto classType = symbolTable.lookup<ClassTypeOp>(
|
auto classType = symbolTable.lookup<ClassTypeOp>(
|
||||||
op.getReceiver().getType().cast<NnModuleType>().getClassName());
|
cast<NnModuleType>(op.getReceiver().getType()).getClassName());
|
||||||
assert(classType && "malformed module -- missing ClassTypeOp");
|
assert(classType && "malformed module -- missing ClassTypeOp");
|
||||||
func::FuncOp func;
|
func::FuncOp func;
|
||||||
for (auto method : classType.getOps<MethodOp>()) {
|
for (auto method : classType.getOps<MethodOp>()) {
|
||||||
|
@ -94,7 +94,7 @@ class PrepareForGlobalizeObjectGraphPass
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addIllegalOp<PrimCallMethodOp>();
|
target.addIllegalOp<PrimCallMethodOp>();
|
||||||
target.addDynamicallyLegalOp<func::ConstantOp>(
|
target.addDynamicallyLegalOp<func::ConstantOp>(
|
||||||
[](func::ConstantOp op) { return !op.getType().isa<FunctionType>(); });
|
[](func::ConstantOp op) { return !isa<FunctionType>(op.getType()); });
|
||||||
target.addIllegalOp<func::CallIndirectOp>();
|
target.addIllegalOp<func::CallIndirectOp>();
|
||||||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,7 @@ public:
|
||||||
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
Value falseVal = rewriter.create<ConstantBoolOp>(op.getLoc(), false);
|
||||||
|
|
||||||
// Create IndexPut_Op
|
// Create IndexPut_Op
|
||||||
BaseTensorType tensorType = op.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
|
||||||
Type rangeType = tensorType.getWithSizesAndDtype(
|
Type rangeType = tensorType.getWithSizesAndDtype(
|
||||||
{kUnknownSize}, tensorType.getOptionalDtype());
|
{kUnknownSize}, tensorType.getOptionalDtype());
|
||||||
Value range = rewriter.create<AtenArangeStartStepOp>(
|
Value range = rewriter.create<AtenArangeStartStepOp>(
|
||||||
|
@ -130,8 +130,7 @@ public:
|
||||||
|
|
||||||
// Create IndexPut_Op
|
// Create IndexPut_Op
|
||||||
// Convert indexNum to indexTensor for the selectOp
|
// Convert indexNum to indexTensor for the selectOp
|
||||||
BaseTensorType selectOutTy =
|
BaseTensorType selectOutTy = cast<BaseTensorType>(selectOp.getType());
|
||||||
selectOp.getType().template cast<BaseTensorType>();
|
|
||||||
SmallVector<int64_t> empty;
|
SmallVector<int64_t> empty;
|
||||||
auto dtype = getTypeForTorchType(selectOp.getContext(),
|
auto dtype = getTypeForTorchType(selectOp.getContext(),
|
||||||
selectOp.getIndex().getType());
|
selectOp.getIndex().getType());
|
||||||
|
@ -141,7 +140,7 @@ public:
|
||||||
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
|
selectOp.getLoc(), emptyTensorType, selectOp.getIndex());
|
||||||
|
|
||||||
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
|
// Create indicesVector for IndexPut_Op by TorchNone and indexTensor
|
||||||
BaseTensorType tensorType = op->getResultTypes()[0].cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(op->getResultTypes()[0]);
|
||||||
SmallVector<Value> indicesVector(dim, noneVal);
|
SmallVector<Value> indicesVector(dim, noneVal);
|
||||||
indicesVector.push_back(indexTensor);
|
indicesVector.push_back(indexTensor);
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,9 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter,
|
||||||
Location loc, Value overwriterTensor,
|
Location loc, Value overwriterTensor,
|
||||||
Value overwrittenTensor) {
|
Value overwrittenTensor) {
|
||||||
Type overwriterTensorType = overwriterTensor.getType();
|
Type overwriterTensorType = overwriterTensor.getType();
|
||||||
Type overwrittenTensorType = overwrittenTensor.getType()
|
Type overwrittenTensorType =
|
||||||
.dyn_cast<NonValueTensorType>()
|
dyn_cast<NonValueTensorType>(overwrittenTensor.getType())
|
||||||
.getWithValueSemantics();
|
.getWithValueSemantics();
|
||||||
if (overwriterTensorType != overwrittenTensorType) {
|
if (overwriterTensorType != overwrittenTensorType) {
|
||||||
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
|
overwriterTensor = rewriter.create<TensorStaticInfoCastOp>(
|
||||||
loc, overwrittenTensorType, overwriterTensor);
|
loc, overwrittenTensorType, overwriterTensor);
|
||||||
|
@ -58,7 +58,7 @@ operatorOpHasValueSemantics(OperatorOp opOp,
|
||||||
std::optional<SymbolTable> extraLibrary) {
|
std::optional<SymbolTable> extraLibrary) {
|
||||||
if (!extraLibrary.has_value())
|
if (!extraLibrary.has_value())
|
||||||
return false;
|
return false;
|
||||||
auto opName = opOp->getAttr("name").cast<StringAttr>().getValue();
|
auto opName = cast<StringAttr>(opOp->getAttr("name")).getValue();
|
||||||
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
|
std::string libFuncName = (mlir::torch::Torch::getLibraryFunctionPrefix(
|
||||||
LibraryFunctionKind::HasValueSemantics) +
|
LibraryFunctionKind::HasValueSemantics) +
|
||||||
Twine(opName))
|
Twine(opName))
|
||||||
|
@ -96,8 +96,8 @@ public:
|
||||||
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||||
opOperand.get()));
|
opOperand.get()));
|
||||||
} else if (auto listType = dyn_cast<ListType>(operandType)) {
|
} else if (auto listType = dyn_cast<ListType>(operandType)) {
|
||||||
if (!(listType.getContainedType().isa<NonValueTensorType>() ||
|
if (!(isa<NonValueTensorType>(listType.getContainedType()) ||
|
||||||
listType.getContainedType().isa<OptionalType>()))
|
isa<OptionalType>(listType.getContainedType())))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Construct a new list whose elements are value tensors copied from
|
// Construct a new list whose elements are value tensors copied from
|
||||||
|
@ -116,7 +116,7 @@ public:
|
||||||
|
|
||||||
// TODO: Handle optional type in list type.
|
// TODO: Handle optional type in list type.
|
||||||
if (auto optionalType =
|
if (auto optionalType =
|
||||||
listType.getContainedType().dyn_cast<OptionalType>()) {
|
dyn_cast<OptionalType>(listType.getContainedType())) {
|
||||||
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
if (!llvm::all_of(listConstruct.getElements(), [](Value val) {
|
||||||
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
return val.getType().isa<NonValueTensorType, Torch::NoneType>();
|
||||||
})) {
|
})) {
|
||||||
|
@ -129,7 +129,7 @@ public:
|
||||||
|
|
||||||
auto newListElements = llvm::to_vector(llvm::map_range(
|
auto newListElements = llvm::to_vector(llvm::map_range(
|
||||||
listConstruct.getElements(), [&](Value tensor) -> Value {
|
listConstruct.getElements(), [&](Value tensor) -> Value {
|
||||||
if (tensor.getType().isa<NonValueTensorType>()) {
|
if (isa<NonValueTensorType>(tensor.getType())) {
|
||||||
return rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
return rewriter.create<CopyToValueTensorOp>(op->getLoc(),
|
||||||
tensor);
|
tensor);
|
||||||
}
|
}
|
||||||
|
@ -147,7 +147,7 @@ public:
|
||||||
} else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
|
} else if (auto optionalType = dyn_cast<OptionalType>(operandType)) {
|
||||||
// TODO: A more general way to handle the optional type is to
|
// TODO: A more general way to handle the optional type is to
|
||||||
// introduce a `copy.to_optional_vtensor` op.
|
// introduce a `copy.to_optional_vtensor` op.
|
||||||
if (!optionalType.getContainedType().isa<NonValueTensorType>())
|
if (!isa<NonValueTensorType>(optionalType.getContainedType()))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Create a new optional value whose input is a value tensor copied
|
// Create a new optional value whose input is a value tensor copied
|
||||||
|
@ -160,7 +160,7 @@ public:
|
||||||
"derefine");
|
"derefine");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!derefine.getOperand().getType().isa<NonValueTensorType>())
|
if (!isa<NonValueTensorType>(derefine.getOperand().getType()))
|
||||||
continue;
|
continue;
|
||||||
auto newOperand = rewriter.create<CopyToValueTensorOp>(
|
auto newOperand = rewriter.create<CopyToValueTensorOp>(
|
||||||
op->getLoc(), derefine.getOperand());
|
op->getLoc(), derefine.getOperand());
|
||||||
|
@ -172,7 +172,7 @@ public:
|
||||||
// Convert all results.
|
// Convert all results.
|
||||||
rewriter.setInsertionPointAfter(op);
|
rewriter.setInsertionPointAfter(op);
|
||||||
for (Value result : op->getResults()) {
|
for (Value result : op->getResults()) {
|
||||||
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
|
auto tensorType = dyn_cast<NonValueTensorType>(result.getType());
|
||||||
if (!tensorType)
|
if (!tensorType)
|
||||||
continue;
|
continue;
|
||||||
result.setType(tensorType.getWithValueSemantics());
|
result.setType(tensorType.getWithValueSemantics());
|
||||||
|
|
|
@ -84,7 +84,7 @@ class RefinePublicReturnPass
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto tensorType = newOperand.getType().dyn_cast<BaseTensorType>()) {
|
if (auto tensorType = dyn_cast<BaseTensorType>(newOperand.getType())) {
|
||||||
newOperands.push_back(
|
newOperands.push_back(
|
||||||
copyTensorToType(builder, returnOp->getLoc(),
|
copyTensorToType(builder, returnOp->getLoc(),
|
||||||
tensorType.getWithValueSemantics(), newOperand));
|
tensorType.getWithValueSemantics(), newOperand));
|
||||||
|
|
|
@ -118,7 +118,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable(
|
||||||
assert(call.getNumResults() == 1 &&
|
assert(call.getNumResults() == 1 &&
|
||||||
"Multiple results are packed in a tuple in Python!");
|
"Multiple results are packed in a tuple in Python!");
|
||||||
Value result = call.getResult(0);
|
Value result = call.getResult(0);
|
||||||
if (auto tupleType = result.getType().dyn_cast<Torch::TupleType>()) {
|
if (auto tupleType = dyn_cast<Torch::TupleType>(result.getType())) {
|
||||||
auto unpack = b.create<PrimTupleUnpackOp>(
|
auto unpack = b.create<PrimTupleUnpackOp>(
|
||||||
loc, tupleType.getContainedTypes(), result);
|
loc, tupleType.getContainedTypes(), result);
|
||||||
llvm::append_range(unpackedResults, unpack.getResults());
|
llvm::append_range(unpackedResults, unpack.getResults());
|
||||||
|
@ -275,7 +275,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// for i in range(len(operand)):
|
// for i in range(len(operand)):
|
||||||
// adjusted_list.append(adjust(operand[i]))
|
// adjusted_list.append(adjust(operand[i]))
|
||||||
// return adjusted_list
|
// return adjusted_list
|
||||||
auto providedType = operand.getType().cast<Torch::ListType>();
|
auto providedType = cast<Torch::ListType>(operand.getType());
|
||||||
Value adjustedList =
|
Value adjustedList =
|
||||||
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
|
b.create<PrimListConstructOp>(loc, desiredListType, ValueRange({}));
|
||||||
// Create a for-like PrimLoopOp.
|
// Create a for-like PrimLoopOp.
|
||||||
|
@ -312,7 +312,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand,
|
||||||
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
// signature uses `Scalar` (see comments in torch_ods_gen.py for
|
||||||
// explanation).
|
// explanation).
|
||||||
if (isa<Torch::FloatType>(desiredType) &&
|
if (isa<Torch::FloatType>(desiredType) &&
|
||||||
operand.getType().isa<Torch::IntType>()) {
|
isa<Torch::IntType>(operand.getType())) {
|
||||||
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
auto dtypeArgAdjuster = [](OpBuilder &b, Location loc, Value operand,
|
||||||
Type desiredType) -> Value {
|
Type desiredType) -> Value {
|
||||||
if (isa<Torch::TupleType>(desiredType) &&
|
if (isa<Torch::TupleType>(desiredType) &&
|
||||||
operand.getType().isa<Torch::BaseTensorType>()) {
|
isa<Torch::BaseTensorType>(operand.getType())) {
|
||||||
Type intType = Torch::IntType::get(b.getContext());
|
Type intType = Torch::IntType::get(b.getContext());
|
||||||
Type sizeListType = Torch::ListType::get(intType);
|
Type sizeListType = Torch::ListType::get(intType);
|
||||||
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
|
Value size = b.create<AtenSizeOp>(loc, sizeListType, operand);
|
||||||
|
|
|
@ -41,8 +41,8 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc,
|
||||||
auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
|
auto desiredListType = dyn_cast<Torch::ListType>(desiredType);
|
||||||
if (!desiredListType)
|
if (!desiredListType)
|
||||||
return operand;
|
return operand;
|
||||||
if (operand.getType().isa<Torch::BaseTensorType>() &&
|
if (isa<Torch::BaseTensorType>(operand.getType()) &&
|
||||||
desiredListType.getContainedType().isa<Torch::IntType>()) {
|
isa<Torch::IntType>(desiredListType.getContainedType())) {
|
||||||
return b.create<AtenSizeOp>(loc, desiredType, operand);
|
return b.create<AtenSizeOp>(loc, desiredType, operand);
|
||||||
}
|
}
|
||||||
return operand;
|
return operand;
|
||||||
|
|
|
@ -259,7 +259,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
Type originalResultType = result.getType();
|
Type originalResultType = result.getType();
|
||||||
Type updatedType;
|
Type updatedType;
|
||||||
if (auto originalBaseTensorType =
|
if (auto originalBaseTensorType =
|
||||||
originalResultType.template dyn_cast<BaseTensorType>()) {
|
dyn_cast<BaseTensorType>(originalResultType)) {
|
||||||
// If we didn't get any new information, there is nothing left for us to do.
|
// If we didn't get any new information, there is nothing left for us to do.
|
||||||
updatedType = meetTensorTypes(originalBaseTensorType,
|
updatedType = meetTensorTypes(originalBaseTensorType,
|
||||||
cast<BaseTensorType>(newResultType));
|
cast<BaseTensorType>(newResultType));
|
||||||
|
@ -267,7 +267,7 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
calculateOp, "New type information does not refine old type");
|
calculateOp, "New type information does not refine old type");
|
||||||
} else if (auto originalResultType =
|
} else if (auto originalResultType =
|
||||||
result.getType().template dyn_cast<Torch::NumberType>()) {
|
dyn_cast<Torch::NumberType>(result.getType())) {
|
||||||
if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
|
if (!isa<Torch::FloatType, Torch::IntType>(newResultType)) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
calculateOp,
|
calculateOp,
|
||||||
|
|
|
@ -35,7 +35,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
||||||
|
|
||||||
// Calculate the updated type incorporating the new information.
|
// Calculate the updated type incorporating the new information.
|
||||||
Type impliedTypeFromDtype;
|
Type impliedTypeFromDtype;
|
||||||
if (result.getType().isa<Torch::NumberType>()) {
|
if (isa<Torch::NumberType>(result.getType())) {
|
||||||
FailureOr<Type> torchType =
|
FailureOr<Type> torchType =
|
||||||
getTorchTypeForScalarType(op->getContext(), dtypeScalarType);
|
getTorchTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||||
if (failed(torchType)) {
|
if (failed(torchType)) {
|
||||||
|
@ -45,7 +45,7 @@ static LogicalResult refineDtypeCalculateResult(DtypeCalculateOp op,
|
||||||
}
|
}
|
||||||
impliedTypeFromDtype = *torchType;
|
impliedTypeFromDtype = *torchType;
|
||||||
} else if (auto originalResultType =
|
} else if (auto originalResultType =
|
||||||
result.getType().dyn_cast<BaseTensorType>()) {
|
dyn_cast<BaseTensorType>(result.getType())) {
|
||||||
FailureOr<Type> builtinType =
|
FailureOr<Type> builtinType =
|
||||||
getTypeForScalarType(op->getContext(), dtypeScalarType);
|
getTypeForScalarType(op->getContext(), dtypeScalarType);
|
||||||
if (failed(builtinType)) {
|
if (failed(builtinType)) {
|
||||||
|
@ -168,12 +168,12 @@ public:
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
|
LogicalResult matchAndRewrite(PrimNumToTensorScalarOp op,
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
auto originalResultType = op.getResult().getType().cast<BaseTensorType>();
|
auto originalResultType = cast<BaseTensorType>(op.getResult().getType());
|
||||||
if (originalResultType.hasDtype())
|
if (originalResultType.hasDtype())
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
op, "`PrimNumToTensorScalarOp` already has a dtype");
|
op, "`PrimNumToTensorScalarOp` already has a dtype");
|
||||||
|
|
||||||
if (op.getA().getType().isa<Torch::NumberType>()) {
|
if (isa<Torch::NumberType>(op.getA().getType())) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"`PrimNumToTensorScalarOp`'s input "
|
"`PrimNumToTensorScalarOp`'s input "
|
||||||
"should have concrete Scalar Type.");
|
"should have concrete Scalar Type.");
|
||||||
|
|
|
@ -27,7 +27,7 @@ public:
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value self = op.getSelf();
|
Value self = op.getSelf();
|
||||||
MLIRContext *context = op.getContext();
|
MLIRContext *context = op.getContext();
|
||||||
auto tensorType = self.getType().cast<BaseTensorType>();
|
auto tensorType = cast<BaseTensorType>(self.getType());
|
||||||
if (!tensorType.hasSizes())
|
if (!tensorType.hasSizes())
|
||||||
return rewriter.notifyMatchFailure(op, "unranked tensor");
|
return rewriter.notifyMatchFailure(op, "unranked tensor");
|
||||||
int64_t rank = tensorType.getSizes().size();
|
int64_t rank = tensorType.getSizes().size();
|
||||||
|
@ -96,7 +96,7 @@ static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
|
||||||
sizes.push_back(kUnknownSize);
|
sizes.push_back(kUnknownSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto originalResultType = result.getType().cast<BaseTensorType>();
|
auto originalResultType = cast<BaseTensorType>(result.getType());
|
||||||
auto impliedTypesFromShape =
|
auto impliedTypesFromShape =
|
||||||
cast<BaseTensorType>(originalResultType)
|
cast<BaseTensorType>(originalResultType)
|
||||||
.getWithSizesAndDtype(ArrayRef(sizes),
|
.getWithSizesAndDtype(ArrayRef(sizes),
|
||||||
|
|
|
@ -44,9 +44,9 @@ bool Torch::getListConstructElements(Value v, SmallVectorImpl<Value> &elems) {
|
||||||
}
|
}
|
||||||
|
|
||||||
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
if (type.isa<Float32Type>())
|
if (isa<Float32Type>(type))
|
||||||
return torch_upstream::ScalarType::Float;
|
return torch_upstream::ScalarType::Float;
|
||||||
if (type.isa<Float64Type>())
|
if (isa<Float64Type>(type))
|
||||||
return torch_upstream::ScalarType::Double;
|
return torch_upstream::ScalarType::Double;
|
||||||
if (type.isSignedInteger(64))
|
if (type.isSignedInteger(64))
|
||||||
return torch_upstream::ScalarType::Long;
|
return torch_upstream::ScalarType::Long;
|
||||||
|
@ -64,11 +64,11 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
|
||||||
return torch_upstream::ScalarType::Byte;
|
return torch_upstream::ScalarType::Byte;
|
||||||
if (type.isSignedInteger(8))
|
if (type.isSignedInteger(8))
|
||||||
return torch_upstream::ScalarType::Char;
|
return torch_upstream::ScalarType::Char;
|
||||||
if (type.isa<QUInt8Type>())
|
if (isa<QUInt8Type>(type))
|
||||||
return torch_upstream::ScalarType::QUInt8;
|
return torch_upstream::ScalarType::QUInt8;
|
||||||
if (type.isa<QInt8Type>())
|
if (isa<QInt8Type>(type))
|
||||||
return torch_upstream::ScalarType::QInt8;
|
return torch_upstream::ScalarType::QInt8;
|
||||||
if (type.isa<QInt32Type>())
|
if (isa<QInt32Type>(type))
|
||||||
return torch_upstream::ScalarType::QInt32;
|
return torch_upstream::ScalarType::QInt32;
|
||||||
if (isa<ComplexType>(type)) {
|
if (isa<ComplexType>(type)) {
|
||||||
mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
|
mlir::Type complexElemType = cast<ComplexType>(type).getElementType();
|
||||||
|
@ -185,7 +185,7 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
|
||||||
// Helper to convert a tensor to a specific scalar type.
|
// Helper to convert a tensor to a specific scalar type.
|
||||||
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
|
||||||
Value input, Type dtype) {
|
Value input, Type dtype) {
|
||||||
BaseTensorType origType = input.getType().cast<BaseTensorType>();
|
BaseTensorType origType = cast<BaseTensorType>(input.getType());
|
||||||
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
Type newType = origType.getWithSizesAndDtype(origType.getSizes(), dtype);
|
||||||
// `convertIntVal` contains the corresponding integer for the dtype which is
|
// `convertIntVal` contains the corresponding integer for the dtype which is
|
||||||
// used by the aten.to.dtype op.
|
// used by the aten.to.dtype op.
|
||||||
|
@ -202,7 +202,7 @@ bool Torch::isBuiltInType(Type type) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::optional<unsigned> Torch::getTensorRank(Value tensor) {
|
std::optional<unsigned> Torch::getTensorRank(Value tensor) {
|
||||||
BaseTensorType tensorType = tensor.getType().cast<BaseTensorType>();
|
BaseTensorType tensorType = cast<BaseTensorType>(tensor.getType());
|
||||||
if (!tensorType.hasSizes())
|
if (!tensorType.hasSizes())
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
return tensorType.getSizes().size();
|
return tensorType.getSizes().size();
|
||||||
|
@ -279,7 +279,7 @@ SmallVector<int64_t> Torch::makeShapeTorchCompatible(ArrayRef<int64_t> shape) {
|
||||||
// Return the squeezed tensor or failure.
|
// Return the squeezed tensor or failure.
|
||||||
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
Location loc, int64_t dim, Value input) {
|
Location loc, int64_t dim, Value input) {
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(loc, "input tensor must have size");
|
return rewriter.notifyMatchFailure(loc, "input tensor must have size");
|
||||||
}
|
}
|
||||||
|
@ -314,7 +314,7 @@ FailureOr<Value> Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op,
|
||||||
// Return the unsqueezed tensor or failure.
|
// Return the unsqueezed tensor or failure.
|
||||||
FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
FailureOr<Value> Torch::unsqueezeTensor(PatternRewriter &rewriter,
|
||||||
Operation *op, Value input, Value dim) {
|
Operation *op, Value input, Value dim) {
|
||||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
BaseTensorType inputType = cast<BaseTensorType>(input.getType());
|
||||||
if (!inputType.hasSizes()) {
|
if (!inputType.hasSizes()) {
|
||||||
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
return rewriter.notifyMatchFailure(op, "input tensor must have size");
|
||||||
}
|
}
|
||||||
|
@ -348,9 +348,9 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc,
|
||||||
SmallVector<int64_t> &resultShape,
|
SmallVector<int64_t> &resultShape,
|
||||||
SmallVector<Value> &resultShapeValue) {
|
SmallVector<Value> &resultShapeValue) {
|
||||||
SmallVector<int64_t> shapeA{
|
SmallVector<int64_t> shapeA{
|
||||||
inputA.getType().cast<BaseTensorType>().getSizes()};
|
cast<BaseTensorType>(inputA.getType()).getSizes()};
|
||||||
SmallVector<int64_t> shapeB{
|
SmallVector<int64_t> shapeB{
|
||||||
inputB.getType().cast<BaseTensorType>().getSizes()};
|
cast<BaseTensorType>(inputB.getType()).getSizes()};
|
||||||
unsigned rankA = shapeA.size();
|
unsigned rankA = shapeA.size();
|
||||||
unsigned rankB = shapeB.size();
|
unsigned rankB = shapeB.size();
|
||||||
unsigned minRank = rankA > rankB ? rankB : rankA;
|
unsigned minRank = rankA > rankB ? rankB : rankA;
|
||||||
|
@ -504,9 +504,8 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc,
|
||||||
BaseTensorType inputType, Value scalar) {
|
BaseTensorType inputType, Value scalar) {
|
||||||
assert(inputType.hasDtype() && "input must have dtype");
|
assert(inputType.hasDtype() && "input must have dtype");
|
||||||
SmallVector<int64_t> sizes;
|
SmallVector<int64_t> sizes;
|
||||||
BaseTensorType rank0TensorTy =
|
BaseTensorType rank0TensorTy = cast<BaseTensorType>(
|
||||||
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())
|
inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()));
|
||||||
.cast<BaseTensorType>();
|
|
||||||
Value dimList = rewriter.create<PrimListConstructOp>(
|
Value dimList = rewriter.create<PrimListConstructOp>(
|
||||||
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())),
|
||||||
ValueRange{});
|
ValueRange{});
|
||||||
|
@ -531,9 +530,9 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
|
||||||
return rewriter.getF32Type();
|
return rewriter.getF32Type();
|
||||||
if (inputType.isBF16())
|
if (inputType.isBF16())
|
||||||
return rewriter.getF32Type();
|
return rewriter.getF32Type();
|
||||||
if (inputType.isa<Float32Type>())
|
if (isa<Float32Type>(inputType))
|
||||||
return rewriter.getF32Type();
|
return rewriter.getF32Type();
|
||||||
if (inputType.isa<Float64Type>())
|
if (isa<Float64Type>(inputType))
|
||||||
return rewriter.getF64Type();
|
return rewriter.getF64Type();
|
||||||
if (inputType.isFloat8E5M2())
|
if (inputType.isFloat8E5M2())
|
||||||
return rewriter.getF32Type();
|
return rewriter.getF32Type();
|
||||||
|
|
|
@ -34,9 +34,9 @@ static bool haveSameSizeAndElementType(TensorType lhs, TensorType rhs) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
LogicalResult ToBuiltinTensorOp::verify() {
|
LogicalResult ToBuiltinTensorOp::verify() {
|
||||||
auto resultType = getResult().getType().cast<TensorType>();
|
auto resultType = cast<TensorType>(getResult().getType());
|
||||||
auto operandType =
|
auto operandType =
|
||||||
getOperand().getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
cast<Torch::ValueTensorType>(getOperand().getType()).toBuiltinTensor();
|
||||||
if (!haveSameSizeAndElementType(resultType, operandType)) {
|
if (!haveSameSizeAndElementType(resultType, operandType)) {
|
||||||
return emitError()
|
return emitError()
|
||||||
<< "operand and result must have the same size and dtype";
|
<< "operand and result must have the same size and dtype";
|
||||||
|
@ -49,7 +49,7 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
||||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
auto resultType =
|
auto resultType =
|
||||||
operands[0].getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
cast<Torch::ValueTensorType>(operands[0].getType()).toBuiltinTensor();
|
||||||
if (!resultType)
|
if (!resultType)
|
||||||
return failure();
|
return failure();
|
||||||
inferredReturnTypes.push_back(resultType);
|
inferredReturnTypes.push_back(resultType);
|
||||||
|
@ -62,8 +62,8 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
|
||||||
|
|
||||||
LogicalResult FromBuiltinTensorOp::verify() {
|
LogicalResult FromBuiltinTensorOp::verify() {
|
||||||
auto resultType =
|
auto resultType =
|
||||||
getResult().getType().cast<Torch::ValueTensorType>().toBuiltinTensor();
|
cast<Torch::ValueTensorType>(getResult().getType()).toBuiltinTensor();
|
||||||
auto operandType = getOperand().getType().cast<TensorType>();
|
auto operandType = cast<TensorType>(getOperand().getType());
|
||||||
if (!haveSameSizeAndElementType(resultType, operandType)) {
|
if (!haveSameSizeAndElementType(resultType, operandType)) {
|
||||||
return emitError()
|
return emitError()
|
||||||
<< "operand and result must have the same size and dtype";
|
<< "operand and result must have the same size and dtype";
|
||||||
|
|
|
@ -36,7 +36,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||||
ValueRange inputs,
|
ValueRange inputs,
|
||||||
Location loc) -> Value {
|
Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
if (!inputs[0].getType().isa<Torch::BaseTensorType>())
|
if (!isa<Torch::BaseTensorType>(inputs[0].getType()))
|
||||||
return {};
|
return {};
|
||||||
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
return builder.create<ToBuiltinTensorOp>(loc, inputs[0]);
|
||||||
});
|
});
|
||||||
|
@ -44,7 +44,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
|
||||||
Torch::ValueTensorType type,
|
Torch::ValueTensorType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<TensorType>());
|
assert(isa<TensorType>(inputs[0].getType()));
|
||||||
return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]);
|
return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]);
|
||||||
};
|
};
|
||||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||||
|
@ -64,13 +64,13 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target,
|
||||||
if (!(type.getWidth() == 1 && type.isSignless()))
|
if (!(type.getWidth() == 1 && type.isSignless()))
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<Torch::BoolType>());
|
assert(isa<Torch::BoolType>(inputs[0].getType()));
|
||||||
return builder.create<ToI1Op>(loc, inputs[0]).getResult();
|
return builder.create<ToI1Op>(loc, inputs[0]).getResult();
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<IntegerType>());
|
assert(isa<IntegerType>(inputs[0].getType()));
|
||||||
return builder.create<FromI1Op>(loc, inputs[0]);
|
return builder.create<FromI1Op>(loc, inputs[0]);
|
||||||
};
|
};
|
||||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||||
|
@ -99,7 +99,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target,
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<IntegerType>());
|
assert(isa<IntegerType>(inputs[0].getType()));
|
||||||
return builder.create<FromI64Op>(loc, inputs[0]);
|
return builder.create<FromI64Op>(loc, inputs[0]);
|
||||||
};
|
};
|
||||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||||
|
@ -116,13 +116,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target,
|
||||||
[](OpBuilder &builder, Float64Type type, ValueRange inputs,
|
[](OpBuilder &builder, Float64Type type, ValueRange inputs,
|
||||||
Location loc) -> std::optional<Value> {
|
Location loc) -> std::optional<Value> {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<Torch::FloatType>());
|
assert(isa<Torch::FloatType>(inputs[0].getType()));
|
||||||
return builder.create<ToF64Op>(loc, inputs[0]).getResult();
|
return builder.create<ToF64Op>(loc, inputs[0]).getResult();
|
||||||
});
|
});
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<Float64Type>());
|
assert(isa<Float64Type>(inputs[0].getType()));
|
||||||
return builder.create<FromF64Op>(loc, inputs[0]);
|
return builder.create<FromF64Op>(loc, inputs[0]);
|
||||||
};
|
};
|
||||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||||
|
@ -153,7 +153,7 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target,
|
||||||
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
|
auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type,
|
||||||
ValueRange inputs, Location loc) -> Value {
|
ValueRange inputs, Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
assert(inputs[0].getType().isa<IntegerType>());
|
assert(isa<IntegerType>(inputs[0].getType()));
|
||||||
return builder.create<I64ToGeneratorOp>(loc, inputs[0]);
|
return builder.create<I64ToGeneratorOp>(loc, inputs[0]);
|
||||||
};
|
};
|
||||||
typeConverter.addSourceMaterialization(sourceMaterialization);
|
typeConverter.addSourceMaterialization(sourceMaterialization);
|
||||||
|
|
|
@ -42,7 +42,7 @@ public:
|
||||||
|
|
||||||
// get inputs: lhs, rhsQuant, scales, zps
|
// get inputs: lhs, rhsQuant, scales, zps
|
||||||
Value lhs = adaptor.getOperands()[0];
|
Value lhs = adaptor.getOperands()[0];
|
||||||
auto lhsType = lhs.getType().cast<RankedTensorType>();
|
auto lhsType = cast<RankedTensorType>(lhs.getType());
|
||||||
if (!lhsType) {
|
if (!lhsType) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ public:
|
||||||
int lhsReductDimSize = lhsShape.back();
|
int lhsReductDimSize = lhsShape.back();
|
||||||
|
|
||||||
Value rhsQuant = adaptor.getOperands()[1];
|
Value rhsQuant = adaptor.getOperands()[1];
|
||||||
auto rhsType = rhsQuant.getType().cast<RankedTensorType>();
|
auto rhsType = cast<RankedTensorType>(rhsQuant.getType());
|
||||||
if (!rhsType) {
|
if (!rhsType) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,7 @@ public:
|
||||||
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
|
if (failed(getConstantIntegerFromDefiningOp(bitWidth, unpackedBitWidth)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto rhsType = rhs.getType().dyn_cast<ValueTensorType>();
|
auto rhsType = dyn_cast<ValueTensorType>(rhs.getType());
|
||||||
if (!rhsType)
|
if (!rhsType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ public:
|
||||||
ValueTensorType newRhsType = ValueTensorType::get(
|
ValueTensorType newRhsType = ValueTensorType::get(
|
||||||
rewriter.getContext(), tensorShape, unpackedElementType);
|
rewriter.getContext(), tensorShape, unpackedElementType);
|
||||||
|
|
||||||
auto elements = constOp.getValueAttr().dyn_cast<DenseIntElementsAttr>();
|
auto elements = dyn_cast<DenseIntElementsAttr>(constOp.getValueAttr());
|
||||||
if (!elements)
|
if (!elements)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
|
|
|
@ -234,7 +234,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
|
||||||
if (!globalOp.getValue().has_value())
|
if (!globalOp.getValue().has_value())
|
||||||
return globalOp.emitError("global op must have a value");
|
return globalOp.emitError("global op must have a value");
|
||||||
|
|
||||||
RankedTensorType tensorType = globalOp.getType().cast<RankedTensorType>();
|
RankedTensorType tensorType = cast<RankedTensorType>(globalOp.getType());
|
||||||
MemRefType memrefType =
|
MemRefType memrefType =
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp,
|
||||||
static LogicalResult
|
static LogicalResult
|
||||||
bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
|
bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp,
|
||||||
OpBuilder &b, SmallVector<Operation *> &toErase) {
|
OpBuilder &b, SmallVector<Operation *> &toErase) {
|
||||||
RankedTensorType tensorType = globalLoadOp.getType().cast<RankedTensorType>();
|
RankedTensorType tensorType = cast<RankedTensorType>(globalLoadOp.getType());
|
||||||
MemRefType memrefType =
|
MemRefType memrefType =
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
|
||||||
|
@ -271,7 +271,7 @@ bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp,
|
||||||
OpBuilder &b,
|
OpBuilder &b,
|
||||||
SmallVector<Operation *> &toErase) {
|
SmallVector<Operation *> &toErase) {
|
||||||
RankedTensorType tensorType =
|
RankedTensorType tensorType =
|
||||||
globalStoreOp.getValue().getType().cast<RankedTensorType>();
|
cast<RankedTensorType>(globalStoreOp.getValue().getType());
|
||||||
MemRefType memrefType =
|
MemRefType memrefType =
|
||||||
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
|
||||||
|
|
||||||
|
@ -300,7 +300,7 @@ class MLProgramBufferize : public MLProgramBufferizeBase<MLProgramBufferize> {
|
||||||
SmallVector<Operation *> toErase;
|
SmallVector<Operation *> toErase;
|
||||||
|
|
||||||
auto walkResult = module.walk([&](ml_program::GlobalOp op) {
|
auto walkResult = module.walk([&](ml_program::GlobalOp op) {
|
||||||
if (auto type = op.getType().dyn_cast<RankedTensorType>()) {
|
if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
|
||||||
if (!type.hasStaticShape()) {
|
if (!type.hasStaticShape()) {
|
||||||
// If the ml_program.global has dynamically shaped tensor.
|
// If the ml_program.global has dynamically shaped tensor.
|
||||||
op.emitError(
|
op.emitError(
|
||||||
|
@ -387,8 +387,8 @@ mlir::torch::RefBackend::createExpandOpsForLLVMPass() {
|
||||||
|
|
||||||
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from,
|
||||||
Value to) {
|
Value to) {
|
||||||
auto memrefTypeFrom = from.getType().cast<MemRefType>();
|
auto memrefTypeFrom = cast<MemRefType>(from.getType());
|
||||||
auto memrefTypeTo = to.getType().cast<MemRefType>();
|
auto memrefTypeTo = cast<MemRefType>(to.getType());
|
||||||
(void)memrefTypeFrom;
|
(void)memrefTypeFrom;
|
||||||
assert(memrefTypeFrom && memrefTypeTo &&
|
assert(memrefTypeFrom && memrefTypeTo &&
|
||||||
memrefTypeFrom.getRank() == memrefTypeTo.getRank());
|
memrefTypeFrom.getRank() == memrefTypeTo.getRank());
|
||||||
|
|
Loading…
Reference in New Issue